diff --git a/requirements-swh.txt b/requirements-swh.txt --- a/requirements-swh.txt +++ b/requirements-swh.txt @@ -1,4 +1,4 @@ swh.core >= 0.0.56 -swh.model >= 0.0.27 +swh.model >= 0.0.32 swh.objstorage >= 0.0.17 swh.scheduler >= 0.0.14 diff --git a/swh/storage/tests/algos/test_snapshot.py b/swh/storage/tests/algos/test_snapshot.py --- a/swh/storage/tests/algos/test_snapshot.py +++ b/swh/storage/tests/algos/test_snapshot.py @@ -8,105 +8,24 @@ import pytest from hypothesis import given -from hypothesis.strategies import (binary, composite, datetimes, dictionaries, - from_regex, none, one_of, sampled_from) +from hypothesis.strategies import datetimes from swh.model.identifiers import snapshot_identifier, identifier_to_bytes +from swh.model.hypothesis_strategies import \ + origins, snapshots, branch_names, branch_targets from swh.storage.tests.storage_testing import StorageTestFixture from swh.storage.algos.snapshot import snapshot_get_all_branches -def branch_names(): - return binary(min_size=5, max_size=10) - - -@composite -def branch_targets_object(draw): - return { - 'target': draw(binary(min_size=20, max_size=20)), - 'target_type': draw( - sampled_from([ - 'content', 'directory', 'revision', 'release', 'snapshot', - ]) - ), - } - - -@composite -def branch_targets_alias(draw): - return { - 'target': draw(branch_names()), - 'target_type': 'alias', - } - - -def branch_targets(*, only_objects=False): - if only_objects: - return branch_targets_object() - else: - return one_of(none(), branch_targets_alias(), branch_targets_object()) - - -@composite -def snapshots(draw, *, min_size=0, max_size=100, only_objects=False): - branches = draw(dictionaries( - keys=branch_names(), - values=branch_targets(only_objects=only_objects), - min_size=min_size, - max_size=max_size, - )) - - if not only_objects: - # Make sure aliases point to actual branches - unresolved_aliases = { - target['target'] - for target in branches.values() - if (target - and target['target_type'] == 'alias' - and target['target'] not in branches) - } - - for alias in unresolved_aliases: - branches[alias] = draw(branch_targets(only_objects=True)) - - ret = { - 'branches': branches, - } - while True: - try: - id_ = snapshot_identifier(ret) - except ValueError as e: - for (source, target) in e.args[1]: - ret[source] = draw(branch_targets(only_objects=True)) - else: - break - ret['id'] = identifier_to_bytes(id_) - return ret - - -@composite -def urls(draw): - protocol = draw(sampled_from(['git', 'http', 'https', 'deb'])) - domain = draw(from_regex(r'\A([a-z]([a-z0-9-]*)\.){1,3}[a-z0-9]+\Z')) - - return '%s://%s' % (protocol, domain) - - -@composite -def origins(draw): - return { - 'type': draw(sampled_from(['git', 'hg', 'svn', 'pypi', 'deb'])), - 'url': draw(urls()), - } - - @pytest.mark.db @pytest.mark.property_based class TestSnapshotAllBranches(StorageTestFixture, unittest.TestCase): - @given(origins(), datetimes(), snapshots(min_size=0, max_size=10, - only_objects=False)) + @given(origins().map(lambda x: x.to_dict()), + datetimes(), + snapshots(min_size=0, max_size=10, only_objects=False)) def test_snapshot_small(self, origin, ts, snapshot): + snapshot = snapshot.to_dict() origin_id = self.storage.origin_add_one(origin) visit = self.storage.origin_visit_add(origin_id, ts) self.storage.snapshot_add(origin_id, visit['visit'], snapshot) @@ -115,7 +34,8 @@ snapshot['id']) self.assertEqual(snapshot, returned_snapshot) - @given(origins(), datetimes(), + @given(origins().map(lambda x: x.to_dict()), + datetimes(), branch_names(), branch_targets(only_objects=True)) def test_snapshot_large(self, origin, ts, branch_name, branch_target): origin_id = self.storage.origin_add_one(origin) diff --git a/swh/storage/tests/generate_data_test.py b/swh/storage/tests/generate_data_test.py --- a/swh/storage/tests/generate_data_test.py +++ b/swh/storage/tests/generate_data_test.py @@ -3,14 +3,10 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import random - -from hypothesis.strategies import (binary, composite, just, sets) +from hypothesis.strategies import (binary, composite, sets) from swh.model.hashutil import MultiHash -from swh.storage.tests.algos.test_snapshot import origins - def gen_raw_content(): """Generate raw content binary. @@ -51,31 +47,3 @@ }) return contents - - -def gen_origins(min_size=10, max_size=100, unique=True): - """Generate a list of origins. - - Args: - **min_size** (int): Minimal number of elements to generate - (default: 10) - **max_size** (int): Maximal number of elements to generate - (default: 100) - **unique** (bool): Specify if all generated origins must be unique - - Returns: - [dict] representing origins. The list's size is between - [min_size:max_size]. - """ - size = random.randint(min_size, max_size) - new_origins = [] - origins_set = set() - while len(new_origins) != size: - new_origin = origins().example() - if unique: - key = (new_origin['type'], new_origin['url']) - if key in origins_set: - continue - origins_set.add(key) - new_origins.append(new_origin) - return just(new_origins) diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -17,10 +17,11 @@ from swh.model import from_disk, identifiers from swh.model.hashutil import hash_to_bytes +from swh.model.hypothesis_strategies import origins from swh.storage.tests.storage_testing import StorageTestFixture from swh.storage import HashCollision -from .generate_data_test import gen_contents, gen_origins +from .generate_data_test import gen_contents @pytest.mark.db @@ -2927,14 +2928,16 @@ origin_visits = list(self.storage.origin_visit_get(1)) self.assertEqual(origin_visits, []) - @given(gen_origins(min_size=100, max_size=100)) + @given(strategies.sets(origins().map(lambda x: tuple(x.to_dict().items())), + min_size=20)) def test_origin_get_range(self, new_origins): + new_origins = list(map(dict, new_origins)) nb_origins = len(new_origins) self.storage.origin_add(new_origins) - origin_from = random.randint(1, nb_origins) + origin_from = random.randint(1, nb_origins-1) origin_count = random.randint(1, nb_origins - origin_from) expected_origins = []