diff --git a/swh/provenance/cli.py b/swh/provenance/cli.py --- a/swh/provenance/cli.py +++ b/swh/provenance/cli.py @@ -127,16 +127,16 @@ """Process a provided list of revisions.""" from . import get_archive, get_provenance from .provenance import revision_add - from .revision import FileRevisionIterator + from .revision import CSVRevisionIterator archive = get_archive(**ctx.obj["config"]["archive"]) provenance = get_provenance(**ctx.obj["config"]["provenance"]) - revisions = FileRevisionIterator(filename, archive, limit=limit) + revisions_provider = ( + line.strip().split(",") for line in open(filename, "r") if line.strip() + ) + revisions = CSVRevisionIterator(revisions_provider, archive, limit=limit) - while True: - revision = revisions.next() - if revision is None: - break + for revision in revisions: revision_add(provenance, archive, revision) diff --git a/swh/provenance/revision.py b/swh/provenance/revision.py --- a/swh/provenance/revision.py +++ b/swh/provenance/revision.py @@ -1,6 +1,7 @@ from datetime import datetime +from itertools import islice import threading -from typing import Optional +from typing import Iterable, Iterator, Optional, Tuple from swh.model.hashutil import hash_to_bytes @@ -44,45 +45,45 @@ ######################################################################################## -class RevisionIterator: - """Iterator interface.""" +class CSVRevisionIterator: + """Iterator over revisions typically present in the given CSV file. - def __iter__(self): - pass - - def __next__(self): - pass + The input is an iterator that produces 3 elements per row: + (id, date, root) -class FileRevisionIterator(RevisionIterator): - """Iterator over revisions present in the given CSV file.""" + where: + - id: is the id (sha1_git) of the revision + - date: is the author date + - root: sha1 of the directory + """ def __init__( - self, filename: str, archive: ArchiveInterface, limit: Optional[int] = None + self, + revisions: Iterable[Tuple[bytes, datetime, bytes]], + archive: ArchiveInterface, + limit: Optional[int] = None, ): - self.file = open(filename) - self.idx = 0 - self.limit = limit + self.revisions: Iterator[Tuple[bytes, datetime, bytes]] + if limit is not None: + self.revisions = islice(revisions, limit) + else: + self.revisions = iter(revisions) self.mutex = threading.Lock() self.archive = archive - def next(self): - self.mutex.acquire() - line = self.file.readline().strip() - if line and (self.limit is None or self.idx < self.limit): - self.idx = self.idx + 1 - id, date, root = line.strip().split(",") - self.mutex.release() + def __iter__(self): + return self + def __next__(self): + with self.mutex: + id, date, root = next(self.revisions) return RevisionEntry( self.archive, hash_to_bytes(id), date=datetime.fromisoformat(date), root=hash_to_bytes(root), ) - else: - self.mutex.release() - return None # class ArchiveRevisionIterator(RevisionIterator): diff --git a/swh/provenance/tests/conftest.py b/swh/provenance/tests/conftest.py --- a/swh/provenance/tests/conftest.py +++ b/swh/provenance/tests/conftest.py @@ -12,6 +12,8 @@ from swh.core.utils import numfile_sortkey as sortkey from swh.model.tests.swh_model_data import TEST_OBJECTS import swh.provenance +from swh.provenance.postgresql.archive import ArchivePostgreSQL +from swh.provenance.storage.archive import ArchiveStorage SQL_DIR = path.join(path.dirname(swh.provenance.__file__), "sql") SQL_FILES = [ @@ -55,3 +57,13 @@ ): getattr(swh_storage, f"{obj_type}_add")(TEST_OBJECTS[obj_type]) return swh_storage + + +@pytest.fixture +def archive_direct(swh_storage_with_objects): + return ArchivePostgreSQL(swh_storage_with_objects.get_db().conn) + + +@pytest.fixture +def archive_api(swh_storage_with_objects): + return ArchiveStorage(swh_storage_with_objects) diff --git a/swh/provenance/tests/test_revision_iterator.py b/swh/provenance/tests/test_revision_iterator.py new file mode 100644 --- /dev/null +++ b/swh/provenance/tests/test_revision_iterator.py @@ -0,0 +1,28 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information +import datetime + +from swh.model.model import TimestampWithTimezone +from swh.model.tests.swh_model_data import TEST_OBJECTS +from swh.provenance.revision import CSVRevisionIterator + + +def ts_to_dt(ts_with_tz: TimestampWithTimezone) -> datetime.datetime: + """converts a TimestampWithTimezone into a datetime""" + ts = ts_with_tz.timestamp + timestamp = datetime.datetime.fromtimestamp(ts.seconds, datetime.timezone.utc) + timestamp = timestamp.replace(microsecond=ts.microseconds) + return timestamp + + +def test_archive_direct_revision_iterator(swh_storage_with_objects, archive_direct): + """Test FileOriginIterator""" + revisions_csv = [ + (rev.id, ts_to_dt(rev.date).isoformat(), rev.directory) + for rev in TEST_OBJECTS["revision"] + ] + revisions = list(CSVRevisionIterator(revisions_csv, archive_direct)) + assert revisions + assert len(revisions) == len(TEST_OBJECTS["revision"])