diff --git a/mypy.ini b/mypy.ini --- a/mypy.ini +++ b/mypy.ini @@ -5,6 +5,9 @@ # 3rd party libraries without stubs (yet) +[mypy-bson.*] +ignore_missing_imports = True + [mypy-iso8601.*] ignore_missing_imports = True @@ -17,6 +20,9 @@ [mypy-pkg_resources.*] ignore_missing_imports = True +[mypy-pymongo.*] +ignore_missing_imports = True + [mypy-pytest.*] ignore_missing_imports = True diff --git a/requirements-test.txt b/requirements-test.txt --- a/requirements-test.txt +++ b/requirements-test.txt @@ -2,3 +2,4 @@ swh.loader.git >= 0.8 swh.journal >= 0.8 types-Werkzeug +mongomock diff --git a/requirements.txt b/requirements.txt --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,8 @@ click iso8601 methodtools +pymongo PyYAML types-click types-PyYAML +types-Werkzeug diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py --- a/swh/provenance/__init__.py +++ b/swh/provenance/__init__.py @@ -87,6 +87,11 @@ raise_on_commit = kwargs.get("raise_on_commit", False) return ProvenanceStoragePostgreSql(conn, raise_on_commit) + elif cls == "mongodb": + from .mongo.backend import ProvenanceStorageMongoDb + + return ProvenanceStorageMongoDb(**kwargs["db"]) + elif cls == "remote": from .api.client import RemoteProvenanceStorage diff --git a/swh/provenance/cli.py b/swh/provenance/cli.py --- a/swh/provenance/cli.py +++ b/swh/provenance/cli.py @@ -26,11 +26,13 @@ DEFAULT_CONFIG: Dict[str, Any] = { "provenance": { "archive": { + # Storage API based Archive object # "cls": "api", # "storage": { # "cls": "remote", # "url": "http://uffizi.internal.softwareheritage.org:5002", # } + # Direct access Archive object "cls": "direct", "db": { "host": "db.internal.softwareheritage.org", @@ -39,8 +41,22 @@ }, }, "storage": { - "cls": "postgresql", - "db": {"host": "localhost", "dbname": "provenance"}, + # Local PostgreSQL Storage + # "cls": "postgresql", + # "db": { + # "host": "localhost", + # "user": "postgres", + # "password": "postgres", + # "dbname": "provenance", + # }, + # Local MongoDB Storage + # "cls": "mongodb", + # "db": { + # "dbname": "provenance", + # }, + # Remote REST-API/PostgreSQL + # "cls": "remote", + # "url": "http://localhost:8080/%2f", }, } } diff --git a/swh/provenance/mongo/README.md b/swh/provenance/mongo/README.md new file mode 100644 --- /dev/null +++ b/swh/provenance/mongo/README.md @@ -0,0 +1,44 @@ +mongo backend +============= + +Provenance storage implementation using MongoDB + +initial data-model +------------------ + +```json +content +{ + id: sha1 + ts: int //optional + revision: {: []} + directory: {: []} +} + +directory +{ + id: sha1 + ts: int //optional + revision: {: []} +} + +revision +{ + id: sha1 + ts: int // optional + preferred //optinal + origin [] + revision [] +} + +origin +{ + id: sha1 + url: str +} + +path +{ + path: str +} +``` diff --git a/swh/provenance/mongo/__init__.py b/swh/provenance/mongo/__init__.py new file mode 100644 diff --git a/swh/provenance/mongo/backend.py b/swh/provenance/mongo/backend.py new file mode 100644 --- /dev/null +++ b/swh/provenance/mongo/backend.py @@ -0,0 +1,488 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from datetime import datetime, timezone +import os +from typing import Any, Dict, Generator, Iterable, List, Optional, Set + +from bson import ObjectId +from pymongo import MongoClient + +from swh.model.model import Sha1Git + +from ..interface import ( + EntityType, + ProvenanceResult, + RelationData, + RelationType, + RevisionData, +) + + +class ProvenanceStorageMongoDb: + def __init__(self, dbname: str, **kwargs): + self.db = MongoClient(**kwargs).get_database(dbname) + + def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: + # get all the revisions + # iterate and find the earliest + content = self.db.content.find_one({"sha1": id}) + if not content: + return None + + occurs = [] + for revision in self.db.revision.find( + {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["revision"]]}} + ): + origin = self.db.origin.find_one({"sha1": revision["preferred"]}) + assert origin is not None + + for path in content["revision"][str(revision["_id"])]: + occurs.append( + ProvenanceResult( + content=id, + revision=revision["sha1"], + date=datetime.fromtimestamp(revision["ts"], timezone.utc), + origin=origin["url"], + path=path, + ) + ) + return sorted(occurs, key=lambda x: (x.date, x.revision, x.origin, x.path))[0] + + def content_find_all( + self, id: Sha1Git, limit: Optional[int] = None + ) -> Generator[ProvenanceResult, None, None]: + content = self.db.content.find_one({"sha1": id}) + if not content: + return None + + occurs = [] + for revision in self.db.revision.find( + {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["revision"]]}} + ): + origin = self.db.origin.find_one({"sha1": revision["preferred"]}) + assert origin is not None + + for path in content["revision"][str(revision["_id"])]: + occurs.append( + ProvenanceResult( + content=id, + revision=revision["sha1"], + date=datetime.fromtimestamp(revision["ts"], timezone.utc), + origin=origin["url"], + path=path, + ) + ) + for directory in self.db.directory.find( + {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["directory"]]}} + ): + for revision in self.db.revision.find( + {"_id": {"$in": [ObjectId(obj_id) for obj_id in directory["revision"]]}} + ): + origin = self.db.origin.find_one({"sha1": revision["preferred"]}) + assert origin is not None + + for suffix in content["directory"][str(directory["_id"])]: + for prefix in directory["revision"][str(revision["_id"])]: + path = ( + os.path.join(prefix, suffix) + if prefix not in [b".", b""] + else suffix + ) + occurs.append( + ProvenanceResult( + content=id, + revision=revision["sha1"], + date=datetime.fromtimestamp( + revision["ts"], timezone.utc + ), + origin=origin["url"], + path=path, + ) + ) + yield from sorted(occurs, key=lambda x: (x.date, x.revision, x.origin, x.path)) + + def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: + return { + x["sha1"]: datetime.fromtimestamp(x["ts"], timezone.utc) + for x in self.db.content.find( + {"sha1": {"$in": list(ids)}, "ts": {"$ne": None}}, + {"sha1": 1, "ts": 1, "_id": 0}, + ) + } + + def content_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: + # get all the docuemtns with the id, add date, add missing records + cnts = { + x["sha1"]: x + for x in self.db.content.find( + {"sha1": {"$in": list(dates)}}, {"sha1": 1, "ts": 1, "_id": 1} + ) + } + + for sha1, date in dates.items(): + ts = datetime.timestamp(date) + if sha1 in cnts: + # update + if cnts[sha1]["ts"] is None or ts < cnts[sha1]["ts"]: + self.db.content.update_one( + {"_id": cnts[sha1]["_id"]}, {"$set": {"ts": ts}} + ) + else: + # add new content + self.db.content.insert_one( + { + "sha1": sha1, + "ts": ts, + "revision": {}, + "directory": {}, + } + ) + return True + + def directory_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: + dirs = { + x["sha1"]: x + for x in self.db.directory.find( + {"sha1": {"$in": list(dates)}}, {"sha1": 1, "ts": 1, "_id": 1} + ) + } + for sha1, date in dates.items(): + ts = datetime.timestamp(date) + if sha1 in dirs: + # update + if dirs[sha1]["ts"] is None or ts < dirs[sha1]["ts"]: + self.db.directory.update_one( + {"_id": dirs[sha1]["_id"]}, {"$set": {"ts": ts}} + ) + else: + # add new dir + self.db.directory.insert_one({"sha1": sha1, "ts": ts, "revision": {}}) + return True + + def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: + return { + x["sha1"]: datetime.fromtimestamp(x["ts"], timezone.utc) + for x in self.db.directory.find( + {"sha1": {"$in": list(ids)}, "ts": {"$ne": None}}, + {"sha1": 1, "ts": 1, "_id": 0}, + ) + } + + def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: + return { + x["sha1"] + for x in self.db.get_collection(entity.value).find( + {}, {"sha1": 1, "_id": 0} + ) + } + + def location_get(self) -> Set[bytes]: + contents = self.db.content.find({}, {"revision": 1, "_id": 0, "directory": 1}) + paths: List[Iterable[bytes]] = [] + for content in contents: + paths.extend(value for _, value in content["revision"].items()) + paths.extend(value for _, value in content["directory"].items()) + + dirs = self.db.directory.find({}, {"revision": 1, "_id": 0}) + for each_dir in dirs: + paths.extend(value for _, value in each_dir["revision"].items()) + return set(sum(paths, [])) + + def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool: + origins = { + x["sha1"]: x + for x in self.db.origin.find( + {"sha1": {"$in": list(urls)}}, {"sha1": 1, "url": 1, "_id": 1} + ) + } + for sha1, url in urls.items(): + if sha1 not in origins: + # add new origin + self.db.origin.insert_one({"sha1": sha1, "url": url}) + return True + + def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: + return { + x["sha1"]: x["url"] + for x in self.db.origin.find( + {"sha1": {"$in": list(ids)}}, {"sha1": 1, "url": 1, "_id": 0} + ) + } + + def revision_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: + revs = { + x["sha1"]: x + for x in self.db.revision.find( + {"sha1": {"$in": list(dates)}}, {"sha1": 1, "ts": 1, "_id": 1} + ) + } + for sha1, date in dates.items(): + ts = datetime.timestamp(date) + if sha1 in revs: + # update + if revs[sha1]["ts"] is None or ts < revs[sha1]["ts"]: + self.db.revision.update_one( + {"_id": revs[sha1]["_id"]}, {"$set": {"ts": ts}} + ) + else: + # add new rev + self.db.revision.insert_one( + { + "sha1": sha1, + "preferred": None, + "origin": [], + "revision": [], + "ts": ts, + } + ) + return True + + def revision_set_origin(self, origins: Dict[Sha1Git, Sha1Git]) -> bool: + revs = { + x["sha1"]: x + for x in self.db.revision.find( + {"sha1": {"$in": list(origins)}}, {"sha1": 1, "preferred": 1, "_id": 1} + ) + } + for sha1, origin in origins.items(): + if sha1 in revs: + self.db.revision.update_one( + {"_id": revs[sha1]["_id"]}, {"$set": {"preferred": origin}} + ) + else: + # add new rev + self.db.revision.insert_one( + { + "sha1": sha1, + "preferred": origin, + "origin": [], + "revision": [], + "ts": None, + } + ) + return True + + def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: + return { + x["sha1"]: RevisionData( + date=datetime.fromtimestamp(x["ts"], timezone.utc) if x["ts"] else None, + origin=x["preferred"], + ) + for x in self.db.revision.find( + {"sha1": {"$in": list(ids)}}, + {"sha1": 1, "preferred": 1, "ts": 1, "_id": 0}, + ) + } + + def relation_add( + self, relation: RelationType, data: Iterable[RelationData] + ) -> bool: + src_relation, *_, dst_relation = relation.value.split("_") + set_data = set(data) + + dst_sha1s = {x.dst for x in data} + if dst_relation in ["content", "directory", "revision"]: + dst_obj: Dict[str, Any] = {"ts": None} + if dst_relation == "content": + dst_obj["revision"] = {} + dst_obj["directory"] = {} + if dst_relation == "directory": + dst_obj["revision"] = {} + if dst_relation == "revision": + dst_obj["preferred"] = None + dst_obj["origin"] = [] + dst_obj["revision"] = [] + + existing = { + x["sha1"] + for x in self.db.get_collection(dst_relation).find( + {"sha1": {"$in": list(dst_sha1s)}}, {"_id": 0, "sha1": 1} + ) + } + + for sha1 in dst_sha1s: + if sha1 not in existing: + self.db.get_collection(dst_relation).insert_one( + dict(dst_obj, **{"sha1": sha1}) + ) + elif dst_relation == "origin": + # TODO, check origins are already in the DB + # if not, algo has something wrong (algo inserts it initially) + pass + + dst_objs = { + x["sha1"]: x["_id"] + for x in self.db.get_collection(dst_relation).find( + {"sha1": {"$in": list(dst_sha1s)}}, {"_id": 1, "sha1": 1} + ) + } + + denorm: Dict[Sha1Git, Any] = {} + for each in set_data: + if src_relation != "revision": + denorm.setdefault(each.src, {}).setdefault( + str(dst_objs[each.dst]), [] + ).append(each.path) + else: + denorm.setdefault(each.src, []).append(dst_objs[each.dst]) + + src_objs = { + x["sha1"]: x + for x in self.db.get_collection(src_relation).find( + {"sha1": {"$in": list(denorm)}} + ) + } + + for sha1, _ in denorm.items(): + if sha1 in src_objs: + # update + if src_relation != "revision": + k = { + obj_id: list(set(paths + denorm[sha1][obj_id])) + for obj_id, paths in src_objs[sha1][dst_relation].items() + } + self.db.get_collection(src_relation).update_one( + {"_id": src_objs[sha1]["_id"]}, + {"$set": {dst_relation: dict(denorm[sha1], **k)}}, + ) + else: + self.db.get_collection(src_relation).update_one( + {"_id": src_objs[sha1]["_id"]}, + { + "$set": { + dst_relation: list( + set(src_objs[sha1][dst_relation] + denorm[sha1]) + ) + } + }, + ) + else: + # add new rev + src_obj: Dict[str, Any] = {"ts": None} + if src_relation == "content": + src_obj["revision"] = {} + src_obj["directory"] = {} + if src_relation == "directory": + src_obj["revision"] = {} + if src_relation == "revision": + src_obj["preferred"] = None + src_obj["origin"] = [] + src_obj["revision"] = [] + self.db.get_collection(src_relation).insert_one( + dict(src_obj, **{"sha1": sha1, dst_relation: denorm[sha1]}) + ) + return True + + def relation_get( + self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False + ) -> Set[RelationData]: + src, *_, dst = relation.value.split("_") + sha1s = set(ids) + if not reverse: + src_objs = { + x["sha1"]: x[dst] + for x in self.db.get_collection(src).find( + {"sha1": {"$in": list(sha1s)}}, {"_id": 0, "sha1": 1, dst: 1} + ) + } + dst_ids = list( + {ObjectId(obj_id) for _, value in src_objs.items() for obj_id in value} + ) + dst_objs = { + x["sha1"]: x["_id"] + for x in self.db.get_collection(dst).find( + {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} + ) + } + if src != "revision": + return { + RelationData(src=src_sha1, dst=dst_sha1, path=path) + for src_sha1, denorm in src_objs.items() + for dst_sha1, dst_obj_id in dst_objs.items() + for dst_obj_str, paths in denorm.items() + for path in paths + if dst_obj_id == ObjectId(dst_obj_str) + } + else: + return { + RelationData(src=src_sha1, dst=dst_sha1, path=None) + for src_sha1, denorm in src_objs.items() + for dst_sha1, dst_obj_id in dst_objs.items() + for dst_obj_ref in denorm + if dst_obj_id == dst_obj_ref + } + else: + dst_objs = { + x["sha1"]: x["_id"] + for x in self.db.get_collection(dst).find( + {"sha1": {"$in": list(sha1s)}}, {"_id": 1, "sha1": 1} + ) + } + src_objs = { + x["sha1"]: x[dst] + for x in self.db.get_collection(src).find( + {}, {"_id": 0, "sha1": 1, dst: 1} + ) + } + if src != "revision": + return { + RelationData(src=src_sha1, dst=dst_sha1, path=path) + for src_sha1, denorm in src_objs.items() + for dst_sha1, dst_obj_id in dst_objs.items() + for dst_obj_str, paths in denorm.items() + for path in paths + if dst_obj_id == ObjectId(dst_obj_str) + } + else: + return { + RelationData(src=src_sha1, dst=dst_sha1, path=None) + for src_sha1, denorm in src_objs.items() + for dst_sha1, dst_obj_id in dst_objs.items() + for dst_obj_ref in denorm + if dst_obj_id == dst_obj_ref + } + + def relation_get_all(self, relation: RelationType) -> Set[RelationData]: + src, *_, dst = relation.value.split("_") + src_objs = { + x["sha1"]: x[dst] + for x in self.db.get_collection(src).find({}, {"_id": 0, "sha1": 1, dst: 1}) + } + dst_ids = list( + {ObjectId(obj_id) for _, value in src_objs.items() for obj_id in value} + ) + if src != "revision": + dst_objs = { + x["_id"]: x["sha1"] + for x in self.db.get_collection(dst).find( + {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} + ) + } + return { + RelationData(src=src_sha1, dst=dst_sha1, path=path) + for src_sha1, denorm in src_objs.items() + for dst_obj_id, dst_sha1 in dst_objs.items() + for dst_obj_str, paths in denorm.items() + for path in paths + if dst_obj_id == ObjectId(dst_obj_str) + } + else: + dst_objs = { + x["_id"]: x["sha1"] + for x in self.db.get_collection(dst).find( + {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} + ) + } + return { + RelationData(src=src_sha1, dst=dst_sha1, path=None) + for src_sha1, denorm in src_objs.items() + for dst_obj_id, dst_sha1 in dst_objs.items() + for dst_obj_ref in denorm + if dst_obj_id == dst_obj_ref + } + + def with_path(self) -> bool: + return True diff --git a/swh/provenance/postgresql/provenance.py b/swh/provenance/postgresql/provenance.py --- a/swh/provenance/postgresql/provenance.py +++ b/swh/provenance/postgresql/provenance.py @@ -268,6 +268,7 @@ SELECT sha1, date FROM {entity} WHERE sha1 IN ({values}) + AND date IS NOT NULL """ self.cursor.execute(sql, sha1s) dates.update((row["sha1"], row["date"]) for row in self.cursor.fetchall()) diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -155,8 +155,8 @@ # Origins urls should be inserted first so that internal ids' resolution works # properly. urls = { - sha1: date - for sha1, date in self.cache["origin"]["data"].items() + sha1: url + for sha1, url in self.cache["origin"]["data"].items() if sha1 in self.cache["origin"]["added"] } while not self.storage.origin_set_url(urls): @@ -280,14 +280,13 @@ updated = { id: rev.date for id, rev in self.storage.revision_get(missing_ids).items() - if rev.date is not None } else: updated = getattr(self.storage, f"{entity}_get")(missing_ids) cache["data"].update(updated) dates: Dict[Sha1Git, datetime] = {} for sha1 in ids: - date = cache["data"].get(sha1) + date = cache["data"].setdefault(sha1, None) if date is not None: dates[sha1] = date return dates diff --git a/swh/provenance/tests/conftest.py b/swh/provenance/tests/conftest.py --- a/swh/provenance/tests/conftest.py +++ b/swh/provenance/tests/conftest.py @@ -11,6 +11,7 @@ import msgpack import psycopg2.extensions import pytest +from unittest.mock import patch from pytest_postgresql.factories import postgresql from swh.journal.serializers import msgpack_ext_hook @@ -22,6 +23,7 @@ from swh.provenance.storage.archive import ArchiveStorage from swh.storage.interface import StorageInterface from swh.storage.replay import process_replay_objects +import mongomock @pytest.fixture( @@ -32,7 +34,7 @@ "without-path-denormalized", ] ) -def populated_db( +def provenance_postgresqldb( request: SubRequest, postgresql: psycopg2.extensions.connection, ) -> Dict[str, str]: @@ -47,9 +49,13 @@ # the Flask app used as server in these tests @pytest.fixture -def app(populated_db: Dict[str, str]) -> Iterator[server.ProvenanceStorageServerApp]: +def app( + provenance_postgresqldb: Dict[str, str] +) -> Iterator[server.ProvenanceStorageServerApp]: assert hasattr(server, "storage") - server.storage = get_provenance_storage(cls="postgresql", db=populated_db) + server.storage = get_provenance_storage( + cls="postgresql", db=provenance_postgresqldb + ) yield server.app @@ -59,10 +65,10 @@ return RemoteProvenanceStorage -@pytest.fixture(params=["postgresql", "remote"]) +@pytest.fixture(params=["mongodb"]) def provenance_storage( request: SubRequest, - populated_db: Dict[str, str], + provenance_postgresqldb: Dict[str, str], swh_rpc_client: RemoteProvenanceStorage, ) -> ProvenanceStorageInterface: """Return a working and initialized ProvenanceStorageInterface object""" @@ -71,10 +77,14 @@ assert isinstance(swh_rpc_client, ProvenanceStorageInterface) return swh_rpc_client + elif request.param == "mongodb": + with patch("pymongo.MongoClient.get_database") as patched_mongo: + patched_mongo.return_value = mongomock.MongoClient().get_database("test") + return get_provenance_storage(cls=request.param, db={"dbname": "test"}) else: # in test sessions, we DO want to raise any exception occurring at commit time return get_provenance_storage( - cls=request.param, db=populated_db, raise_on_commit=True + cls=request.param, db=provenance_postgresqldb, raise_on_commit=True ) diff --git a/swh/provenance/tests/test_provenance_storage.py b/swh/provenance/tests/test_provenance_storage.py --- a/swh/provenance/tests/test_provenance_storage.py +++ b/swh/provenance/tests/test_provenance_storage.py @@ -111,51 +111,6 @@ # Assuming provenance.storage has the 'with-path' flavor. assert provenance.storage.with_path() - # Test content methods. - # Add all content present in the current repo to both storages, just assigning their - # creation dates. Then check that the inserted content is the same in both cases. - cnt_dates = {cnt["sha1_git"]: cnt["ctime"] for cnt in data["content"]} - assert cnt_dates - assert provenance.storage.content_set_date( - cnt_dates - ) == provenance_storage.content_set_date(cnt_dates) - - assert provenance.storage.content_get(cnt_dates) == provenance_storage.content_get( - cnt_dates - ) - assert provenance.storage.entity_get_all( - EntityType.CONTENT - ) == provenance_storage.entity_get_all(EntityType.CONTENT) - - # Test directory methods. - # Of all directories present in the current repo, only assign a date to those - # containing blobs (picking the max date among the available ones). Then check that - # the inserted data is the same in both storages. - def getmaxdate( - dir: Dict[str, Any], cnt_dates: Dict[Sha1Git, datetime] - ) -> Optional[datetime]: - dates = [ - cnt_dates[entry["target"]] - for entry in dir["entries"] - if entry["type"] == "file" - ] - return max(dates) if dates else None - - dir_dates = {dir["id"]: getmaxdate(dir, cnt_dates) for dir in data["directory"]} - assert dir_dates - assert provenance.storage.directory_set_date( - {sha1: date for sha1, date in dir_dates.items() if date is not None} - ) == provenance_storage.directory_set_date( - {sha1: date for sha1, date in dir_dates.items() if date is not None} - ) - - assert provenance.storage.directory_get( - dir_dates - ) == provenance_storage.directory_get(dir_dates) - assert provenance.storage.entity_get_all( - EntityType.DIRECTORY - ) == provenance_storage.entity_get_all(EntityType.DIRECTORY) - # Test origin methods. # Add all origins present in the current repo to both storages. Then check that the # inserted data is the same in both cases. @@ -174,32 +129,6 @@ EntityType.ORIGIN ) == provenance_storage.entity_get_all(EntityType.ORIGIN) - # Test revision methods. - # Add all revisions present in the current repo to both storages, assigning their - # dataes and an arbitrary origin to each one. Then check that the inserted data is - # the same in both cases. - rev_dates = {rev["id"]: ts2dt(rev["date"]) for rev in data["revision"]} - assert rev_dates - assert provenance.storage.revision_set_date( - rev_dates - ) == provenance_storage.revision_set_date(rev_dates) - - rev_origins = { - rev["id"]: next(iter(org_urls)) # any arbitrary origin will do - for rev in data["revision"] - } - assert rev_origins - assert provenance.storage.revision_set_origin( - rev_origins - ) == provenance_storage.revision_set_origin(rev_origins) - - assert provenance.storage.revision_get( - rev_dates - ) == provenance_storage.revision_get(rev_dates) - assert provenance.storage.entity_get_all( - EntityType.REVISION - ) == provenance_storage.entity_get_all(EntityType.REVISION) - # Test content-in-revision relation. # Create flat models of every root directory for the revisions in the dataset. cnt_in_rev: Set[RelationData] = set() @@ -284,6 +213,76 @@ provenance_storage, ) + # Test content methods. + # Add all content present in the current repo to both storages, just assigning their + # creation dates. Then check that the inserted content is the same in both cases. + cnt_dates = {cnt["sha1_git"]: cnt["ctime"] for cnt in data["content"]} + assert cnt_dates + assert provenance.storage.content_set_date( + cnt_dates + ) == provenance_storage.content_set_date(cnt_dates) + + assert provenance.storage.content_get(cnt_dates) == provenance_storage.content_get( + cnt_dates + ) + assert provenance.storage.entity_get_all( + EntityType.CONTENT + ) == provenance_storage.entity_get_all(EntityType.CONTENT) + + # Test directory methods. + # Of all directories present in the current repo, only assign a date to those + # containing blobs (picking the max date among the available ones). Then check that + # the inserted data is the same in both storages. + def getmaxdate( + dir: Dict[str, Any], cnt_dates: Dict[Sha1Git, datetime] + ) -> Optional[datetime]: + dates = [ + cnt_dates[entry["target"]] + for entry in dir["entries"] + if entry["type"] == "file" + ] + return max(dates) if dates else None + + dir_dates = {dir["id"]: getmaxdate(dir, cnt_dates) for dir in data["directory"]} + assert dir_dates + assert provenance.storage.directory_set_date( + {sha1: date for sha1, date in dir_dates.items() if date is not None} + ) == provenance_storage.directory_set_date( + {sha1: date for sha1, date in dir_dates.items() if date is not None} + ) + assert provenance.storage.directory_get( + dir_dates + ) == provenance_storage.directory_get(dir_dates) + assert provenance.storage.entity_get_all( + EntityType.DIRECTORY + ) == provenance_storage.entity_get_all(EntityType.DIRECTORY) + + # Test revision methods. + # Add all revisions present in the current repo to both storages, assigning their + # dataes and an arbitrary origin to each one. Then check that the inserted data is + # the same in both cases. + rev_dates = {rev["id"]: ts2dt(rev["date"]) for rev in data["revision"]} + assert rev_dates + assert provenance.storage.revision_set_date( + rev_dates + ) == provenance_storage.revision_set_date(rev_dates) + + rev_origins = { + rev["id"]: next(iter(org_urls)) # any arbitrary origin will do + for rev in data["revision"] + } + assert rev_origins + assert provenance.storage.revision_set_origin( + rev_origins + ) == provenance_storage.revision_set_origin(rev_origins) + + assert provenance.storage.revision_get( + rev_dates + ) == provenance_storage.revision_get(rev_dates) + assert provenance.storage.entity_get_all( + EntityType.REVISION + ) == provenance_storage.entity_get_all(EntityType.REVISION) + # Test location_get. if provenance_storage.with_path(): assert provenance.storage.location_get() == provenance_storage.location_get()