diff --git a/swh/provenance/origin.py b/swh/provenance/origin.py --- a/swh/provenance/origin.py +++ b/swh/provenance/origin.py @@ -63,36 +63,38 @@ for origin in origins: for visit in archive.iter_origin_visits(origin.url): for status in archive.iter_origin_visit_statuses(origin.url, visit.visit): - # TODO: may filter only those whose status is 'full'?? - targets = [] - releases = [] - snapshot = archive.snapshot_get_all_branches(status.snapshot) + if snapshot is None: + continue + # TODO: may filter only those whose status is 'full'?? + targets_set = set() + releases_set = set() if snapshot is not None: for branch in snapshot.branches: if snapshot.branches[branch].target_type == TargetType.REVISION: - targets.append(snapshot.branches[branch].target) - + targets_set.add(snapshot.branches[branch].target) elif ( snapshot.branches[branch].target_type == TargetType.RELEASE ): - releases.append(snapshot.branches[branch].target) + releases_set.add(snapshot.branches[branch].target) # This is done to keep the query in release_get small, hence avoiding # a timeout. - limit = 100 - for i in range(0, len(releases), limit): - for release in archive.release_get(releases[i : i + limit]): + batchsize = 100 + releases = list(releases_set) + while releases: + for release in archive.release_get(releases[:batchsize]): if release is not None: if release.target_type == ObjectType.REVISION: - targets.append(release.target) + targets_set.add(release.target) + releases[:batchsize] = [] # This is done to keep the query in revision_get small, hence avoiding # a timeout. - revisions = [] - limit = 100 - for i in range(0, len(targets), limit): - for revision in archive.revision_get(targets[i : i + limit]): + revisions = set() + targets = list(targets_set) + while targets: + for revision in archive.revision_get(targets[:batchsize]): if revision is not None: parents = list( map( @@ -100,12 +102,13 @@ revision.parents, ) ) - revisions.append( + revisions.add( RevisionEntry(archive, revision.id, parents=parents) ) + targets[:batchsize] = [] - yield OriginEntry(status.origin, revisions) + yield OriginEntry(status.origin, list(revisions)) - idx = idx + 1 + idx += 1 if idx == limit: return diff --git a/swh/provenance/tests/test_origin_iterator.py b/swh/provenance/tests/test_origin_iterator.py new file mode 100644 --- /dev/null +++ b/swh/provenance/tests/test_origin_iterator.py @@ -0,0 +1,24 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information +import pytest + +from swh.model.tests.swh_model_data import TEST_OBJECTS +from swh.provenance.origin import ArchiveOriginIterator + + +def test_archive_direct_origin_iterator(swh_storage_with_objects, archive_direct): + """Test ArchiveOriginIterator against the ArchivePostgreSQL""" + # XXX + pytest.xfail("Iterate Origins is currently unsupported by ArchivePostgreSQL") + origins = list(ArchiveOriginIterator(archive_direct)) + assert origins + assert len(origins) == len(TEST_OBJECTS["origin"]) + + +def test_archive_api_origin_iterator(swh_storage_with_objects, archive_api): + """Test ArchiveOriginIterator against the ArchiveStorage""" + origins = list(ArchiveOriginIterator(archive_api)) + assert origins + assert len(origins) == len(TEST_OBJECTS["origin"])