diff --git a/swh/provenance/cli.py b/swh/provenance/cli.py --- a/swh/provenance/cli.py +++ b/swh/provenance/cli.py @@ -137,7 +137,7 @@ revisions_provider = ( line.strip().split(",") for line in open(filename, "r") if line.strip() ) - revisions = CSVRevisionIterator(revisions_provider, archive, limit=limit) + revisions = CSVRevisionIterator(revisions_provider, limit=limit) for revision in revisions: revision_add( @@ -157,14 +157,18 @@ def iter_origins(ctx, filename, limit): """Process a provided list of origins.""" from . import get_archive, get_provenance - from .origin import FileOriginIterator + from .origin import CSVOriginIterator from .provenance import origin_add archive = get_archive(**ctx.obj["config"]["archive"]) provenance = get_provenance(**ctx.obj["config"]["provenance"]) + origins_provider = ( + line.strip().split(",") for line in open(filename, "r") if line.strip() + ) + origins = CSVOriginIterator(origins_provider, limit=limit) - for origin in FileOriginIterator(filename, archive, limit=limit): - origin_add(archive, provenance, origin) + for origin in origins: + origin_add(provenance, archive, [origin]) @cli.command(name="find-first") diff --git a/swh/provenance/model.py b/swh/provenance/model.py --- a/swh/provenance/model.py +++ b/swh/provenance/model.py @@ -4,16 +4,64 @@ # See top-level LICENSE file for more information from datetime import datetime -from typing import Iterable, Iterator, List, Optional +from typing import Iterable, Iterator, List, Optional, Set + +from swh.core.utils import grouper +from swh.model.model import ObjectType, TargetType from .archive import ArchiveInterface class OriginEntry: - def __init__(self, url, revisions: Iterable["RevisionEntry"], id=None): - self.id = id + def __init__( + self, url: str, date: datetime, snapshot: bytes, id: Optional[int] = None + ): self.url = url - self.revisions = revisions + self.date = date + self.snapshot = snapshot + self.id = id + self._revisions: Optional[List[RevisionEntry]] = None + + def retrieve_revisions(self, archive: ArchiveInterface): + if self._revisions is None: + snapshot = archive.snapshot_get_all_branches(self.snapshot) + assert snapshot is not None + targets_set = set() + releases_set = set() + if snapshot is not None: + for branch in snapshot.branches: + if snapshot.branches[branch].target_type == TargetType.REVISION: + targets_set.add(snapshot.branches[branch].target) + elif snapshot.branches[branch].target_type == TargetType.RELEASE: + releases_set.add(snapshot.branches[branch].target) + + batchsize = 100 + for releases in grouper(releases_set, batchsize): + targets_set.update( + release.target + for release in archive.revision_get(releases) + if release is not None + and release.target_type == ObjectType.REVISION + ) + + revisions: Set[RevisionEntry] = set() + for targets in grouper(targets_set, batchsize): + revisions.update( + RevisionEntry(revision.id) + for revision in archive.revision_get(targets) + if revision is not None + ) + + self._revisions = list(revisions) + + @property + def revisions(self) -> Iterator["RevisionEntry"]: + if self._revisions is None: + raise RuntimeError( + "Revisions of this node has not yet been retrieved. " + "Please call retrieve_revisions() before using this property." + ) + return (x for x in self._revisions) class RevisionEntry: @@ -35,7 +83,7 @@ if self._parents is None: revision = archive.revision_get([self.id]) if revision: - self._parents = revision[0].parents + self._parents = list(revision)[0].parents if self._parents and not self._nodes: self._nodes = [ RevisionEntry( diff --git a/swh/provenance/origin.py b/swh/provenance/origin.py --- a/swh/provenance/origin.py +++ b/swh/provenance/origin.py @@ -1,91 +1,41 @@ -from typing import List, Optional +from datetime import datetime, timezone +from itertools import islice +from typing import Iterable, Iterator, Optional, Tuple -from swh.model.model import ObjectType, Origin, TargetType +import iso8601 -from .archive import ArchiveInterface -from .model import OriginEntry, RevisionEntry +from .model import OriginEntry ################################################################################ ################################################################################ -class FileOriginIterator: - """Iterator over origins present in the given CSV file.""" +class CSVOriginIterator: + """Iterator over origin visit statuses typically present in the given CSV + file. - def __init__( - self, filename: str, archive: ArchiveInterface, limit: Optional[int] = None - ): - self.file = open(filename) - self.limit = limit - self.archive = archive - - def __iter__(self): - yield from iterate_statuses( - [Origin(url.strip()) for url in self.file], self.archive, self.limit - ) + The input is an iterator that produces 3 elements per row: + (url, date, snap) -class ArchiveOriginIterator: - """Iterator over origins present in the given storage.""" + where: + - url: is the origin url of the visit + - date: is the date of the visit + - snap: sha1_git of the snapshot pointed by the visit status + """ - def __init__(self, archive: ArchiveInterface, limit: Optional[int] = None): - self.limit = limit - self.archive = archive + def __init__( + self, + statuses: Iterable[Tuple[str, datetime, bytes]], + limit: Optional[int] = None, + ): + self.statuses: Iterator[Tuple[str, datetime, bytes]] + if limit is not None: + self.statuses = islice(statuses, limit) + else: + self.statuses = iter(statuses) def __iter__(self): - yield from iterate_statuses( - self.archive.iter_origins(), self.archive, self.limit - ) - - -def iterate_statuses( - origins: List[Origin], archive: ArchiveInterface, limit: Optional[int] = None -): - idx = 0 - for origin in origins: - for visit in archive.iter_origin_visits(origin.url): - for status in archive.iter_origin_visit_statuses(origin.url, visit.visit): - snapshot = archive.snapshot_get_all_branches(status.snapshot) - if snapshot is None: - continue - # TODO: may filter only those whose status is 'full'?? - targets_set = set() - releases_set = set() - if snapshot is not None: - for branch in snapshot.branches: - if snapshot.branches[branch].target_type == TargetType.REVISION: - targets_set.add(snapshot.branches[branch].target) - elif ( - snapshot.branches[branch].target_type == TargetType.RELEASE - ): - releases_set.add(snapshot.branches[branch].target) - - # This is done to keep the query in release_get small, hence avoiding - # a timeout. - batchsize = 100 - while releases_set: - releases = [ - releases_set.pop() for i in range(batchsize) if releases_set - ] - for release in archive.release_get(releases): - if release is not None: - if release.target_type == ObjectType.REVISION: - targets_set.add(release.target) - - # This is done to keep the query in revision_get small, hence avoiding - # a timeout. - revisions = set() - while targets_set: - targets = [ - targets_set.pop() for i in range(batchsize) if targets_set - ] - for revision in archive.revision_get(targets): - if revision is not None: - revisions.add(RevisionEntry(revision.id)) - # target_set |= set(revision.parents) - - yield OriginEntry(status.origin, list(revisions)) - - idx += 1 - if idx == limit: - return + for url, date, snap in self.statuses: + date = iso8601.parse_date(date, default_timezone=timezone.utc) + yield OriginEntry(url, date, snap) diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -133,20 +133,30 @@ def origin_add( - archive: ArchiveInterface, provenance: ProvenanceInterface, origin: OriginEntry + provenance: ProvenanceInterface, + archive: ArchiveInterface, + origins: List[OriginEntry], ) -> None: - # TODO: refactor to iterate over origin visit statuses and commit only once - # per status. - origin.id = provenance.origin_get_id(origin) - for revision in origin.revisions: - origin_add_revision(archive, provenance, origin, revision) - # Commit after each revision - provenance.commit() # TODO: verify this! + start = time.time() + for origin in origins: + origin.retrieve_revisions(archive) + for revision in origin.revisions: + origin_add_revision(provenance, archive, origin, revision) + done = time.time() + provenance.commit() + stop = time.time() + logging.debug( + "Origins " + ";".join( + [origin.url + ":" + hash_to_hex(origin.snapshot) for origin in origins] + ) + + f" were processed in {stop - start} secs (commit took {stop - done} secs)!" + ) def origin_add_revision( - archive: ArchiveInterface, provenance: ProvenanceInterface, + archive: ArchiveInterface, origin: OriginEntry, revision: RevisionEntry, ) -> None: diff --git a/swh/provenance/revision.py b/swh/provenance/revision.py --- a/swh/provenance/revision.py +++ b/swh/provenance/revision.py @@ -1,12 +1,10 @@ from datetime import datetime, timezone from itertools import islice -import threading from typing import Iterable, Iterator, Optional, Tuple import iso8601 from swh.model.hashutil import hash_to_bytes -from swh.provenance.archive import ArchiveInterface from swh.provenance.model import RevisionEntry ######################################################################################## @@ -29,7 +27,6 @@ def __init__( self, revisions: Iterable[Tuple[bytes, datetime, bytes]], - archive: ArchiveInterface, limit: Optional[int] = None, ): self.revisions: Iterator[Tuple[bytes, datetime, bytes]] @@ -37,20 +34,17 @@ self.revisions = islice(revisions, limit) else: self.revisions = iter(revisions) - self.mutex = threading.Lock() - self.archive = archive def __iter__(self): return self def __next__(self): - with self.mutex: - id, date, root = next(self.revisions) - date = iso8601.parse_date(date) - if date.tzinfo is None: - date = date.replace(tzinfo=timezone.utc) - return RevisionEntry( - hash_to_bytes(id), - date=date, - root=hash_to_bytes(root), - ) + id, date, root = next(self.revisions) + date = iso8601.parse_date(date) + if date.tzinfo is None: + date = date.replace(tzinfo=timezone.utc) + return RevisionEntry( + hash_to_bytes(id), + date=date, + root=hash_to_bytes(root), + ) diff --git a/swh/provenance/tests/test_origin_iterator.py b/swh/provenance/tests/test_origin_iterator.py --- a/swh/provenance/tests/test_origin_iterator.py +++ b/swh/provenance/tests/test_origin_iterator.py @@ -2,23 +2,36 @@ # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import pytest from swh.model.tests.swh_model_data import TEST_OBJECTS -from swh.provenance.origin import ArchiveOriginIterator +from swh.provenance.origin import CSVOriginIterator +from swh.storage.algos.origin import ( + iter_origins, + iter_origin_visits, + iter_origin_visit_statuses, +) -def test_archive_direct_origin_iterator(swh_storage_with_objects, archive_direct): - """Test ArchiveOriginIterator against the ArchivePostgreSQL""" - # XXX - pytest.xfail("Iterate Origins is currently unsupported by ArchivePostgreSQL") - origins = list(ArchiveOriginIterator(archive_direct)) +def test_origin_iterator(swh_storage_with_objects): + """Test CSVOriginIterator""" + origins_csv = [] + for origin in iter_origins(swh_storage_with_objects): + for visit in iter_origin_visits(swh_storage_with_objects, origin.url): + for status in iter_origin_visit_statuses( + swh_storage_with_objects, origin.url, visit.visit + ): + if status.snapshot is not None: + origins_csv.append( + (status.origin, status.date.isoformat(), status.snapshot) + ) + origins = list(CSVOriginIterator(origins_csv)) assert origins - assert len(origins) == len(TEST_OBJECTS["origin"]) - - -def test_archive_api_origin_iterator(swh_storage_with_objects, archive_api): - """Test ArchiveOriginIterator against the ArchiveStorage""" - origins = list(ArchiveOriginIterator(archive_api)) - assert origins - assert len(origins) == len(TEST_OBJECTS["origin"]) + assert len(origins) == len( + list( + { + status.origin + for status in TEST_OBJECTS["origin_visit_status"] + if status.snapshot is not None + } + ) + ) diff --git a/swh/provenance/tests/test_provenance_db.py b/swh/provenance/tests/test_provenance_db.py --- a/swh/provenance/tests/test_provenance_db.py +++ b/swh/provenance/tests/test_provenance_db.py @@ -20,8 +20,12 @@ def test_provenance_origin_add(provenance, swh_storage_with_objects): - """Test the ProvenanceDB.origin_add() method""" - for origin in TEST_OBJECTS["origin"]: - entry = OriginEntry(url=origin.url, revisions=[]) - origin_add(ArchiveStorage(swh_storage_with_objects), provenance, entry) + """Test the origin_add function""" + archive = ArchiveStorage(swh_storage_with_objects) + for status in TEST_OBJECTS["origin_visit_status"]: + if status.snapshot is not None: + entry = OriginEntry( + url=status.origin, date=status.date, snapshot=status.snapshot + ) + origin_add(provenance, archive, [entry]) # TODO: check some facts here diff --git a/swh/provenance/tests/test_revision_iterator.py b/swh/provenance/tests/test_revision_iterator.py --- a/swh/provenance/tests/test_revision_iterator.py +++ b/swh/provenance/tests/test_revision_iterator.py @@ -17,7 +17,7 @@ "out-of-order", ), ) -def test_archive_direct_revision_iterator(swh_storage, archive_direct, repo): +def test_archive_direct_revision_iterator(swh_storage, repo): """Test CSVRevisionIterator""" data = load_repo_data(repo) fill_storage(swh_storage, data) @@ -25,6 +25,6 @@ (rev["id"], ts2dt(rev["date"]).isoformat(), rev["directory"]) for rev in data["revision"] ] - revisions = list(CSVRevisionIterator(revisions_csv, archive_direct)) + revisions = list(CSVRevisionIterator(revisions_csv)) assert revisions assert len(revisions) == len(data["revision"])