diff --git a/swh/provenance/origin.py b/swh/provenance/origin.py index 660e0f4..7b2964d 100644 --- a/swh/provenance/origin.py +++ b/swh/provenance/origin.py @@ -1,111 +1,114 @@ from typing import Optional from swh.model.model import ObjectType, Origin, TargetType from .archive import ArchiveInterface from .revision import RevisionEntry class OriginEntry: def __init__(self, url, revisions, id=None): self.id = id self.url = url self.revisions = revisions ################################################################################ ################################################################################ class OriginIterator: """Iterator interface.""" def __iter__(self): pass def __next__(self): pass class FileOriginIterator(OriginIterator): """Iterator over origins present in the given CSV file.""" def __init__( self, filename: str, archive: ArchiveInterface, limit: Optional[int] = None ): self.file = open(filename) self.limit = limit # self.mutex = threading.Lock() self.archive = archive def __iter__(self): yield from iterate_statuses( [Origin(url.strip()) for url in self.file], self.archive, self.limit ) class ArchiveOriginIterator: """Iterator over origins present in the given storage.""" def __init__(self, archive: ArchiveInterface, limit: Optional[int] = None): self.limit = limit # self.mutex = threading.Lock() self.archive = archive def __iter__(self): yield from iterate_statuses( self.archive.iter_origins(), self.archive, self.limit ) def iterate_statuses(origins, archive: ArchiveInterface, limit: Optional[int] = None): idx = 0 for origin in origins: for visit in archive.iter_origin_visits(origin.url): for status in archive.iter_origin_visit_statuses(origin.url, visit.visit): - # TODO: may filter only those whose status is 'full'?? - targets = [] - releases = [] - snapshot = archive.snapshot_get_all_branches(status.snapshot) + if snapshot is None: + continue + # TODO: may filter only those whose status is 'full'?? + targets_set = set() + releases_set = set() if snapshot is not None: for branch in snapshot.branches: if snapshot.branches[branch].target_type == TargetType.REVISION: - targets.append(snapshot.branches[branch].target) - + targets_set.add(snapshot.branches[branch].target) elif ( snapshot.branches[branch].target_type == TargetType.RELEASE ): - releases.append(snapshot.branches[branch].target) + releases_set.add(snapshot.branches[branch].target) # This is done to keep the query in release_get small, hence avoiding # a timeout. - limit = 100 - for i in range(0, len(releases), limit): - for release in archive.release_get(releases[i : i + limit]): + batchsize = 100 + releases = list(releases_set) + while releases: + for release in archive.release_get(releases[:batchsize]): if release is not None: if release.target_type == ObjectType.REVISION: - targets.append(release.target) + targets_set.add(release.target) + releases[:batchsize] = [] # This is done to keep the query in revision_get small, hence avoiding # a timeout. - revisions = [] - limit = 100 - for i in range(0, len(targets), limit): - for revision in archive.revision_get(targets[i : i + limit]): + revisions = set() + targets = list(targets_set) + while targets: + for revision in archive.revision_get(targets[:batchsize]): if revision is not None: parents = list( map( lambda id: RevisionEntry(archive, id), revision.parents, ) ) - revisions.append( + revisions.add( RevisionEntry(archive, revision.id, parents=parents) ) + targets[:batchsize] = [] - yield OriginEntry(status.origin, revisions) + yield OriginEntry(status.origin, list(revisions)) - idx = idx + 1 + idx += 1 if idx == limit: return diff --git a/swh/provenance/tests/test_origin_iterator.py b/swh/provenance/tests/test_origin_iterator.py new file mode 100644 index 0000000..21cd9d1 --- /dev/null +++ b/swh/provenance/tests/test_origin_iterator.py @@ -0,0 +1,24 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information +import pytest + +from swh.model.tests.swh_model_data import TEST_OBJECTS +from swh.provenance.origin import ArchiveOriginIterator + + +def test_archive_direct_origin_iterator(swh_storage_with_objects, archive_direct): + """Test ArchiveOriginIterator against the ArchivePostgreSQL""" + # XXX + pytest.xfail("Iterate Origins is currently unsupported by ArchivePostgreSQL") + origins = list(ArchiveOriginIterator(archive_direct)) + assert origins + assert len(origins) == len(TEST_OBJECTS["origin"]) + + +def test_archive_api_origin_iterator(swh_storage_with_objects, archive_api): + """Test ArchiveOriginIterator against the ArchiveStorage""" + origins = list(ArchiveOriginIterator(archive_api)) + assert origins + assert len(origins) == len(TEST_OBJECTS["origin"])