diff --git a/swh/provenance/archive.py b/swh/provenance/archive.py --- a/swh/provenance/archive.py +++ b/swh/provenance/archive.py @@ -8,10 +8,13 @@ from typing_extensions import Protocol, runtime_checkable from swh.model.model import Sha1Git +from swh.storage.interface import StorageInterface @runtime_checkable class ArchiveInterface(Protocol): + storage: StorageInterface + def directory_ls(self, id: Sha1Git) -> Iterable[Dict[str, Any]]: """List entries for one directory. diff --git a/swh/provenance/postgresql/archive.py b/swh/provenance/postgresql/archive.py --- a/swh/provenance/postgresql/archive.py +++ b/swh/provenance/postgresql/archive.py @@ -9,13 +9,15 @@ import psycopg2.extensions from swh.model.model import Sha1Git -from swh.storage.postgresql.storage import Storage +from swh.storage import get_storage class ArchivePostgreSQL: def __init__(self, conn: psycopg2.extensions.connection) -> None: + self.storage = get_storage( + "postgresql", db=conn.dsn, objstorage={"cls": "memory"} + ) self.conn = conn - self.storage = Storage(conn, objstorage={"cls": "memory"}) def directory_ls(self, id: Sha1Git) -> Iterable[Dict[str, Any]]: entries = self._directory_ls(id) diff --git a/swh/provenance/tests/conftest.py b/swh/provenance/tests/conftest.py --- a/swh/provenance/tests/conftest.py +++ b/swh/provenance/tests/conftest.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from datetime import datetime, timedelta, timezone from os import path from typing import Any, Dict, Iterable, Iterator @@ -12,15 +13,13 @@ import pytest from swh.journal.serializers import msgpack_ext_hook -from swh.model.tests.swh_model_data import TEST_OBJECTS from swh.provenance import get_provenance, get_provenance_storage from swh.provenance.api.client import RemoteProvenanceStorage import swh.provenance.api.server as server from swh.provenance.archive import ArchiveInterface from swh.provenance.interface import ProvenanceInterface, ProvenanceStorageInterface -from swh.provenance.postgresql.archive import ArchivePostgreSQL from swh.provenance.storage.archive import ArchiveStorage -from swh.storage.postgresql.storage import Storage +from swh.storage.interface import StorageInterface from swh.storage.replay import process_replay_objects @@ -80,50 +79,9 @@ @pytest.fixture -def swh_storage_with_objects(swh_storage: Storage) -> Storage: - """return a Storage object (postgresql-based by default) with a few of each - object type in it - - The inserted content comes from swh.model.tests.swh_model_data. - """ - for obj_type in ( - "content", - "skipped_content", - "directory", - "revision", - "release", - "snapshot", - "origin", - "origin_visit", - "origin_visit_status", - ): - getattr(swh_storage, f"{obj_type}_add")(TEST_OBJECTS[obj_type]) - return swh_storage - - -@pytest.fixture -def archive_direct(swh_storage_with_objects: Storage) -> ArchiveInterface: - return ArchivePostgreSQL(swh_storage_with_objects.get_db().conn) - - -@pytest.fixture -def archive_api(swh_storage_with_objects: Storage) -> ArchiveInterface: - return ArchiveStorage(swh_storage_with_objects) - - -@pytest.fixture(params=["archive", "db"]) -def archive(request, swh_storage_with_objects: Storage) -> Iterator[ArchiveInterface]: - """Return a ArchivePostgreSQL based StorageInterface object""" - # this is a workaround to prevent tests from hanging because of an unclosed - # transaction. - # TODO: refactor the ArchivePostgreSQL to properly deal with - # transactions and get rid of this fixture - if request.param == "db": - archive = ArchivePostgreSQL(conn=swh_storage_with_objects.get_db().conn) - yield archive - archive.conn.rollback() - else: - yield ArchiveStorage(swh_storage_with_objects) +def archive(swh_storage: StorageInterface) -> ArchiveInterface: + """Return an ArchiveStorage-based ArchiveInterface object""" + return ArchiveStorage(swh_storage) def get_datafile(fname: str) -> str: @@ -149,5 +107,14 @@ return {k: v for (k, v) in d.items() if k in keys} -def fill_storage(storage: Storage, data: Dict[str, Any]) -> None: +def fill_storage(storage: StorageInterface, data: Dict[str, Any]) -> None: process_replay_objects(data, storage=storage) + + +# TODO: remove this function in favour of TimestampWithTimezone.to_datetime +# from swh.model.model +def ts2dt(ts: Dict[str, Any]) -> datetime: + timestamp = datetime.fromtimestamp( + ts["timestamp"]["seconds"], timezone(timedelta(minutes=ts["offset"])) + ) + return timestamp.replace(microsecond=ts["timestamp"]["microseconds"]) diff --git a/swh/provenance/tests/test_archive_interface.py b/swh/provenance/tests/test_archive_interface.py --- a/swh/provenance/tests/test_archive_interface.py +++ b/swh/provenance/tests/test_archive_interface.py @@ -12,6 +12,7 @@ from swh.provenance.postgresql.archive import ArchivePostgreSQL from swh.provenance.storage.archive import ArchiveStorage from swh.provenance.tests.conftest import fill_storage, load_repo_data +from swh.storage.interface import StorageInterface from swh.storage.postgresql.storage import Storage @@ -19,8 +20,9 @@ "repo", ("cmdbts2", "out-of-order", "with-merges"), ) -def test_archive_interface(repo: str, swh_storage: Storage) -> None: +def test_archive_interface(repo: str, swh_storage: StorageInterface) -> None: archive_api = ArchiveStorage(swh_storage) + assert isinstance(swh_storage, Storage) dsn = swh_storage.get_db().conn.dsn with BaseDb.connect(dsn).conn as conn: BaseDb.adapt_conn(conn) diff --git a/swh/provenance/tests/test_conftest.py b/swh/provenance/tests/test_conftest.py --- a/swh/provenance/tests/test_conftest.py +++ b/swh/provenance/tests/test_conftest.py @@ -4,7 +4,8 @@ # See top-level LICENSE file for more information from swh.provenance.interface import ProvenanceInterface -from swh.storage.postgresql.storage import Storage +from swh.provenance.tests.conftest import fill_storage, load_repo_data +from swh.storage.interface import StorageInterface def test_provenance_fixture(provenance: ProvenanceInterface) -> None: @@ -13,10 +14,13 @@ provenance.flush() # should be a noop -def test_storage(swh_storage_with_objects: Storage) -> None: - """Check the 'swh_storage_with_objects' fixture produce a working Storage +def test_fill_storage(swh_storage: StorageInterface) -> None: + """Check the 'fill_storage' test utility produces a working Storage object with at least some Content, Revision and Directory in it""" - assert swh_storage_with_objects - assert swh_storage_with_objects.content_get_random() - assert swh_storage_with_objects.directory_get_random() - assert swh_storage_with_objects.revision_get_random() + data = load_repo_data("cmdbts2") + fill_storage(swh_storage, data) + + assert swh_storage + assert swh_storage.content_get_random() + assert swh_storage.directory_get_random() + assert swh_storage.revision_get_random() diff --git a/swh/provenance/tests/test_history_graph.py b/swh/provenance/tests/test_history_graph.py --- a/swh/provenance/tests/test_history_graph.py +++ b/swh/provenance/tests/test_history_graph.py @@ -15,7 +15,6 @@ from swh.provenance.model import OriginEntry, RevisionEntry from swh.provenance.origin import origin_add_revision from swh.provenance.tests.conftest import fill_storage, get_datafile, load_repo_data -from swh.storage.postgresql.storage import Storage def history_graph_from_dict(d: Dict[str, Any]) -> HistoryNode: @@ -39,7 +38,6 @@ @pytest.mark.parametrize("batch", (True, False)) def test_history_graph( provenance: ProvenanceInterface, - swh_storage: Storage, archive: ArchiveInterface, repo: str, visit: str, @@ -47,7 +45,7 @@ ) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) - fill_storage(swh_storage, data) + fill_storage(archive.storage, data) filename = f"history_graphs_{repo}_{visit}.yaml" diff --git a/swh/provenance/tests/test_isochrone_graph.py b/swh/provenance/tests/test_isochrone_graph.py --- a/swh/provenance/tests/test_isochrone_graph.py +++ b/swh/provenance/tests/test_isochrone_graph.py @@ -16,9 +16,12 @@ from swh.provenance.interface import ProvenanceInterface from swh.provenance.model import DirectoryEntry, RevisionEntry from swh.provenance.revision import revision_add -from swh.provenance.tests.conftest import fill_storage, get_datafile, load_repo_data -from swh.provenance.tests.test_provenance_db import ts2dt -from swh.storage.postgresql.storage import Storage +from swh.provenance.tests.conftest import ( + fill_storage, + get_datafile, + load_repo_data, + ts2dt, +) def isochrone_graph_from_dict(d: Dict[str, Any], depth: int = 0) -> IsochroneNode: @@ -63,7 +66,6 @@ @pytest.mark.parametrize("batch", (True, False)) def test_isochrone_graph( provenance: ProvenanceInterface, - swh_storage: Storage, archive: ArchiveInterface, repo: str, lower: bool, @@ -72,7 +74,7 @@ ) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) - fill_storage(swh_storage, data) + fill_storage(archive.storage, data) revisions = {rev["id"]: rev for rev in data["revision"]} filename = f"graphs_{repo}_{'lower' if lower else 'upper'}_{mindepth}.yaml" diff --git a/swh/provenance/tests/test_origin_iterator.py b/swh/provenance/tests/test_origin_iterator.py --- a/swh/provenance/tests/test_origin_iterator.py +++ b/swh/provenance/tests/test_origin_iterator.py @@ -3,36 +3,40 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from swh.model.model import OriginVisitStatus -from swh.model.tests.swh_model_data import TEST_OBJECTS +import pytest + from swh.provenance.origin import CSVOriginIterator +from swh.provenance.tests.conftest import fill_storage, load_repo_data from swh.storage.algos.origin import ( iter_origin_visit_statuses, iter_origin_visits, iter_origins, ) -from swh.storage.postgresql.storage import Storage +from swh.storage.interface import StorageInterface -def test_origin_iterator(swh_storage_with_objects: Storage) -> None: +@pytest.mark.parametrize( + "repo", + ( + "cmdbts2", + "out-of-order", + ), +) +def test_origin_iterator(swh_storage: StorageInterface, repo: str) -> None: """Test CSVOriginIterator""" + data = load_repo_data(repo) + fill_storage(swh_storage, data) + origins_csv = [] - for origin in iter_origins(swh_storage_with_objects): - for visit in iter_origin_visits(swh_storage_with_objects, origin.url): + for origin in iter_origins(swh_storage): + for visit in iter_origin_visits(swh_storage, origin.url): if visit.visit is not None: for status in iter_origin_visit_statuses( - swh_storage_with_objects, origin.url, visit.visit + swh_storage, origin.url, visit.visit ): if status.snapshot is not None: origins_csv.append((status.origin, status.snapshot)) origins = list(CSVOriginIterator(origins_csv)) + assert origins - assert len(origins) == len( - list( - { - status.origin - for status in TEST_OBJECTS["origin_visit_status"] - if isinstance(status, OriginVisitStatus) and status.snapshot is not None - } - ) - ) + assert len(origins) == len(data["origin"]) diff --git a/swh/provenance/tests/test_origin_revision_layer.py b/swh/provenance/tests/test_origin_revision_layer.py --- a/swh/provenance/tests/test_origin_revision_layer.py +++ b/swh/provenance/tests/test_origin_revision_layer.py @@ -16,7 +16,6 @@ from swh.provenance.model import OriginEntry from swh.provenance.origin import origin_add from swh.provenance.tests.conftest import fill_storage, get_datafile, load_repo_data -from swh.storage.postgresql.storage import Storage class SynthRelation(TypedDict): @@ -119,14 +118,13 @@ ) def test_origin_revision_layer( provenance: ProvenanceInterface, - swh_storage: Storage, archive: ArchiveInterface, repo: str, visit: str, ) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) - fill_storage(swh_storage, data) + fill_storage(archive.storage, data) syntheticfile = get_datafile(f"origin-revision_{repo}_{visit}.txt") origins = [ diff --git a/swh/provenance/tests/test_provenance_db.py b/swh/provenance/tests/test_provenance_db.py --- a/swh/provenance/tests/test_provenance_db.py +++ b/swh/provenance/tests/test_provenance_db.py @@ -3,38 +3,8 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from datetime import datetime, timedelta, timezone - -from swh.model.model import OriginVisitStatus -from swh.model.tests.swh_model_data import TEST_OBJECTS from swh.provenance.interface import ProvenanceInterface -from swh.provenance.model import OriginEntry -from swh.provenance.origin import origin_add from swh.provenance.postgresql.provenancedb import ProvenanceDB -from swh.provenance.storage.archive import ArchiveStorage -from swh.storage.postgresql.storage import Storage - - -# TODO: remove this function in favour of TimestampWithTimezone.to_datetime -# from swh.model.model -def ts2dt(ts: dict) -> datetime: - timestamp = datetime.fromtimestamp( - ts["timestamp"]["seconds"], timezone(timedelta(minutes=ts["offset"])) - ) - return timestamp.replace(microsecond=ts["timestamp"]["microseconds"]) - - -def test_provenance_origin_add( - provenance: ProvenanceInterface, swh_storage_with_objects: Storage -) -> None: - """Test the origin_add function""" - archive = ArchiveStorage(swh_storage_with_objects) - for status in TEST_OBJECTS["origin_visit_status"]: - assert isinstance(status, OriginVisitStatus) - if status.snapshot is not None: - entry = OriginEntry(url=status.origin, snapshot=status.snapshot) - origin_add(provenance, archive, [entry]) - # TODO: check some facts here def test_provenance_flavor(provenance: ProvenanceInterface) -> None: diff --git a/swh/provenance/tests/test_revision_content_layer.py b/swh/provenance/tests/test_revision_content_layer.py --- a/swh/provenance/tests/test_revision_content_layer.py +++ b/swh/provenance/tests/test_revision_content_layer.py @@ -15,9 +15,12 @@ from swh.provenance.interface import EntityType, ProvenanceInterface, RelationType from swh.provenance.model import RevisionEntry from swh.provenance.revision import revision_add -from swh.provenance.tests.conftest import fill_storage, get_datafile, load_repo_data -from swh.provenance.tests.test_provenance_db import ts2dt -from swh.storage.postgresql.storage import Storage +from swh.provenance.tests.conftest import ( + fill_storage, + get_datafile, + load_repo_data, + ts2dt, +) class SynthRelation(TypedDict): @@ -156,7 +159,6 @@ ) def test_revision_content_result( provenance: ProvenanceInterface, - swh_storage: Storage, archive: ArchiveInterface, repo: str, lower: bool, @@ -164,7 +166,7 @@ ) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) - fill_storage(swh_storage, data) + fill_storage(archive.storage, data) syntheticfile = get_datafile( f"synthetic_{repo}_{'lower' if lower else 'upper'}_{mindepth}.txt" ) @@ -298,7 +300,6 @@ @pytest.mark.parametrize("batch", (True, False)) def test_provenance_heuristics_content_find_all( provenance: ProvenanceInterface, - swh_storage: Storage, archive: ArchiveInterface, repo: str, lower: bool, @@ -307,7 +308,7 @@ ) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) - fill_storage(swh_storage, data) + fill_storage(archive.storage, data) revisions = [ RevisionEntry( id=revision["id"], @@ -381,7 +382,6 @@ @pytest.mark.parametrize("batch", (True, False)) def test_provenance_heuristics_content_find_first( provenance: ProvenanceInterface, - swh_storage: Storage, archive: ArchiveInterface, repo: str, lower: bool, @@ -390,7 +390,7 @@ ) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) - fill_storage(swh_storage, data) + fill_storage(archive.storage, data) revisions = [ RevisionEntry( id=revision["id"], diff --git a/swh/provenance/tests/test_revision_iterator.py b/swh/provenance/tests/test_revision_iterator.py --- a/swh/provenance/tests/test_revision_iterator.py +++ b/swh/provenance/tests/test_revision_iterator.py @@ -6,9 +6,8 @@ import pytest from swh.provenance.revision import CSVRevisionIterator -from swh.provenance.tests.conftest import fill_storage, load_repo_data -from swh.provenance.tests.test_provenance_db import ts2dt -from swh.storage.postgresql.storage import Storage +from swh.provenance.tests.conftest import fill_storage, load_repo_data, ts2dt +from swh.storage.interface import StorageInterface @pytest.mark.parametrize( @@ -18,13 +17,15 @@ "out-of-order", ), ) -def test_archive_direct_revision_iterator(swh_storage: Storage, repo: str) -> None: +def test_revision_iterator(swh_storage: StorageInterface, repo: str) -> None: """Test CSVRevisionIterator""" data = load_repo_data(repo) fill_storage(swh_storage, data) + revisions_csv = [ (rev["id"], ts2dt(rev["date"]), rev["directory"]) for rev in data["revision"] ] revisions = list(CSVRevisionIterator(revisions_csv)) + assert revisions assert len(revisions) == len(data["revision"])