diff --git a/.gitignore b/.gitignore --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ *.egg-info version.txt .vscode/ +.hypothesis/ diff --git a/requirements-test.txt b/requirements-test.txt --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1 +1,2 @@ +hypothesis nose diff --git a/swh/storage/algos/snapshot.py b/swh/storage/algos/snapshot.py new file mode 100644 --- /dev/null +++ b/swh/storage/algos/snapshot.py @@ -0,0 +1,31 @@ +# Copyright (C) 2018 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + + +def snapshot_get_all_branches(storage, snapshot_id): + """Get all the branches for a given snapshot + + Args: + storage (swh.storage.storage.Storage): the storage instance + snapshot_id (bytes): the snapshot's identifier + Returns: + dict: a dict with two keys: + * **id**: identifier of the snapshot + * **branches**: a dict of branches contained in the snapshot + whose keys are the branches' names. + """ + ret = storage.snapshot_get(snapshot_id) + + if not ret: + return + + next_branch = ret.pop('next_branch', None) + while next_branch: + data = storage.snapshot_get_branches(snapshot_id, + branches_from=next_branch) + ret['branches'].update(data['branches']) + next_branch = data.get('next_branch') + + return ret diff --git a/swh/storage/tests/algos/test_snapshot.py b/swh/storage/tests/algos/test_snapshot.py new file mode 100644 --- /dev/null +++ b/swh/storage/tests/algos/test_snapshot.py @@ -0,0 +1,125 @@ +# Copyright (C) 2018 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import unittest + +from hypothesis import given, settings +from hypothesis.strategies import (binary, composite, datetimes, dictionaries, + from_regex, none, one_of, sampled_from) + +from swh.model.identifiers import snapshot_identifier, identifier_to_bytes +from swh.storage.tests.storage_testing import StorageTestFixture + +from swh.storage.algos.snapshot import snapshot_get_all_branches + + +def branch_names(): + return binary(min_size=5, max_size=10) + + +@composite +def branch_targets_object(draw): + return { + 'target': draw(binary(min_size=20, max_size=20)), + 'target_type': draw( + sampled_from([ + 'content', 'directory', 'revision', 'release', 'snapshot', + ]) + ), + } + + +@composite +def branch_targets_alias(draw): + return { + 'target': draw(branch_names()), + 'target_type': 'alias', + } + + +def branch_targets(*, only_objects=False): + if only_objects: + return branch_targets_object() + else: + return one_of(none(), branch_targets_alias(), branch_targets_object()) + + +@composite +def snapshots(draw, *, min_size=0, max_size=100, only_objects=False): + branches = draw(dictionaries( + keys=branch_names(), + values=branch_targets(only_objects=only_objects), + min_size=min_size, + max_size=max_size, + )) + + if not only_objects: + # Make sure aliases point to actual branches + unresolved_aliases = { + target['target'] + for target in branches.values() + if (target + and target['target_type'] == 'alias' + and target['target'] not in branches) + } + + for alias in unresolved_aliases: + branches[alias] = draw(branch_targets(only_objects=True)) + + ret = { + 'branches': branches, + } + ret['id'] = identifier_to_bytes(snapshot_identifier(ret)) + return ret + + +@composite +def urls(draw): + protocol = draw(sampled_from(['git', 'http', 'https', 'deb'])) + domain = draw(from_regex(r'\A([a-z]([a-z0-9-]*)\.){1,3}[a-z0-9]+\Z')) + + return '%s://%s' % (protocol, domain) + + +@composite +def origins(draw): + return { + 'type': draw(sampled_from(['git', 'hg', 'svn', 'pypi', 'deb'])), + 'url': draw(urls()), + } + + +class TestSnapshotAllBranches(StorageTestFixture, unittest.TestCase): + @given(origins(), datetimes(), snapshots(min_size=0, max_size=10, + only_objects=False)) + def test_snapshot_small(self, origin, ts, snapshot): + origin_id = self.storage.origin_add_one(origin) + visit = self.storage.origin_visit_add(origin_id, ts) + self.storage.snapshot_add(origin_id, visit['visit'], snapshot) + + returned_snapshot = snapshot_get_all_branches(self.storage, + snapshot['id']) + self.assertEquals(snapshot, returned_snapshot) + + @settings(max_examples=5, deadline=1000) + @given(origins(), datetimes(), + branch_names(), branch_targets(only_objects=True)) + def test_snapshot_large(self, origin, ts, branch_name, branch_target): + origin_id = self.storage.origin_add_one(origin) + visit = self.storage.origin_visit_add(origin_id, ts) + + snapshot = { + 'branches': { + b'%s%05d' % (branch_name, i): branch_target + for i in range(10000) + } + } + snapshot['id'] = identifier_to_bytes(snapshot_identifier(snapshot)) + + self.storage.snapshot_add(origin_id, visit['visit'], snapshot) + + returned_snapshot = snapshot_get_all_branches(self.storage, + snapshot['id']) + self.assertEquals(snapshot, returned_snapshot)