diff --git a/swh/provenance/interface.py b/swh/provenance/interface.py --- a/swh/provenance/interface.py +++ b/swh/provenance/interface.py @@ -23,6 +23,11 @@ ORIGIN = "origin" +class UnsupportedEntityError(Exception): + def __init__(self, entity: EntityType) -> None: + super().__init__(f"Unsupported entity: {entity.value}") + + class RelationType(enum.Enum): CNT_EARLY_IN_REV = "content_in_revision" CNT_IN_DIR = "content_in_directory" @@ -107,6 +112,15 @@ """ ... + @remote_api_endpoint("entity_add") + def entity_add(self, entity: EntityType, ids: Iterable[Sha1Git]) -> bool: + """Add entries to the selected `entity` with `None` values in all optional + fields. `EntityType.ORIGIN` is not supported by this method (it raises a + `UnsupportedEntityError`) since origins have non-optional associated fields + (ie. `url`). See `origin_set_url` for adding origin entries to the storage. + """ + ... + @remote_api_endpoint("entity_get_all") def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: """Retrieve all sha1 ids for entities of type `entity` present in the provenance @@ -114,6 +128,11 @@ """ ... + @remote_api_endpoint("location_add") + def location_add(self, paths: Iterable[bytes]) -> bool: + """Register the given `paths` in the storage.""" + ... + @remote_api_endpoint("location_get") def location_get(self) -> Set[bytes]: """Retrieve all paths present in the provenance model.""" @@ -144,7 +163,8 @@ def revision_set_origin(self, origins: Dict[Sha1Git, Sha1Git]) -> bool: """Associate origins to revisions identified by sha1 ids, as paired in `origins` (revision ids are keys and origin ids, values). Return a boolean - stating whether the information was successfully stored. + stating whether the information was successfully stored. This method assumes + all origins are are already registered in the storage. See `origin_set_url`. """ ... @@ -160,7 +180,10 @@ def relation_add( self, relation: RelationType, data: Iterable[RelationData] ) -> bool: - """Add entries in the selected `relation`.""" + """Add entries in the selected `relation`. This method assumes all entities + being related are already registered in the storage. See `entity_add` and + `origin_set_url`. + """ ... @remote_api_endpoint("relation_get") diff --git a/swh/provenance/mongo/backend.py b/swh/provenance/mongo/backend.py --- a/swh/provenance/mongo/backend.py +++ b/swh/provenance/mongo/backend.py @@ -18,6 +18,7 @@ RelationData, RelationType, RevisionData, + UnsupportedEntityError, ) @@ -171,6 +172,36 @@ ) } + def entity_add(self, entity: EntityType, ids: Iterable[Sha1Git]) -> bool: + if entity == EntityType.ORIGIN: + raise UnsupportedEntityError(entity) + + sha1s = list(set(ids)) + if ids: + obj: Dict[str, Any] = {"ts": None} + if entity.value == "content": + obj["revision"] = {} + obj["directory"] = {} + if entity.value == "directory": + obj["revision"] = {} + if entity.value == "revision": + obj["preferred"] = None + obj["origin"] = [] + obj["revision"] = [] + + existing = { + x["sha1"] + for x in self.db.get_collection(entity.value).find( + {"sha1": {"$in": sha1s}}, {"_id": 0, "sha1": 1} + ) + } + for sha1 in sha1s: + if sha1 not in existing: + self.db.get_collection(entity.value).insert_one( + dict(obj, **{"sha1": sha1}) + ) + return True + def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: return { x["sha1"] @@ -179,6 +210,10 @@ ) } + def location_add(self, paths: Iterable[bytes]) -> bool: + # TODO: implement this methods if path are to be stored in a separate collection + return True + def location_get(self) -> Set[bytes]: contents = self.db.content.find({}, {"revision": 1, "_id": 0, "directory": 1}) paths: List[Iterable[bytes]] = [] @@ -283,40 +318,10 @@ src_relation, *_, dst_relation = relation.value.split("_") set_data = set(data) - dst_sha1s = {x.dst for x in set_data} - if dst_relation in ["content", "directory", "revision"]: - dst_obj: Dict[str, Any] = {"ts": None} - if dst_relation == "content": - dst_obj["revision"] = {} - dst_obj["directory"] = {} - if dst_relation == "directory": - dst_obj["revision"] = {} - if dst_relation == "revision": - dst_obj["preferred"] = None - dst_obj["origin"] = [] - dst_obj["revision"] = [] - - existing = { - x["sha1"] - for x in self.db.get_collection(dst_relation).find( - {"sha1": {"$in": list(dst_sha1s)}}, {"_id": 0, "sha1": 1} - ) - } - - for sha1 in dst_sha1s: - if sha1 not in existing: - self.db.get_collection(dst_relation).insert_one( - dict(dst_obj, **{"sha1": sha1}) - ) - elif dst_relation == "origin": - # TODO, check origins are already in the DB - # if not, algo has something wrong (algo inserts it initially) - pass - dst_objs = { x["sha1"]: x["_id"] for x in self.db.get_collection(dst_relation).find( - {"sha1": {"$in": list(dst_sha1s)}}, {"_id": 1, "sha1": 1} + {"sha1": {"$in": [x.dst for x in set_data]}}, {"_id": 1, "sha1": 1} ) } @@ -337,42 +342,24 @@ } for sha1, dsts in denorm.items(): - if sha1 in src_objs: - # update - if src_relation != "revision": - k = { - obj_id: list(set(paths + dsts.get(obj_id, []))) - for obj_id, paths in src_objs[sha1][dst_relation].items() - } - self.db.get_collection(src_relation).update_one( - {"_id": src_objs[sha1]["_id"]}, - {"$set": {dst_relation: dict(dsts, **k)}}, - ) - else: - self.db.get_collection(src_relation).update_one( - {"_id": src_objs[sha1]["_id"]}, - { - "$set": { - dst_relation: list( - set(src_objs[sha1][dst_relation] + dsts) - ) - } - }, - ) + # update + if src_relation != "revision": + k = { + obj_id: list(set(paths + dsts.get(obj_id, []))) + for obj_id, paths in src_objs[sha1][dst_relation].items() + } + self.db.get_collection(src_relation).update_one( + {"_id": src_objs[sha1]["_id"]}, + {"$set": {dst_relation: dict(dsts, **k)}}, + ) else: - # add new rev - src_obj: Dict[str, Any] = {"ts": None} - if src_relation == "content": - src_obj["revision"] = {} - src_obj["directory"] = {} - if src_relation == "directory": - src_obj["revision"] = {} - if src_relation == "revision": - src_obj["preferred"] = None - src_obj["origin"] = [] - src_obj["revision"] = [] - self.db.get_collection(src_relation).insert_one( - dict(src_obj, **{"sha1": sha1, dst_relation: dsts}) + self.db.get_collection(src_relation).update_one( + {"_id": src_objs[sha1]["_id"]}, + { + "$set": { + dst_relation: list(set(src_objs[sha1][dst_relation] + dsts)) + } + }, ) return True diff --git a/swh/provenance/postgresql/provenance.py b/swh/provenance/postgresql/provenance.py --- a/swh/provenance/postgresql/provenance.py +++ b/swh/provenance/postgresql/provenance.py @@ -22,6 +22,7 @@ RelationData, RelationType, RevisionData, + UnsupportedEntityError, ) LOGGER = logging.getLogger(__name__) @@ -87,11 +88,52 @@ def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self._entity_get_date("directory", ids) + def entity_add(self, entity: EntityType, ids: Iterable[Sha1Git]) -> bool: + if entity == EntityType.ORIGIN: + raise UnsupportedEntityError(entity) + + try: + sha1s = [(sha1,) for sha1 in ids] + if sha1s: + sql = f""" + INSERT INTO {entity.value}(sha1) VALUES %s + ON CONFLICT DO NOTHING + """ + with self.transaction() as cursor: + psycopg2.extras.execute_values(cursor, sql, argslist=sha1s) + return True + except: # noqa: E722 + # Unexpected error occurred, rollback all changes and log message + LOGGER.exception("Unexpected error") + if self.raise_on_commit: + raise + return False + def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: with self.transaction(readonly=True) as cursor: cursor.execute(f"SELECT sha1 FROM {entity.value}") return {row["sha1"] for row in cursor} + def location_add(self, paths: Iterable[bytes]) -> bool: + if not self.with_path(): + return True + try: + values = [(path,) for path in paths] + if values: + sql = """ + INSERT INTO location(path) VALUES %s + ON CONFLICT DO NOTHING + """ + with self.transaction() as cursor: + psycopg2.extras.execute_values(cursor, sql, argslist=values) + return True + except: # noqa: E722 + # Unexpected error occurred, rollback all changes and log message + LOGGER.exception("Unexpected error") + if self.raise_on_commit: + raise + return False + def location_get(self) -> Set[bytes]: with self.transaction(readonly=True) as cursor: cursor.execute("SELECT location.path AS path FROM location") @@ -187,32 +229,6 @@ rel_table = relation.value src_table, *_, dst_table = rel_table.split("_") - if src_table != "origin": - # Origin entries should be inserted previously as they require extra - # non-null information - srcs = tuple(set((sha1,) for (sha1, _, _) in rows)) - sql = f""" - INSERT INTO {src_table}(sha1) VALUES %s - ON CONFLICT DO NOTHING - """ - with self.transaction() as cursor: - psycopg2.extras.execute_values( - cur=cursor, sql=sql, argslist=srcs - ) - - if dst_table != "origin": - # Origin entries should be inserted previously as they require extra - # non-null information - dsts = tuple(set((sha1,) for (_, sha1, _) in rows)) - sql = f""" - INSERT INTO {dst_table}(sha1) VALUES %s - ON CONFLICT DO NOTHING - """ - with self.transaction() as cursor: - psycopg2.extras.execute_values( - cur=cursor, sql=sql, argslist=dsts - ) - # Put the next three queries in a manual single transaction: # they use the same temp table with self.transaction() as cursor: diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -13,6 +13,7 @@ from swh.model.model import Sha1Git from .interface import ( + EntityType, ProvenanceResult, ProvenanceStorageInterface, RelationData, @@ -79,6 +80,49 @@ def flush(self) -> None: # Revision-content layer insertions ############################################ + # After relations, dates for the entities can be safely set, acknowledging that + # these entities won't need to be reprocessed in case of failure. + sha1s = { + src + for src, _, _ in self.cache["content_in_revision"] + | self.cache["content_in_directory"] + } + if sha1s: + while not self.storage.entity_add(EntityType.CONTENT, sha1s): + LOGGER.warning( + "Unable to write content entities to the storage. Retrying..." + ) + + sha1s = {dst for _, dst, _ in self.cache["content_in_directory"]} + if sha1s: + while not self.storage.entity_add(EntityType.DIRECTORY, sha1s): + LOGGER.warning( + "Unable to write directory entities to the storage. Retrying..." + ) + + sha1s = { + dst + for _, dst, _ in self.cache["content_in_revision"] + | self.cache["directory_in_revision"] + } + if sha1s: + while not self.storage.entity_add(EntityType.REVISION, sha1s): + LOGGER.warning( + "Unable to write revision entities to the storage. Retrying..." + ) + + paths = { + path + for _, _, path in self.cache["content_in_revision"] + | self.cache["content_in_directory"] + | self.cache["directory_in_revision"] + } + if paths: + while not self.storage.location_add(paths): + LOGGER.warning( + "Unable to write locations entities to the storage. Retrying..." + ) + # For this layer, relations need to be inserted first so that, in case of # failure, reprocessing the input does not generated an inconsistent database. if self.cache["content_in_revision"]: @@ -170,6 +214,17 @@ "Unable to write origins urls to the storage. Retrying..." ) + sha1s = ( + {src for src in self.cache["revision_origin"]["added"]} + # Destinations in this relation should match origins in the previous one + | {src for src in self.cache["revision_before_revision"]} + ) + if sha1s: + while not self.storage.entity_add(EntityType.REVISION, sha1s): + LOGGER.warning( + "Unable to write revision entities to the storage. Retrying..." + ) + # Second, flat models for revisions' histories (ie. revision-before-revision). data: Iterable[RelationData] = sum( [ diff --git a/swh/provenance/sql/40-funcs.sql b/swh/provenance/sql/40-funcs.sql --- a/swh/provenance/sql/40-funcs.sql +++ b/swh/provenance/sql/40-funcs.sql @@ -99,11 +99,6 @@ join_location text; begin if src_table in ('content'::regclass, 'directory'::regclass) then - insert into location(path) - select V.path - from tmp_relation_add as V - on conflict (path) do nothing; - select_fields := 'D.id, L.id'; join_location := 'inner join location as L on (L.path = V.path)'; else @@ -419,11 +414,6 @@ on_conflict text; begin if src_table in ('content'::regclass, 'directory'::regclass) then - insert into location(path) - select V.path - from tmp_relation_add as V - on conflict (path) do nothing; - select_fields := 'array_agg((D.id, L.id)::rel_dst)'; join_location := 'inner join location as L on (L.path = V.path)'; group_entries := 'group by S.id'; diff --git a/swh/provenance/tests/test_provenance_storage.py b/swh/provenance/tests/test_provenance_storage.py --- a/swh/provenance/tests/test_provenance_storage.py +++ b/swh/provenance/tests/test_provenance_storage.py @@ -3,7 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from datetime import datetime +from datetime import datetime, timezone import inspect import os from typing import Any, Dict, Iterable, Optional, Set @@ -13,6 +13,8 @@ from swh.model.hashutil import hash_to_bytes from swh.model.identifiers import origin_identifier from swh.model.model import Sha1Git +from swh.provenance.api.client import RemoteProvenanceStorage +from swh.provenance.archive import ArchiveInterface from swh.provenance.interface import ( EntityType, ProvenanceInterface, @@ -20,54 +22,211 @@ ProvenanceStorageInterface, RelationData, RelationType, + RevisionData, + UnsupportedEntityError, ) -from swh.provenance.tests.conftest import load_repo_data, ts2dt +from swh.provenance.model import OriginEntry, RevisionEntry +from swh.provenance.mongo.backend import ProvenanceStorageMongoDb +from swh.provenance.origin import origin_add +from swh.provenance.provenance import Provenance +from swh.provenance.revision import revision_add +from swh.provenance.tests.conftest import fill_storage, load_repo_data, ts2dt -def relation_add_and_compare_result( - relation: RelationType, - data: Set[RelationData], - refstorage: ProvenanceStorageInterface, - storage: ProvenanceStorageInterface, - with_path: bool = True, +@pytest.mark.parametrize( + "repo", + ("cmdbts2",), +) +def test_provenance_storage_content( + provenance_storage: ProvenanceStorageInterface, + repo: str, ) -> None: - assert data - assert refstorage.relation_add(relation, data) == storage.relation_add( - relation, data - ) + """Tests content methods for every `ProvenanceStorageInterface` implementation.""" - assert relation_compare_result( - refstorage.relation_get(relation, (reldata.src for reldata in data)), - storage.relation_get(relation, (reldata.src for reldata in data)), - with_path, - ) - assert relation_compare_result( - refstorage.relation_get( - relation, - (reldata.dst for reldata in data), - reverse=True, - ), - storage.relation_get( - relation, - (reldata.dst for reldata in data), - reverse=True, - ), - with_path, - ) - assert relation_compare_result( - refstorage.relation_get_all(relation), - storage.relation_get_all(relation), - with_path, - ) + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data(repo) + # Add all content present in the current repo to the storage, just assigning their + # creation dates. Then check that the returned results when querying are the same. + dates = {cnt["sha1_git"]: cnt["ctime"] for cnt in data["content"]} + assert dates + assert provenance_storage.content_set_date(dates) + assert provenance_storage.content_get(set(dates.keys())) == dates + assert provenance_storage.entity_get_all(EntityType.CONTENT) == set(dates.keys()) -def relation_compare_result( - expected: Set[RelationData], computed: Set[RelationData], with_path: bool -) -> bool: - return { - RelationData(reldata.src, reldata.dst, reldata.path if with_path else None) - for reldata in expected - } == computed + +@pytest.mark.parametrize( + "repo", + ("cmdbts2",), +) +def test_provenance_storage_directory( + provenance_storage: ProvenanceStorageInterface, + repo: str, +) -> None: + """Tests directory methods for every `ProvenanceStorageInterface` implementation.""" + + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data(repo) + + # Of all directories present in the current repo, only assign a date to those + # containing blobs (picking the max date among the available ones). Then check that + # the returned results when querying are the same. + def getmaxdate( + directory: Dict[str, Any], contents: Iterable[Dict[str, Any]] + ) -> datetime: + dates = [ + content["ctime"] + for entry in directory["entries"] + for content in contents + if entry["type"] == "file" and entry["target"] == content["sha1_git"] + ] + return max(dates) if dates else datetime.now(tz=timezone.utc) + + dates = {dir["id"]: getmaxdate(dir, data["content"]) for dir in data["directory"]} + assert dates + assert provenance_storage.directory_set_date(dates) + assert provenance_storage.directory_get(set(dates.keys())) == dates + assert provenance_storage.entity_get_all(EntityType.DIRECTORY) == set(dates.keys()) + + +@pytest.mark.parametrize( + "repo", + ("cmdbts2",), +) +def test_provenance_storage_entity( + provenance_storage: ProvenanceStorageInterface, + repo: str, +) -> None: + """Tests entity methods for every `ProvenanceStorageInterface` implementation.""" + + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data(repo) + + # Test EntityType.CONTENT + # Add all contents present in the current repo to the storage. Then check that the + # returned results when querying are the same. + sha1s = {cnt["sha1_git"] for cnt in data["content"]} + assert sha1s + assert provenance_storage.entity_add(EntityType.CONTENT, sha1s) + assert provenance_storage.entity_get_all(EntityType.CONTENT) == sha1s + + # Test EntityType.DIRECTORY + # Add all directories present in the current repo to the storage. Then check that + # the returned directories when querying are the same. + sha1s = {dir["id"] for dir in data["directory"]} + assert sha1s + assert provenance_storage.entity_add(EntityType.DIRECTORY, sha1s) + assert provenance_storage.entity_get_all(EntityType.DIRECTORY) == sha1s + + # Test EntityType.REVISION + # Add all revisions present in the current repo to the storage. Then check that the + # returned revisions when querying are the same. + sha1s = {rev["id"] for rev in data["revision"]} + assert sha1s + assert provenance_storage.entity_add(EntityType.REVISION, sha1s) + assert provenance_storage.entity_get_all(EntityType.REVISION) == sha1s + + # Test EntityType.ORIGIN + # Add all origins present in the current repo. It should fail with a + # `UnsupportedEntityError`. Then check that indeed nothing was inserted. + if not isinstance(provenance_storage, RemoteProvenanceStorage): + sha1s = {hash_to_bytes(origin_identifier(org)) for org in data["origin"]} + assert sha1s + with pytest.raises(UnsupportedEntityError) as error: + provenance_storage.entity_add(EntityType.ORIGIN, sha1s) + assert "Unsupported entity: origin" in str(error.value) + assert provenance_storage.entity_get_all(EntityType.ORIGIN) == set() + + +@pytest.mark.parametrize( + "repo", + ("cmdbts2",), +) +def test_provenance_storage_location( + provenance_storage: ProvenanceStorageInterface, + repo: str, +) -> None: + """Tests location methods for every `ProvenanceStorageInterface` implementation.""" + + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data(repo) + + # Add all names of entries present in the directories of the current repo as paths + # to the storage. Then check that the returned results when querying are the same. + paths = {entry["name"] for dir in data["directory"] for entry in dir["entries"]} + assert provenance_storage.location_add(paths) + + if isinstance(provenance_storage, ProvenanceStorageMongoDb): + # TODO: remove this when `location_add` is properly implemented for MongoDb. + return + + if provenance_storage.with_path(): + assert provenance_storage.location_get() == paths + else: + assert provenance_storage.location_get() == set() + + +@pytest.mark.parametrize( + "repo", + ("cmdbts2",), +) +def test_provenance_storage_origin( + provenance_storage: ProvenanceStorageInterface, + repo: str, +) -> None: + """Tests origin methods for every `ProvenanceStorageInterface` implementation.""" + + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data(repo) + + # Test origin methods. + # Add all origins present in the current repo to the storage. Then check that the + # returned results when querying are the same. + urls = {hash_to_bytes(origin_identifier(org)): org["url"] for org in data["origin"]} + assert urls + assert provenance_storage.origin_set_url(urls) + assert provenance_storage.origin_get(set(urls.keys())) == urls + assert provenance_storage.entity_get_all(EntityType.ORIGIN) == set(urls.keys()) + + +@pytest.mark.parametrize( + "repo", + ("cmdbts2",), +) +def test_provenance_storage_revision( + provenance_storage: ProvenanceStorageInterface, + repo: str, +) -> None: + """Tests revision methods for every `ProvenanceStorageInterface` implementation.""" + + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data(repo) + + # Test revision methods. + # Add all revisions present in the current repo to the storage, assigning their + # dates and an arbitrary origin to each one. Then check that the returned results + # when querying are the same. + origin = next(iter(data["origin"])) + org_sha1 = hash_to_bytes(origin_identifier(origin)) + # Origin must be inserted in advance. + assert provenance_storage.origin_set_url({org_sha1: origin["url"]}) + + dates = {rev["id"]: ts2dt(rev["date"]) for rev in data["revision"]} + orgs = {rev["id"]: org_sha1 for rev in data["revision"]} + assert set(dates.keys()) == set(orgs.keys()) + revs = { + rev: RevisionData(date, org) + for sha1, date in dates.items() + for rev, org in orgs.items() + if rev == sha1 + } + + assert dates + assert orgs + assert provenance_storage.revision_set_date(dates) + assert provenance_storage.revision_set_origin(orgs) + assert provenance_storage.revision_get(set(revs.keys())) == revs + assert provenance_storage.entity_get_all(EntityType.REVISION) == set(revs.keys()) def dircontent( @@ -94,41 +253,66 @@ return content +def relation_add_and_compare_result( + relation: RelationType, data: Set[RelationData], storage: ProvenanceStorageInterface +) -> None: + # Source, destinations and locations must be added in advance. + src, *_, dst = relation.value.split("_") + if src != "origin": + assert storage.entity_add(EntityType(src), {entry.src for entry in data}) + if dst != "origin": + assert storage.entity_add(EntityType(dst), {entry.dst for entry in data}) + if storage.with_path(): + assert storage.location_add( + {entry.path for entry in data if entry.path is not None} + ) + + assert data + assert storage.relation_add(relation, data) + + for row in data: + assert relation_compare_result( + storage.relation_get(relation, [row.src]), + {entry for entry in data if entry.src == row.src}, + storage.with_path(), + ) + assert relation_compare_result( + storage.relation_get( + relation, + [row.dst], + reverse=True, + ), + {entry for entry in data if entry.dst == row.dst}, + storage.with_path(), + ) + + assert relation_compare_result( + storage.relation_get_all(relation), data, storage.with_path() + ) + + +def relation_compare_result( + computed: Set[RelationData], expected: Set[RelationData], with_path: bool +) -> bool: + return { + RelationData(row.src, row.dst, row.path if with_path else None) + for row in expected + } == computed + + @pytest.mark.parametrize( "repo", - ("cmdbts2", "out-of-order", "with-merges"), + ("cmdbts2",), ) -def test_provenance_storage( - provenance: ProvenanceInterface, +def test_provenance_storage_relation( provenance_storage: ProvenanceStorageInterface, repo: str, ) -> None: - """Tests every ProvenanceStorageInterface implementation against the one provided - for provenance.storage.""" + """Tests relation methods for every `ProvenanceStorageInterface` implementation.""" + # Read data/README.md for more details on how these datasets are generated. data = load_repo_data(repo) - # Assuming provenance.storage has the 'with-path' flavor. - assert provenance.storage.with_path() - - # Test origin methods. - # Add all origins present in the current repo to both storages. Then check that the - # inserted data is the same in both cases. - org_urls = { - hash_to_bytes(origin_identifier(org)): org["url"] for org in data["origin"] - } - assert org_urls - assert provenance.storage.origin_set_url( - org_urls - ) == provenance_storage.origin_set_url(org_urls) - - assert provenance.storage.origin_get(org_urls) == provenance_storage.origin_get( - org_urls - ) - assert provenance.storage.entity_get_all( - EntityType.ORIGIN - ) == provenance_storage.entity_get_all(EntityType.ORIGIN) - # Test content-in-revision relation. # Create flat models of every root directory for the revisions in the dataset. cnt_in_rev: Set[RelationData] = set() @@ -137,13 +321,8 @@ subdir for subdir in data["directory"] if subdir["id"] == rev["directory"] ) cnt_in_rev.update(dircontent(data, rev["id"], root)) - relation_add_and_compare_result( - RelationType.CNT_EARLY_IN_REV, - cnt_in_rev, - provenance.storage, - provenance_storage, - provenance_storage.with_path(), + RelationType.CNT_EARLY_IN_REV, cnt_in_rev, provenance_storage ) # Test content-in-directory relation. @@ -151,13 +330,8 @@ cnt_in_dir: Set[RelationData] = set() for dir in data["directory"]: cnt_in_dir.update(dircontent(data, dir["id"], dir)) - relation_add_and_compare_result( - RelationType.CNT_IN_DIR, - cnt_in_dir, - provenance.storage, - provenance_storage, - provenance_storage.with_path(), + RelationType.CNT_IN_DIR, cnt_in_dir, provenance_storage ) # Test content-in-directory relation. @@ -165,13 +339,8 @@ dir_in_rev = { RelationData(rev["directory"], rev["id"], b".") for rev in data["revision"] } - relation_add_and_compare_result( - RelationType.DIR_IN_REV, - dir_in_rev, - provenance.storage, - provenance_storage, - provenance_storage.with_path(), + RelationType.DIR_IN_REV, dir_in_rev, provenance_storage ) # Test revision-in-origin relation. @@ -190,12 +359,16 @@ for _, branch in snapshot["branches"].items() if branch["target_type"] == "revision" } + # Origins must be inserted in advance (cannot be done by `entity_add` inside + # `relation_add_and_compare_result`). + urls = { + hash_to_bytes(origin_identifier(origin)): origin["url"] + for origin in data["origin"] + } + assert provenance_storage.origin_set_url(urls) relation_add_and_compare_result( - RelationType.REV_IN_ORG, - rev_in_org, - provenance.storage, - provenance_storage, + RelationType.REV_IN_ORG, rev_in_org, provenance_storage ) # Test revision-before-revision relation. @@ -205,87 +378,45 @@ for rev in data["revision"] for parent in rev["parents"] } - relation_add_and_compare_result( - RelationType.REV_BEFORE_REV, - rev_before_rev, - provenance.storage, - provenance_storage, + RelationType.REV_BEFORE_REV, rev_before_rev, provenance_storage ) - # Test content methods. - # Add all content present in the current repo to both storages, just assigning their - # creation dates. Then check that the inserted content is the same in both cases. - cnt_dates = {cnt["sha1_git"]: cnt["ctime"] for cnt in data["content"]} - assert cnt_dates - assert provenance.storage.content_set_date( - cnt_dates - ) == provenance_storage.content_set_date(cnt_dates) - - assert provenance.storage.content_get(cnt_dates) == provenance_storage.content_get( - cnt_dates - ) - assert provenance.storage.entity_get_all( - EntityType.CONTENT - ) == provenance_storage.entity_get_all(EntityType.CONTENT) - # Test directory methods. - # Of all directories present in the current repo, only assign a date to those - # containing blobs (picking the max date among the available ones). Then check that - # the inserted data is the same in both storages. - def getmaxdate( - dir: Dict[str, Any], cnt_dates: Dict[Sha1Git, datetime] - ) -> Optional[datetime]: - dates = [ - cnt_dates[entry["target"]] - for entry in dir["entries"] - if entry["type"] == "file" - ] - return max(dates) if dates else None - - dir_dates = {dir["id"]: getmaxdate(dir, cnt_dates) for dir in data["directory"]} - assert dir_dates - assert provenance.storage.directory_set_date( - {sha1: date for sha1, date in dir_dates.items() if date is not None} - ) == provenance_storage.directory_set_date( - {sha1: date for sha1, date in dir_dates.items() if date is not None} - ) - assert provenance.storage.directory_get( - dir_dates - ) == provenance_storage.directory_get(dir_dates) - assert provenance.storage.entity_get_all( - EntityType.DIRECTORY - ) == provenance_storage.entity_get_all(EntityType.DIRECTORY) +@pytest.mark.parametrize( + "repo", + ("cmdbts2",), +) +def test_provenance_storage_find( + archive: ArchiveInterface, + provenance: ProvenanceInterface, + provenance_storage: ProvenanceStorageInterface, + repo: str, +) -> None: + """Tests `content_find_first` and `content_find_all` methods for every + `ProvenanceStorageInterface` implementation. + """ - # Test revision methods. - # Add all revisions present in the current repo to both storages, assigning their - # dataes and an arbitrary origin to each one. Then check that the inserted data is - # the same in both cases. - rev_dates = {rev["id"]: ts2dt(rev["date"]) for rev in data["revision"]} - assert rev_dates - assert provenance.storage.revision_set_date( - rev_dates - ) == provenance_storage.revision_set_date(rev_dates) - - rev_origins = { - rev["id"]: next(iter(org_urls)) # any arbitrary origin will do + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data(repo) + fill_storage(archive.storage, data) + + # Execute the origin-revision algorithm on both storages. + origins = [ + OriginEntry(url=sta["origin"], snapshot=sta["snapshot"]) + for sta in data["origin_visit_status"] + if sta["snapshot"] is not None + ] + origin_add(provenance, archive, origins) + origin_add(Provenance(provenance_storage), archive, origins) + + # Execute the revision-content algorithm on both storages. + revisions = [ + RevisionEntry(id=rev["id"], date=ts2dt(rev["date"]), root=rev["directory"]) for rev in data["revision"] - } - assert rev_origins - assert provenance.storage.revision_set_origin( - rev_origins - ) == provenance_storage.revision_set_origin(rev_origins) - - assert provenance.storage.revision_get( - rev_dates - ) == provenance_storage.revision_get(rev_dates) - assert provenance.storage.entity_get_all( - EntityType.REVISION - ) == provenance_storage.entity_get_all(EntityType.REVISION) - - # Test location_get. - if provenance_storage.with_path(): - assert provenance.storage.location_get() == provenance_storage.location_get() + ] + revision_add(provenance, archive, revisions) + revision_add(Provenance(provenance_storage), archive, revisions) # Test content_find_first and content_find_all. def adapt_result( @@ -301,7 +432,7 @@ ) return result - for cnt in cnt_dates: + for cnt in {cnt["sha1_git"] for cnt in data["content"]}: assert adapt_result( provenance.storage.content_find_first(cnt), provenance_storage.with_path() ) == provenance_storage.content_find_first(cnt) @@ -312,7 +443,7 @@ } == set(provenance_storage.content_find_all(cnt)) -def test_types(provenance: ProvenanceInterface) -> None: +def test_types(provenance_storage: ProvenanceInterface) -> None: """Checks all methods of ProvenanceStorageInterface are implemented by this backend, and that they have the same signature.""" # Create an instance of the protocol (which cannot be instantiated @@ -328,7 +459,7 @@ continue interface_meth = getattr(interface, meth_name) try: - concrete_meth = getattr(provenance.storage, meth_name) + concrete_meth = getattr(provenance_storage, meth_name) except AttributeError: if not getattr(interface_meth, "deprecated_endpoint", False): # The backend is missing a (non-deprecated) endpoint @@ -346,4 +477,4 @@ # But there's no harm in double-checking. # And we could replace the assertions above by this one, but unlike # the assertions above, it doesn't explain what is missing. - assert isinstance(provenance.storage, ProvenanceStorageInterface) + assert isinstance(provenance_storage, ProvenanceStorageInterface)