diff --git a/mypy.ini b/mypy.ini index fbd44d7..1413db7 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,27 +1,33 @@ [mypy] namespace_packages = True warn_unused_ignores = True # 3rd party libraries without stubs (yet) +[mypy-bson.*] +ignore_missing_imports = True + [mypy-iso8601.*] ignore_missing_imports = True [mypy-methodtools.*] ignore_missing_imports = True [mypy-msgpack.*] ignore_missing_imports = True [mypy-pkg_resources.*] ignore_missing_imports = True +[mypy-pymongo.*] +ignore_missing_imports = True + [mypy-pytest.*] ignore_missing_imports = True [mypy-pytest_postgresql.*] ignore_missing_imports = True [mypy-psycopg2.*] ignore_missing_imports = True diff --git a/pytest.ini b/pytest.ini index b712d00..9634a17 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,6 @@ [pytest] norecursedirs = docs .* + +mongodb_fixture_dir = swh/provenance/tests/data/mongo +mongodb_engine = mongomock +mongodb_dbname = test diff --git a/requirements-test.txt b/requirements-test.txt index bcb961b..c530a6c 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,4 +1,5 @@ pytest +pytest-mongodb swh.loader.git >= 0.8 swh.journal >= 0.8 types-Werkzeug diff --git a/requirements.txt b/requirements.txt index 2201aff..8f39169 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,11 @@ # Add here external Python modules dependencies, one per line. Module names # should match https://pypi.python.org/pypi names. For the full spec or # dependency lines, see https://pip.readthedocs.org/en/1.1/requirements.html click iso8601 methodtools +pymongo PyYAML types-click types-PyYAML +types-Werkzeug diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py index 9dee9f7..681e476 100644 --- a/swh/provenance/__init__.py +++ b/swh/provenance/__init__.py @@ -1,98 +1,107 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from __future__ import annotations from typing import TYPE_CHECKING import warnings if TYPE_CHECKING: from .archive import ArchiveInterface from .interface import ProvenanceInterface, ProvenanceStorageInterface def get_archive(cls: str, **kwargs) -> ArchiveInterface: """Get an archive object of class ``cls`` with arguments ``args``. Args: cls: archive's class, either 'api' or 'direct' args: dictionary of arguments passed to the archive class constructor Returns: an instance of archive object (either using swh.storage API or direct queries to the archive's database) Raises: :cls:`ValueError` if passed an unknown archive class. """ if cls == "api": from swh.storage import get_storage from .storage.archive import ArchiveStorage return ArchiveStorage(get_storage(**kwargs["storage"])) elif cls == "direct": from swh.core.db import BaseDb from .postgresql.archive import ArchivePostgreSQL return ArchivePostgreSQL(BaseDb.connect(**kwargs["db"]).conn) else: raise ValueError def get_provenance(**kwargs) -> ProvenanceInterface: """Get an provenance object with arguments ``args``. Args: args: dictionary of arguments to retrieve a swh.provenance.storage class (see :func:`get_provenance_storage` for details) Returns: an instance of provenance object """ from .provenance import Provenance return Provenance(get_provenance_storage(**kwargs)) def get_provenance_storage(cls: str, **kwargs) -> ProvenanceStorageInterface: """Get an archive object of class ``cls`` with arguments ``args``. Args: cls: storage's class, only 'local' is currently supported args: dictionary of arguments passed to the storage class constructor Returns: an instance of storage object Raises: :cls:`ValueError` if passed an unknown archive class. """ if cls in ["local", "postgresql"]: from swh.core.db import BaseDb from .postgresql.provenance import ProvenanceStoragePostgreSql if cls == "local": warnings.warn( '"local" class is deprecated for provenance storage, please ' 'use "postgresql" class instead.', DeprecationWarning, ) conn = BaseDb.connect(**kwargs["db"]).conn raise_on_commit = kwargs.get("raise_on_commit", False) return ProvenanceStoragePostgreSql(conn, raise_on_commit) + elif cls == "mongodb": + from pymongo import MongoClient + + from .mongo.backend import ProvenanceStorageMongoDb + + dbname = kwargs["db"].pop("dbname") + db = MongoClient(**kwargs["db"]).get_database(dbname) + return ProvenanceStorageMongoDb(db) + elif cls == "remote": from .api.client import RemoteProvenanceStorage storage = RemoteProvenanceStorage(**kwargs) assert isinstance(storage, ProvenanceStorageInterface) return storage else: raise ValueError diff --git a/swh/provenance/cli.py b/swh/provenance/cli.py index 5734129..cd2570d 100644 --- a/swh/provenance/cli.py +++ b/swh/provenance/cli.py @@ -1,225 +1,241 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information # WARNING: do not import unnecessary things here to keep cli startup time under # control from datetime import datetime, timezone import os from typing import Any, Dict, Generator, Optional, Tuple import click import iso8601 import yaml from swh.core import config from swh.core.cli import CONTEXT_SETTINGS from swh.core.cli import swh as swh_cli_group from swh.model.hashutil import hash_to_bytes, hash_to_hex from swh.model.model import Sha1Git # All generic config code should reside in swh.core.config CONFIG_ENVVAR = "SWH_CONFIG_FILENAME" DEFAULT_PATH = os.environ.get(CONFIG_ENVVAR, None) DEFAULT_CONFIG: Dict[str, Any] = { "provenance": { "archive": { + # Storage API based Archive object # "cls": "api", # "storage": { # "cls": "remote", # "url": "http://uffizi.internal.softwareheritage.org:5002", # } + # Direct access Archive object "cls": "direct", "db": { "host": "db.internal.softwareheritage.org", "dbname": "softwareheritage", "user": "guest", }, }, "storage": { + # Local PostgreSQL Storage "cls": "postgresql", - "db": {"host": "localhost", "dbname": "provenance"}, + "db": { + "host": "localhost", + "user": "postgres", + "password": "postgres", + "dbname": "provenance", + }, + # Local MongoDB Storage + # "cls": "mongodb", + # "db": { + # "dbname": "provenance", + # }, + # Remote REST-API/PostgreSQL + # "cls": "remote", + # "url": "http://localhost:8080/%2f", }, } } CONFIG_FILE_HELP = f""" \b Configuration can be loaded from a yaml file given either as --config-file option or the {CONFIG_ENVVAR} environment variable. If no configuration file is specified, use the following default configuration:: \b {yaml.dump(DEFAULT_CONFIG)}""" PROVENANCE_HELP = f"""Software Heritage provenance index database tools {CONFIG_FILE_HELP} """ @swh_cli_group.group( name="provenance", context_settings=CONTEXT_SETTINGS, help=PROVENANCE_HELP ) @click.option( "-C", "--config-file", default=None, type=click.Path(exists=True, dir_okay=False, path_type=str), help="""YAML configuration file.""", ) @click.option( "-P", "--profile", default=None, type=click.Path(exists=False, dir_okay=False, path_type=str), help="""Enable profiling to specified file.""", ) @click.pass_context def cli(ctx: click.core.Context, config_file: Optional[str], profile: str) -> None: if ( config_file is None and DEFAULT_PATH is not None and config.config_exists(DEFAULT_PATH) ): config_file = DEFAULT_PATH if config_file is None: conf = DEFAULT_CONFIG else: # read_raw_config do not fail on ENOENT if not os.path.exists(config_file): raise FileNotFoundError(config_file) conf = yaml.safe_load(open(config_file, "rb")) ctx.ensure_object(dict) ctx.obj["config"] = conf if profile: import atexit import cProfile print("Profiling...") pr = cProfile.Profile() pr.enable() def exit() -> None: pr.disable() pr.dump_stats(profile) atexit.register(exit) @cli.command(name="iter-revisions") @click.argument("filename") @click.option("-a", "--track-all", default=True, type=bool) @click.option("-l", "--limit", type=int) @click.option("-m", "--min-depth", default=1, type=int) @click.option("-r", "--reuse", default=True, type=bool) @click.pass_context def iter_revisions( ctx: click.core.Context, filename: str, track_all: bool, limit: Optional[int], min_depth: int, reuse: bool, ) -> None: # TODO: add file size filtering """Process a provided list of revisions.""" from . import get_archive, get_provenance from .revision import CSVRevisionIterator, revision_add archive = get_archive(**ctx.obj["config"]["provenance"]["archive"]) provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"]) revisions_provider = generate_revision_tuples(filename) revisions = CSVRevisionIterator(revisions_provider, limit=limit) for revision in revisions: revision_add( provenance, archive, [revision], trackall=track_all, lower=reuse, mindepth=min_depth, ) def generate_revision_tuples( filename: str, ) -> Generator[Tuple[Sha1Git, datetime, Sha1Git], None, None]: for line in open(filename, "r"): if line.strip(): revision, date, root = line.strip().split(",") yield ( hash_to_bytes(revision), iso8601.parse_date(date, default_timezone=timezone.utc), hash_to_bytes(root), ) @cli.command(name="iter-origins") @click.argument("filename") @click.option("-l", "--limit", type=int) @click.pass_context def iter_origins(ctx: click.core.Context, filename: str, limit: Optional[int]) -> None: """Process a provided list of origins.""" from . import get_archive, get_provenance from .origin import CSVOriginIterator, origin_add archive = get_archive(**ctx.obj["config"]["provenance"]["archive"]) provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"]) origins_provider = generate_origin_tuples(filename) origins = CSVOriginIterator(origins_provider, limit=limit) for origin in origins: origin_add(provenance, archive, [origin]) def generate_origin_tuples(filename: str) -> Generator[Tuple[str, bytes], None, None]: for line in open(filename, "r"): if line.strip(): url, snapshot = line.strip().split(",") yield (url, hash_to_bytes(snapshot)) @cli.command(name="find-first") @click.argument("swhid") @click.pass_context def find_first(ctx: click.core.Context, swhid: str) -> None: """Find first occurrence of the requested blob.""" from . import get_provenance provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"]) occur = provenance.content_find_first(hash_to_bytes(swhid)) if occur is not None: print( f"swh:1:cnt:{hash_to_hex(occur.content)}, " f"swh:1:rev:{hash_to_hex(occur.revision)}, " f"{occur.date}, " f"{occur.origin}, " f"{os.fsdecode(occur.path)}" ) else: print(f"Cannot find a content with the id {swhid}") @cli.command(name="find-all") @click.argument("swhid") @click.option("-l", "--limit", type=int) @click.pass_context def find_all(ctx: click.core.Context, swhid: str, limit: Optional[int]) -> None: """Find all occurrences of the requested blob.""" from . import get_provenance provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"]) for occur in provenance.content_find_all(hash_to_bytes(swhid), limit=limit): print( f"swh:1:cnt:{hash_to_hex(occur.content)}, " f"swh:1:rev:{hash_to_hex(occur.revision)}, " f"{occur.date}, " f"{occur.origin}, " f"{os.fsdecode(occur.path)}" ) diff --git a/swh/provenance/mongo/README.md b/swh/provenance/mongo/README.md new file mode 100644 index 0000000..da25c85 --- /dev/null +++ b/swh/provenance/mongo/README.md @@ -0,0 +1,44 @@ +mongo backend +============= + +Provenance storage implementation using MongoDB + +initial data-model +------------------ + +```json +content +{ + id: sha1 + ts: int //optional + revision: {: []} + directory: {: []} +} + +directory +{ + id: sha1 + ts: int //optional + revision: {: []} +} + +revision +{ + id: sha1 + ts: int // optional + preferred //optinal + origin [] + revision [] +} + +origin +{ + id: sha1 + url: str +} + +path +{ + path: str +} +``` diff --git a/swh/provenance/mongo/__init__.py b/swh/provenance/mongo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/swh/provenance/mongo/backend.py b/swh/provenance/mongo/backend.py new file mode 100644 index 0000000..acff329 --- /dev/null +++ b/swh/provenance/mongo/backend.py @@ -0,0 +1,488 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from datetime import datetime, timezone +import os +from typing import Any, Dict, Generator, Iterable, List, Optional, Set + +from bson import ObjectId +import pymongo.database + +from swh.model.model import Sha1Git + +from ..interface import ( + EntityType, + ProvenanceResult, + RelationData, + RelationType, + RevisionData, +) + + +class ProvenanceStorageMongoDb: + def __init__(self, db: pymongo.database.Database): + self.db = db + + def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: + # get all the revisions + # iterate and find the earliest + content = self.db.content.find_one({"sha1": id}) + if not content: + return None + + occurs = [] + for revision in self.db.revision.find( + {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["revision"]]}} + ): + origin = self.db.origin.find_one({"sha1": revision["preferred"]}) + assert origin is not None + + for path in content["revision"][str(revision["_id"])]: + occurs.append( + ProvenanceResult( + content=id, + revision=revision["sha1"], + date=datetime.fromtimestamp(revision["ts"], timezone.utc), + origin=origin["url"], + path=path, + ) + ) + return sorted(occurs, key=lambda x: (x.date, x.revision, x.origin, x.path))[0] + + def content_find_all( + self, id: Sha1Git, limit: Optional[int] = None + ) -> Generator[ProvenanceResult, None, None]: + content = self.db.content.find_one({"sha1": id}) + if not content: + return None + + occurs = [] + for revision in self.db.revision.find( + {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["revision"]]}} + ): + origin = self.db.origin.find_one({"sha1": revision["preferred"]}) + assert origin is not None + + for path in content["revision"][str(revision["_id"])]: + occurs.append( + ProvenanceResult( + content=id, + revision=revision["sha1"], + date=datetime.fromtimestamp(revision["ts"], timezone.utc), + origin=origin["url"], + path=path, + ) + ) + for directory in self.db.directory.find( + {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["directory"]]}} + ): + for revision in self.db.revision.find( + {"_id": {"$in": [ObjectId(obj_id) for obj_id in directory["revision"]]}} + ): + origin = self.db.origin.find_one({"sha1": revision["preferred"]}) + assert origin is not None + + for suffix in content["directory"][str(directory["_id"])]: + for prefix in directory["revision"][str(revision["_id"])]: + path = ( + os.path.join(prefix, suffix) + if prefix not in [b".", b""] + else suffix + ) + occurs.append( + ProvenanceResult( + content=id, + revision=revision["sha1"], + date=datetime.fromtimestamp( + revision["ts"], timezone.utc + ), + origin=origin["url"], + path=path, + ) + ) + yield from sorted(occurs, key=lambda x: (x.date, x.revision, x.origin, x.path)) + + def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: + return { + x["sha1"]: datetime.fromtimestamp(x["ts"], timezone.utc) + for x in self.db.content.find( + {"sha1": {"$in": list(ids)}, "ts": {"$ne": None}}, + {"sha1": 1, "ts": 1, "_id": 0}, + ) + } + + def content_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: + # get all the docuemtns with the id, add date, add missing records + cnts = { + x["sha1"]: x + for x in self.db.content.find( + {"sha1": {"$in": list(dates)}}, {"sha1": 1, "ts": 1, "_id": 1} + ) + } + + for sha1, date in dates.items(): + ts = datetime.timestamp(date) + if sha1 in cnts: + # update + if cnts[sha1]["ts"] is None or ts < cnts[sha1]["ts"]: + self.db.content.update_one( + {"_id": cnts[sha1]["_id"]}, {"$set": {"ts": ts}} + ) + else: + # add new content + self.db.content.insert_one( + { + "sha1": sha1, + "ts": ts, + "revision": {}, + "directory": {}, + } + ) + return True + + def directory_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: + dirs = { + x["sha1"]: x + for x in self.db.directory.find( + {"sha1": {"$in": list(dates)}}, {"sha1": 1, "ts": 1, "_id": 1} + ) + } + for sha1, date in dates.items(): + ts = datetime.timestamp(date) + if sha1 in dirs: + # update + if dirs[sha1]["ts"] is None or ts < dirs[sha1]["ts"]: + self.db.directory.update_one( + {"_id": dirs[sha1]["_id"]}, {"$set": {"ts": ts}} + ) + else: + # add new dir + self.db.directory.insert_one({"sha1": sha1, "ts": ts, "revision": {}}) + return True + + def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: + return { + x["sha1"]: datetime.fromtimestamp(x["ts"], timezone.utc) + for x in self.db.directory.find( + {"sha1": {"$in": list(ids)}, "ts": {"$ne": None}}, + {"sha1": 1, "ts": 1, "_id": 0}, + ) + } + + def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: + return { + x["sha1"] + for x in self.db.get_collection(entity.value).find( + {}, {"sha1": 1, "_id": 0} + ) + } + + def location_get(self) -> Set[bytes]: + contents = self.db.content.find({}, {"revision": 1, "_id": 0, "directory": 1}) + paths: List[Iterable[bytes]] = [] + for content in contents: + paths.extend(value for _, value in content["revision"].items()) + paths.extend(value for _, value in content["directory"].items()) + + dirs = self.db.directory.find({}, {"revision": 1, "_id": 0}) + for each_dir in dirs: + paths.extend(value for _, value in each_dir["revision"].items()) + return set(sum(paths, [])) + + def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool: + origins = { + x["sha1"]: x + for x in self.db.origin.find( + {"sha1": {"$in": list(urls)}}, {"sha1": 1, "url": 1, "_id": 1} + ) + } + for sha1, url in urls.items(): + if sha1 not in origins: + # add new origin + self.db.origin.insert_one({"sha1": sha1, "url": url}) + return True + + def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: + return { + x["sha1"]: x["url"] + for x in self.db.origin.find( + {"sha1": {"$in": list(ids)}}, {"sha1": 1, "url": 1, "_id": 0} + ) + } + + def revision_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: + revs = { + x["sha1"]: x + for x in self.db.revision.find( + {"sha1": {"$in": list(dates)}}, {"sha1": 1, "ts": 1, "_id": 1} + ) + } + for sha1, date in dates.items(): + ts = datetime.timestamp(date) + if sha1 in revs: + # update + if revs[sha1]["ts"] is None or ts < revs[sha1]["ts"]: + self.db.revision.update_one( + {"_id": revs[sha1]["_id"]}, {"$set": {"ts": ts}} + ) + else: + # add new rev + self.db.revision.insert_one( + { + "sha1": sha1, + "preferred": None, + "origin": [], + "revision": [], + "ts": ts, + } + ) + return True + + def revision_set_origin(self, origins: Dict[Sha1Git, Sha1Git]) -> bool: + revs = { + x["sha1"]: x + for x in self.db.revision.find( + {"sha1": {"$in": list(origins)}}, {"sha1": 1, "preferred": 1, "_id": 1} + ) + } + for sha1, origin in origins.items(): + if sha1 in revs: + self.db.revision.update_one( + {"_id": revs[sha1]["_id"]}, {"$set": {"preferred": origin}} + ) + else: + # add new rev + self.db.revision.insert_one( + { + "sha1": sha1, + "preferred": origin, + "origin": [], + "revision": [], + "ts": None, + } + ) + return True + + def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: + return { + x["sha1"]: RevisionData( + date=datetime.fromtimestamp(x["ts"], timezone.utc) if x["ts"] else None, + origin=x["preferred"], + ) + for x in self.db.revision.find( + {"sha1": {"$in": list(ids)}}, + {"sha1": 1, "preferred": 1, "ts": 1, "_id": 0}, + ) + } + + def relation_add( + self, relation: RelationType, data: Iterable[RelationData] + ) -> bool: + src_relation, *_, dst_relation = relation.value.split("_") + set_data = set(data) + + dst_sha1s = {x.dst for x in data} + if dst_relation in ["content", "directory", "revision"]: + dst_obj: Dict[str, Any] = {"ts": None} + if dst_relation == "content": + dst_obj["revision"] = {} + dst_obj["directory"] = {} + if dst_relation == "directory": + dst_obj["revision"] = {} + if dst_relation == "revision": + dst_obj["preferred"] = None + dst_obj["origin"] = [] + dst_obj["revision"] = [] + + existing = { + x["sha1"] + for x in self.db.get_collection(dst_relation).find( + {"sha1": {"$in": list(dst_sha1s)}}, {"_id": 0, "sha1": 1} + ) + } + + for sha1 in dst_sha1s: + if sha1 not in existing: + self.db.get_collection(dst_relation).insert_one( + dict(dst_obj, **{"sha1": sha1}) + ) + elif dst_relation == "origin": + # TODO, check origins are already in the DB + # if not, algo has something wrong (algo inserts it initially) + pass + + dst_objs = { + x["sha1"]: x["_id"] + for x in self.db.get_collection(dst_relation).find( + {"sha1": {"$in": list(dst_sha1s)}}, {"_id": 1, "sha1": 1} + ) + } + + denorm: Dict[Sha1Git, Any] = {} + for each in set_data: + if src_relation != "revision": + denorm.setdefault(each.src, {}).setdefault( + str(dst_objs[each.dst]), [] + ).append(each.path) + else: + denorm.setdefault(each.src, []).append(dst_objs[each.dst]) + + src_objs = { + x["sha1"]: x + for x in self.db.get_collection(src_relation).find( + {"sha1": {"$in": list(denorm)}} + ) + } + + for sha1, _ in denorm.items(): + if sha1 in src_objs: + # update + if src_relation != "revision": + k = { + obj_id: list(set(paths + denorm[sha1][obj_id])) + for obj_id, paths in src_objs[sha1][dst_relation].items() + } + self.db.get_collection(src_relation).update_one( + {"_id": src_objs[sha1]["_id"]}, + {"$set": {dst_relation: dict(denorm[sha1], **k)}}, + ) + else: + self.db.get_collection(src_relation).update_one( + {"_id": src_objs[sha1]["_id"]}, + { + "$set": { + dst_relation: list( + set(src_objs[sha1][dst_relation] + denorm[sha1]) + ) + } + }, + ) + else: + # add new rev + src_obj: Dict[str, Any] = {"ts": None} + if src_relation == "content": + src_obj["revision"] = {} + src_obj["directory"] = {} + if src_relation == "directory": + src_obj["revision"] = {} + if src_relation == "revision": + src_obj["preferred"] = None + src_obj["origin"] = [] + src_obj["revision"] = [] + self.db.get_collection(src_relation).insert_one( + dict(src_obj, **{"sha1": sha1, dst_relation: denorm[sha1]}) + ) + return True + + def relation_get( + self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False + ) -> Set[RelationData]: + src, *_, dst = relation.value.split("_") + sha1s = set(ids) + if not reverse: + src_objs = { + x["sha1"]: x[dst] + for x in self.db.get_collection(src).find( + {"sha1": {"$in": list(sha1s)}}, {"_id": 0, "sha1": 1, dst: 1} + ) + } + dst_ids = list( + {ObjectId(obj_id) for _, value in src_objs.items() for obj_id in value} + ) + dst_objs = { + x["sha1"]: x["_id"] + for x in self.db.get_collection(dst).find( + {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} + ) + } + if src != "revision": + return { + RelationData(src=src_sha1, dst=dst_sha1, path=path) + for src_sha1, denorm in src_objs.items() + for dst_sha1, dst_obj_id in dst_objs.items() + for dst_obj_str, paths in denorm.items() + for path in paths + if dst_obj_id == ObjectId(dst_obj_str) + } + else: + return { + RelationData(src=src_sha1, dst=dst_sha1, path=None) + for src_sha1, denorm in src_objs.items() + for dst_sha1, dst_obj_id in dst_objs.items() + for dst_obj_ref in denorm + if dst_obj_id == dst_obj_ref + } + else: + dst_objs = { + x["sha1"]: x["_id"] + for x in self.db.get_collection(dst).find( + {"sha1": {"$in": list(sha1s)}}, {"_id": 1, "sha1": 1} + ) + } + src_objs = { + x["sha1"]: x[dst] + for x in self.db.get_collection(src).find( + {}, {"_id": 0, "sha1": 1, dst: 1} + ) + } + if src != "revision": + return { + RelationData(src=src_sha1, dst=dst_sha1, path=path) + for src_sha1, denorm in src_objs.items() + for dst_sha1, dst_obj_id in dst_objs.items() + for dst_obj_str, paths in denorm.items() + for path in paths + if dst_obj_id == ObjectId(dst_obj_str) + } + else: + return { + RelationData(src=src_sha1, dst=dst_sha1, path=None) + for src_sha1, denorm in src_objs.items() + for dst_sha1, dst_obj_id in dst_objs.items() + for dst_obj_ref in denorm + if dst_obj_id == dst_obj_ref + } + + def relation_get_all(self, relation: RelationType) -> Set[RelationData]: + src, *_, dst = relation.value.split("_") + src_objs = { + x["sha1"]: x[dst] + for x in self.db.get_collection(src).find({}, {"_id": 0, "sha1": 1, dst: 1}) + } + dst_ids = list( + {ObjectId(obj_id) for _, value in src_objs.items() for obj_id in value} + ) + if src != "revision": + dst_objs = { + x["_id"]: x["sha1"] + for x in self.db.get_collection(dst).find( + {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} + ) + } + return { + RelationData(src=src_sha1, dst=dst_sha1, path=path) + for src_sha1, denorm in src_objs.items() + for dst_obj_id, dst_sha1 in dst_objs.items() + for dst_obj_str, paths in denorm.items() + for path in paths + if dst_obj_id == ObjectId(dst_obj_str) + } + else: + dst_objs = { + x["_id"]: x["sha1"] + for x in self.db.get_collection(dst).find( + {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} + ) + } + return { + RelationData(src=src_sha1, dst=dst_sha1, path=None) + for src_sha1, denorm in src_objs.items() + for dst_obj_id, dst_sha1 in dst_objs.items() + for dst_obj_ref in denorm + if dst_obj_id == dst_obj_ref + } + + def with_path(self) -> bool: + return True diff --git a/swh/provenance/postgresql/provenance.py b/swh/provenance/postgresql/provenance.py index 7b7df66..108beef 100644 --- a/swh/provenance/postgresql/provenance.py +++ b/swh/provenance/postgresql/provenance.py @@ -1,373 +1,374 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime import itertools import logging from typing import Dict, Generator, Iterable, Optional, Set, Tuple import psycopg2.extensions import psycopg2.extras from typing_extensions import Literal from swh.core.db import BaseDb from swh.model.model import Sha1Git from ..interface import ( EntityType, ProvenanceResult, RelationData, RelationType, RevisionData, ) class ProvenanceStoragePostgreSql: def __init__( self, conn: psycopg2.extensions.connection, raise_on_commit: bool = False ) -> None: BaseDb.adapt_conn(conn) conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) conn.set_session(autocommit=True) self.conn = conn self.cursor = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) # XXX: not sure this is the best place to do it! sql = "SET timezone TO 'UTC'" self.cursor.execute(sql) self._flavor: Optional[str] = None self.raise_on_commit = raise_on_commit @property def flavor(self) -> str: if self._flavor is None: sql = "SELECT swh_get_dbflavor() AS flavor" self.cursor.execute(sql) self._flavor = self.cursor.fetchone()["flavor"] assert self._flavor is not None return self._flavor @property def denormalized(self) -> bool: return "denormalized" in self.flavor def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: sql = "SELECT * FROM swh_provenance_content_find_first(%s)" self.cursor.execute(sql, (id,)) row = self.cursor.fetchone() return ProvenanceResult(**row) if row is not None else None def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: sql = "SELECT * FROM swh_provenance_content_find_all(%s, %s)" self.cursor.execute(sql, (id, limit)) yield from (ProvenanceResult(**row) for row in self.cursor.fetchall()) def content_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: return self._entity_set_date("content", dates) def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self._entity_get_date("content", ids) def directory_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: return self._entity_set_date("directory", dates) def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self._entity_get_date("directory", ids) def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: sql = f"SELECT sha1 FROM {entity.value}" self.cursor.execute(sql) return {row["sha1"] for row in self.cursor.fetchall()} def location_get(self) -> Set[bytes]: sql = "SELECT location.path AS path FROM location" self.cursor.execute(sql) return {row["path"] for row in self.cursor.fetchall()} def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool: try: if urls: sql = """ LOCK TABLE ONLY origin; INSERT INTO origin(sha1, url) VALUES %s ON CONFLICT DO NOTHING """ psycopg2.extras.execute_values(self.cursor, sql, urls.items()) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message logging.exception("Unexpected error") if self.raise_on_commit: raise return False def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: urls: Dict[Sha1Git, str] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, url FROM origin WHERE sha1 IN ({values}) """ self.cursor.execute(sql, sha1s) urls.update((row["sha1"], row["url"]) for row in self.cursor.fetchall()) return urls def revision_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: return self._entity_set_date("revision", dates) def revision_set_origin(self, origins: Dict[Sha1Git, Sha1Git]) -> bool: try: if origins: sql = """ LOCK TABLE ONLY revision; INSERT INTO revision(sha1, origin) (SELECT V.rev AS sha1, O.id AS origin FROM (VALUES %s) AS V(rev, org) JOIN origin AS O ON (O.sha1=V.org)) ON CONFLICT (sha1) DO UPDATE SET origin=EXCLUDED.origin """ psycopg2.extras.execute_values(self.cursor, sql, origins.items()) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message logging.exception("Unexpected error") if self.raise_on_commit: raise return False def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: result: Dict[Sha1Git, RevisionData] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT R.sha1, R.date, O.sha1 AS origin FROM revision AS R LEFT JOIN origin AS O ON (O.id=R.origin) WHERE R.sha1 IN ({values}) """ self.cursor.execute(sql, sha1s) result.update( (row["sha1"], RevisionData(date=row["date"], origin=row["origin"])) for row in self.cursor.fetchall() ) return result def relation_add( self, relation: RelationType, data: Iterable[RelationData] ) -> bool: try: rows = tuple((rel.src, rel.dst, rel.path) for rel in data) if rows: table = relation.value src, *_, dst = table.split("_") if src != "origin": # Origin entries should be inserted previously as they require extra # non-null information srcs = tuple(set((sha1,) for (sha1, _, _) in rows)) sql = f""" LOCK TABLE ONLY {src}; INSERT INTO {src}(sha1) VALUES %s ON CONFLICT DO NOTHING """ psycopg2.extras.execute_values(self.cursor, sql, srcs) if dst != "origin": # Origin entries should be inserted previously as they require extra # non-null information dsts = tuple(set((sha1,) for (_, sha1, _) in rows)) sql = f""" LOCK TABLE ONLY {dst}; INSERT INTO {dst}(sha1) VALUES %s ON CONFLICT DO NOTHING """ psycopg2.extras.execute_values(self.cursor, sql, dsts) joins = [ f"INNER JOIN {src} AS S ON (S.sha1=V.src)", f"INNER JOIN {dst} AS D ON (D.sha1=V.dst)", ] nope = (RelationType.REV_BEFORE_REV, RelationType.REV_IN_ORG) selected = ["S.id"] if self.denormalized and relation not in nope: selected.append("ARRAY_AGG(D.id)") else: selected.append("D.id") if self._relation_uses_location_table(relation): locations = tuple(set((path,) for (_, _, path) in rows)) sql = """ LOCK TABLE ONLY location; INSERT INTO location(path) VALUES %s ON CONFLICT (path) DO NOTHING """ psycopg2.extras.execute_values(self.cursor, sql, locations) joins.append("INNER JOIN location AS L ON (L.path=V.path)") if self.denormalized: selected.append("ARRAY_AGG(L.id)") else: selected.append("L.id") sql_l = [ f"INSERT INTO {table}", f" SELECT {', '.join(selected)}", " FROM (VALUES %s) AS V(src, dst, path)", *joins, ] if self.denormalized and relation not in nope: sql_l.append("GROUP BY S.id") sql_l.append( f"""ON CONFLICT ({src}) DO UPDATE SET {dst}=ARRAY( SELECT UNNEST({table}.{dst} || EXCLUDED.{dst}) ), location=ARRAY( SELECT UNNEST({relation.value}.location || EXCLUDED.location) ) """ ) else: sql_l.append("ON CONFLICT DO NOTHING") sql = "\n".join(sql_l) psycopg2.extras.execute_values(self.cursor, sql, rows) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message logging.exception("Unexpected error") if self.raise_on_commit: raise return False def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Set[RelationData]: return self._relation_get(relation, ids, reverse) def relation_get_all(self, relation: RelationType) -> Set[RelationData]: return self._relation_get(relation, None) def _entity_get_date( self, entity: Literal["content", "directory", "revision"], ids: Iterable[Sha1Git], ) -> Dict[Sha1Git, datetime]: dates: Dict[Sha1Git, datetime] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, date FROM {entity} WHERE sha1 IN ({values}) + AND date IS NOT NULL """ self.cursor.execute(sql, sha1s) dates.update((row["sha1"], row["date"]) for row in self.cursor.fetchall()) return dates def _entity_set_date( self, entity: Literal["content", "directory", "revision"], data: Dict[Sha1Git, datetime], ) -> bool: try: if data: sql = f""" LOCK TABLE ONLY {entity}; INSERT INTO {entity}(sha1, date) VALUES %s ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date,{entity}.date) """ psycopg2.extras.execute_values(self.cursor, sql, data.items()) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message logging.exception("Unexpected error") if self.raise_on_commit: raise return False def _relation_get( self, relation: RelationType, ids: Optional[Iterable[Sha1Git]], reverse: bool = False, ) -> Set[RelationData]: result: Set[RelationData] = set() sha1s: Optional[Tuple[Tuple[Sha1Git, ...]]] if ids is not None: sha1s = (tuple(ids),) where = f"WHERE {'S' if not reverse else 'D'}.sha1 IN %s" else: sha1s = None where = "" aggreg_dst = self.denormalized and relation in ( RelationType.CNT_EARLY_IN_REV, RelationType.CNT_IN_DIR, RelationType.DIR_IN_REV, ) if sha1s is None or sha1s[0]: table = relation.value src, *_, dst = table.split("_") # TODO: improve this! if src == "revision" and dst == "revision": src_field = "prev" dst_field = "next" else: src_field = src dst_field = dst if aggreg_dst: revloc = f"UNNEST(R.{dst_field}) AS dst" if self._relation_uses_location_table(relation): revloc += ", UNNEST(R.location) AS path" else: revloc = f"R.{dst_field} AS dst" if self._relation_uses_location_table(relation): revloc += ", R.location AS path" inner_sql = f""" SELECT S.sha1 AS src, {revloc} FROM {table} AS R INNER JOIN {src} AS S ON (S.id=R.{src_field}) """ if where != "" and not reverse: inner_sql += where if self._relation_uses_location_table(relation): loc = "L.path AS path" else: loc = "NULL AS path" sql = f""" SELECT CL.src, D.sha1 AS dst, {loc} FROM ({inner_sql}) AS CL INNER JOIN {dst} AS D ON (D.id=CL.dst) """ if self._relation_uses_location_table(relation): sql += "INNER JOIN location AS L ON (L.id=CL.path)" if where != "" and reverse: sql += where self.cursor.execute(sql, sha1s) result.update(RelationData(**row) for row in self.cursor.fetchall()) return result def _relation_uses_location_table(self, relation: RelationType) -> bool: if self.with_path(): src = relation.value.split("_")[0] return src in ("content", "directory") return False def with_path(self) -> bool: return "with-path" in self.flavor diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py index c65c4bf..f30cb02 100644 --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -1,348 +1,347 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime import logging import os from typing import Dict, Generator, Iterable, Optional, Set, Tuple from typing_extensions import Literal, TypedDict from swh.model.model import Sha1Git from .interface import ( ProvenanceResult, ProvenanceStorageInterface, RelationData, RelationType, ) from .model import DirectoryEntry, FileEntry, OriginEntry, RevisionEntry class DatetimeCache(TypedDict): data: Dict[Sha1Git, Optional[datetime]] added: Set[Sha1Git] class OriginCache(TypedDict): data: Dict[Sha1Git, str] added: Set[Sha1Git] class RevisionCache(TypedDict): data: Dict[Sha1Git, Sha1Git] added: Set[Sha1Git] class ProvenanceCache(TypedDict): content: DatetimeCache directory: DatetimeCache revision: DatetimeCache # below are insertion caches only content_in_revision: Set[Tuple[Sha1Git, Sha1Git, bytes]] content_in_directory: Set[Tuple[Sha1Git, Sha1Git, bytes]] directory_in_revision: Set[Tuple[Sha1Git, Sha1Git, bytes]] # these two are for the origin layer origin: OriginCache revision_origin: RevisionCache revision_before_revision: Dict[Sha1Git, Set[Sha1Git]] revision_in_origin: Set[Tuple[Sha1Git, Sha1Git]] def new_cache() -> ProvenanceCache: return ProvenanceCache( content=DatetimeCache(data={}, added=set()), directory=DatetimeCache(data={}, added=set()), revision=DatetimeCache(data={}, added=set()), content_in_revision=set(), content_in_directory=set(), directory_in_revision=set(), origin=OriginCache(data={}, added=set()), revision_origin=RevisionCache(data={}, added=set()), revision_before_revision={}, revision_in_origin=set(), ) class Provenance: def __init__(self, storage: ProvenanceStorageInterface) -> None: self.storage = storage self.cache = new_cache() def clear_caches(self) -> None: self.cache = new_cache() def flush(self) -> None: # Revision-content layer insertions ############################################ # For this layer, relations need to be inserted first so that, in case of # failure, reprocessing the input does not generated an inconsistent database. while not self.storage.relation_add( RelationType.CNT_EARLY_IN_REV, ( RelationData(src=src, dst=dst, path=path) for src, dst, path in self.cache["content_in_revision"] ), ): logging.warning( f"Unable to write {RelationType.CNT_EARLY_IN_REV} rows to the storage. " f"Data: {self.cache['content_in_revision']}. Retrying..." ) while not self.storage.relation_add( RelationType.CNT_IN_DIR, ( RelationData(src=src, dst=dst, path=path) for src, dst, path in self.cache["content_in_directory"] ), ): logging.warning( f"Unable to write {RelationType.CNT_IN_DIR} rows to the storage. " f"Data: {self.cache['content_in_directory']}. Retrying..." ) while not self.storage.relation_add( RelationType.DIR_IN_REV, ( RelationData(src=src, dst=dst, path=path) for src, dst, path in self.cache["directory_in_revision"] ), ): logging.warning( f"Unable to write {RelationType.DIR_IN_REV} rows to the storage. " f"Data: {self.cache['directory_in_revision']}. Retrying..." ) # After relations, dates for the entities can be safely set, acknowledging that # these entities won't need to be reprocessed in case of failure. dates = { sha1: date for sha1, date in self.cache["content"]["data"].items() if sha1 in self.cache["content"]["added"] and date is not None } while not self.storage.content_set_date(dates): logging.warning( f"Unable to write content dates to the storage. " f"Data: {dates}. Retrying..." ) dates = { sha1: date for sha1, date in self.cache["directory"]["data"].items() if sha1 in self.cache["directory"]["added"] and date is not None } while not self.storage.directory_set_date(dates): logging.warning( f"Unable to write directory dates to the storage. " f"Data: {dates}. Retrying..." ) dates = { sha1: date for sha1, date in self.cache["revision"]["data"].items() if sha1 in self.cache["revision"]["added"] and date is not None } while not self.storage.revision_set_date(dates): logging.warning( f"Unable to write revision dates to the storage. " f"Data: {dates}. Retrying..." ) # Origin-revision layer insertions ############################################# # Origins urls should be inserted first so that internal ids' resolution works # properly. urls = { - sha1: date - for sha1, date in self.cache["origin"]["data"].items() + sha1: url + for sha1, url in self.cache["origin"]["data"].items() if sha1 in self.cache["origin"]["added"] } while not self.storage.origin_set_url(urls): logging.warning( f"Unable to write origins urls to the storage. " f"Data: {urls}. Retrying..." ) # Second, flat models for revisions' histories (ie. revision-before-revision). data: Iterable[RelationData] = sum( [ [ RelationData(src=prev, dst=next, path=None) for next in self.cache["revision_before_revision"][prev] ] for prev in self.cache["revision_before_revision"] ], [], ) while not self.storage.relation_add(RelationType.REV_BEFORE_REV, data): logging.warning( f"Unable to write {RelationType.REV_BEFORE_REV} rows to the storage. " f"Data: {data}. Retrying..." ) # Heads (ie. revision-in-origin entries) should be inserted once flat models for # their histories were already added. This is to guarantee consistent results if # something needs to be reprocessed due to a failure: already inserted heads # won't get reprocessed in such a case. data = ( RelationData(src=rev, dst=org, path=None) for rev, org in self.cache["revision_in_origin"] ) while not self.storage.relation_add(RelationType.REV_IN_ORG, data): logging.warning( f"Unable to write {RelationType.REV_IN_ORG} rows to the storage. " f"Data: {data}. Retrying..." ) # Finally, preferred origins for the visited revisions are set (this step can be # reordered if required). origins = { sha1: self.cache["revision_origin"]["data"][sha1] for sha1 in self.cache["revision_origin"]["added"] } while not self.storage.revision_set_origin(origins): logging.warning( f"Unable to write preferred origins to the storage. " f"Data: {origins}. Retrying..." ) # clear local cache ############################################################ self.clear_caches() def content_add_to_directory( self, directory: DirectoryEntry, blob: FileEntry, prefix: bytes ) -> None: self.cache["content_in_directory"].add( (blob.id, directory.id, normalize(os.path.join(prefix, blob.name))) ) def content_add_to_revision( self, revision: RevisionEntry, blob: FileEntry, prefix: bytes ) -> None: self.cache["content_in_revision"].add( (blob.id, revision.id, normalize(os.path.join(prefix, blob.name))) ) def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: return self.storage.content_find_first(id) def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: yield from self.storage.content_find_all(id, limit=limit) def content_get_early_date(self, blob: FileEntry) -> Optional[datetime]: return self.get_dates("content", [blob.id]).get(blob.id) def content_get_early_dates( self, blobs: Iterable[FileEntry] ) -> Dict[Sha1Git, datetime]: return self.get_dates("content", [blob.id for blob in blobs]) def content_set_early_date(self, blob: FileEntry, date: datetime) -> None: self.cache["content"]["data"][blob.id] = date self.cache["content"]["added"].add(blob.id) def directory_add_to_revision( self, revision: RevisionEntry, directory: DirectoryEntry, path: bytes ) -> None: self.cache["directory_in_revision"].add( (directory.id, revision.id, normalize(path)) ) def directory_get_date_in_isochrone_frontier( self, directory: DirectoryEntry ) -> Optional[datetime]: return self.get_dates("directory", [directory.id]).get(directory.id) def directory_get_dates_in_isochrone_frontier( self, dirs: Iterable[DirectoryEntry] ) -> Dict[Sha1Git, datetime]: return self.get_dates("directory", [directory.id for directory in dirs]) def directory_set_date_in_isochrone_frontier( self, directory: DirectoryEntry, date: datetime ) -> None: self.cache["directory"]["data"][directory.id] = date self.cache["directory"]["added"].add(directory.id) def get_dates( self, entity: Literal["content", "directory", "revision"], ids: Iterable[Sha1Git], ) -> Dict[Sha1Git, datetime]: cache = self.cache[entity] missing_ids = set(id for id in ids if id not in cache) if missing_ids: if entity == "revision": updated = { id: rev.date for id, rev in self.storage.revision_get(missing_ids).items() - if rev.date is not None } else: updated = getattr(self.storage, f"{entity}_get")(missing_ids) cache["data"].update(updated) dates: Dict[Sha1Git, datetime] = {} for sha1 in ids: - date = cache["data"].get(sha1) + date = cache["data"].setdefault(sha1, None) if date is not None: dates[sha1] = date return dates def origin_add(self, origin: OriginEntry) -> None: self.cache["origin"]["data"][origin.id] = origin.url self.cache["origin"]["added"].add(origin.id) def revision_add(self, revision: RevisionEntry) -> None: self.cache["revision"]["data"][revision.id] = revision.date self.cache["revision"]["added"].add(revision.id) def revision_add_before_revision( self, head: RevisionEntry, revision: RevisionEntry ) -> None: self.cache["revision_before_revision"].setdefault(revision.id, set()).add( head.id ) def revision_add_to_origin( self, origin: OriginEntry, revision: RevisionEntry ) -> None: self.cache["revision_in_origin"].add((revision.id, origin.id)) def revision_get_date(self, revision: RevisionEntry) -> Optional[datetime]: return self.get_dates("revision", [revision.id]).get(revision.id) def revision_get_preferred_origin( self, revision: RevisionEntry ) -> Optional[Sha1Git]: cache = self.cache["revision_origin"]["data"] if revision.id not in cache: ret = self.storage.revision_get([revision.id]) if revision.id in ret: origin = ret[revision.id].origin if origin is not None: cache[revision.id] = origin return cache.get(revision.id) def revision_in_history(self, revision: RevisionEntry) -> bool: return revision.id in self.cache["revision_before_revision"] or bool( self.storage.relation_get(RelationType.REV_BEFORE_REV, [revision.id]) ) def revision_set_preferred_origin( self, origin: OriginEntry, revision: RevisionEntry ) -> None: self.cache["revision_origin"]["data"][revision.id] = origin.id self.cache["revision_origin"]["added"].add(revision.id) def revision_visited(self, revision: RevisionEntry) -> bool: return revision.id in dict(self.cache["revision_in_origin"]) or bool( self.storage.relation_get(RelationType.REV_IN_ORG, [revision.id]) ) def normalize(path: bytes) -> bytes: return path[2:] if path.startswith(bytes("." + os.path.sep, "utf-8")) else path diff --git a/swh/provenance/tests/conftest.py b/swh/provenance/tests/conftest.py index 25a6b37..e9cb748 100644 --- a/swh/provenance/tests/conftest.py +++ b/swh/provenance/tests/conftest.py @@ -1,142 +1,153 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime, timedelta, timezone from os import path from typing import Any, Dict, Iterable, Iterator from _pytest.fixtures import SubRequest import msgpack import psycopg2.extensions +import pymongo.database import pytest from pytest_postgresql.factories import postgresql from swh.journal.serializers import msgpack_ext_hook from swh.provenance import get_provenance, get_provenance_storage from swh.provenance.api.client import RemoteProvenanceStorage import swh.provenance.api.server as server from swh.provenance.archive import ArchiveInterface from swh.provenance.interface import ProvenanceInterface, ProvenanceStorageInterface from swh.provenance.storage.archive import ArchiveStorage from swh.storage.interface import StorageInterface from swh.storage.replay import process_replay_objects @pytest.fixture( params=[ "with-path", "without-path", "with-path-denormalized", "without-path-denormalized", ] ) -def populated_db( +def provenance_postgresqldb( request: SubRequest, postgresql: psycopg2.extensions.connection, ) -> Dict[str, str]: """return a working and initialized provenance db""" from swh.core.cli.db import populate_database_for_package populate_database_for_package( "swh.provenance", postgresql.dsn, flavor=request.param ) return postgresql.get_dsn_parameters() # the Flask app used as server in these tests @pytest.fixture -def app(populated_db: Dict[str, str]) -> Iterator[server.ProvenanceStorageServerApp]: +def app( + provenance_postgresqldb: Dict[str, str] +) -> Iterator[server.ProvenanceStorageServerApp]: assert hasattr(server, "storage") - server.storage = get_provenance_storage(cls="postgresql", db=populated_db) + server.storage = get_provenance_storage( + cls="postgresql", db=provenance_postgresqldb + ) yield server.app # the RPCClient class used as client used in these tests @pytest.fixture def swh_rpc_client_class() -> type: return RemoteProvenanceStorage -@pytest.fixture(params=["postgresql", "remote"]) +@pytest.fixture(params=["mongodb"]) def provenance_storage( request: SubRequest, - populated_db: Dict[str, str], + provenance_postgresqldb: Dict[str, str], + mongodb: pymongo.database.Database, swh_rpc_client: RemoteProvenanceStorage, ) -> ProvenanceStorageInterface: """Return a working and initialized ProvenanceStorageInterface object""" if request.param == "remote": assert isinstance(swh_rpc_client, ProvenanceStorageInterface) return swh_rpc_client + elif request.param == "mongodb": + from swh.provenance.mongo.backend import ProvenanceStorageMongoDb + + return ProvenanceStorageMongoDb(mongodb) + else: # in test sessions, we DO want to raise any exception occurring at commit time return get_provenance_storage( - cls=request.param, db=populated_db, raise_on_commit=True + cls=request.param, db=provenance_postgresqldb, raise_on_commit=True ) provenance_postgresql = postgresql("postgresql_proc", dbname="provenance_tests") @pytest.fixture def provenance( provenance_postgresql: psycopg2.extensions.connection, ) -> ProvenanceInterface: """Return a working and initialized ProvenanceInterface object""" from swh.core.cli.db import populate_database_for_package populate_database_for_package( "swh.provenance", provenance_postgresql.dsn, flavor="with-path" ) # in test sessions, we DO want to raise any exception occurring at commit time return get_provenance( cls="postgresql", db=provenance_postgresql.get_dsn_parameters(), raise_on_commit=True, ) @pytest.fixture def archive(swh_storage: StorageInterface) -> ArchiveInterface: """Return an ArchiveStorage-based ArchiveInterface object""" return ArchiveStorage(swh_storage) def get_datafile(fname: str) -> str: return path.join(path.dirname(__file__), "data", fname) def load_repo_data(repo: str) -> Dict[str, Any]: data: Dict[str, Any] = {} with open(get_datafile(f"{repo}.msgpack"), "rb") as fobj: unpacker = msgpack.Unpacker( fobj, raw=False, ext_hook=msgpack_ext_hook, strict_map_key=False, timestamp=3, # convert Timestamp in datetime objects (tz UTC) ) for objtype, objd in unpacker: data.setdefault(objtype, []).append(objd) return data def filter_dict(d: Dict[Any, Any], keys: Iterable[Any]) -> Dict[Any, Any]: return {k: v for (k, v) in d.items() if k in keys} def fill_storage(storage: StorageInterface, data: Dict[str, Any]) -> None: process_replay_objects(data, storage=storage) # TODO: remove this function in favour of TimestampWithTimezone.to_datetime # from swh.model.model def ts2dt(ts: Dict[str, Any]) -> datetime: timestamp = datetime.fromtimestamp( ts["timestamp"]["seconds"], timezone(timedelta(minutes=ts["offset"])) ) return timestamp.replace(microsecond=ts["timestamp"]["microseconds"]) diff --git a/swh/provenance/tests/data/mongo/.gitkeep b/swh/provenance/tests/data/mongo/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/swh/provenance/tests/test_provenance_storage.py b/swh/provenance/tests/test_provenance_storage.py index 6cc0a7c..32a8629 100644 --- a/swh/provenance/tests/test_provenance_storage.py +++ b/swh/provenance/tests/test_provenance_storage.py @@ -1,350 +1,349 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime import inspect import os from typing import Any, Dict, Iterable, Optional, Set import pytest from swh.model.hashutil import hash_to_bytes from swh.model.identifiers import origin_identifier from swh.model.model import Sha1Git from swh.provenance.interface import ( EntityType, ProvenanceInterface, ProvenanceResult, ProvenanceStorageInterface, RelationData, RelationType, ) from swh.provenance.tests.conftest import load_repo_data, ts2dt def relation_add_and_compare_result( relation: RelationType, data: Set[RelationData], refstorage: ProvenanceStorageInterface, storage: ProvenanceStorageInterface, with_path: bool = True, ) -> None: assert data assert refstorage.relation_add(relation, data) == storage.relation_add( relation, data ) assert relation_compare_result( refstorage.relation_get(relation, (reldata.src for reldata in data)), storage.relation_get(relation, (reldata.src for reldata in data)), with_path, ) assert relation_compare_result( refstorage.relation_get( relation, (reldata.dst for reldata in data), reverse=True, ), storage.relation_get( relation, (reldata.dst for reldata in data), reverse=True, ), with_path, ) assert relation_compare_result( refstorage.relation_get_all(relation), storage.relation_get_all(relation), with_path, ) def relation_compare_result( expected: Set[RelationData], computed: Set[RelationData], with_path: bool ) -> bool: return { RelationData(reldata.src, reldata.dst, reldata.path if with_path else None) for reldata in expected } == computed def dircontent( data: Dict[str, Any], ref: Sha1Git, dir: Dict[str, Any], prefix: bytes = b"", ) -> Iterable[RelationData]: content = { RelationData(entry["target"], ref, os.path.join(prefix, entry["name"])) for entry in dir["entries"] if entry["type"] == "file" } for entry in dir["entries"]: if entry["type"] == "dir": child = next( subdir for subdir in data["directory"] if subdir["id"] == entry["target"] ) content.update( dircontent(data, ref, child, os.path.join(prefix, entry["name"])) ) return content @pytest.mark.parametrize( "repo", ("cmdbts2", "out-of-order", "with-merges"), ) def test_provenance_storage( provenance: ProvenanceInterface, provenance_storage: ProvenanceStorageInterface, repo: str, ) -> None: """Tests every ProvenanceStorageInterface implementation against the one provided for provenance.storage.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data(repo) # Assuming provenance.storage has the 'with-path' flavor. assert provenance.storage.with_path() - # Test content methods. - # Add all content present in the current repo to both storages, just assigning their - # creation dates. Then check that the inserted content is the same in both cases. - cnt_dates = {cnt["sha1_git"]: cnt["ctime"] for cnt in data["content"]} - assert cnt_dates - assert provenance.storage.content_set_date( - cnt_dates - ) == provenance_storage.content_set_date(cnt_dates) - - assert provenance.storage.content_get(cnt_dates) == provenance_storage.content_get( - cnt_dates - ) - assert provenance.storage.entity_get_all( - EntityType.CONTENT - ) == provenance_storage.entity_get_all(EntityType.CONTENT) - - # Test directory methods. - # Of all directories present in the current repo, only assign a date to those - # containing blobs (picking the max date among the available ones). Then check that - # the inserted data is the same in both storages. - def getmaxdate( - dir: Dict[str, Any], cnt_dates: Dict[Sha1Git, datetime] - ) -> Optional[datetime]: - dates = [ - cnt_dates[entry["target"]] - for entry in dir["entries"] - if entry["type"] == "file" - ] - return max(dates) if dates else None - - dir_dates = {dir["id"]: getmaxdate(dir, cnt_dates) for dir in data["directory"]} - assert dir_dates - assert provenance.storage.directory_set_date( - {sha1: date for sha1, date in dir_dates.items() if date is not None} - ) == provenance_storage.directory_set_date( - {sha1: date for sha1, date in dir_dates.items() if date is not None} - ) - - assert provenance.storage.directory_get( - dir_dates - ) == provenance_storage.directory_get(dir_dates) - assert provenance.storage.entity_get_all( - EntityType.DIRECTORY - ) == provenance_storage.entity_get_all(EntityType.DIRECTORY) - # Test origin methods. # Add all origins present in the current repo to both storages. Then check that the # inserted data is the same in both cases. org_urls = { hash_to_bytes(origin_identifier(org)): org["url"] for org in data["origin"] } assert org_urls assert provenance.storage.origin_set_url( org_urls ) == provenance_storage.origin_set_url(org_urls) assert provenance.storage.origin_get(org_urls) == provenance_storage.origin_get( org_urls ) assert provenance.storage.entity_get_all( EntityType.ORIGIN ) == provenance_storage.entity_get_all(EntityType.ORIGIN) - # Test revision methods. - # Add all revisions present in the current repo to both storages, assigning their - # dataes and an arbitrary origin to each one. Then check that the inserted data is - # the same in both cases. - rev_dates = {rev["id"]: ts2dt(rev["date"]) for rev in data["revision"]} - assert rev_dates - assert provenance.storage.revision_set_date( - rev_dates - ) == provenance_storage.revision_set_date(rev_dates) - - rev_origins = { - rev["id"]: next(iter(org_urls)) # any arbitrary origin will do - for rev in data["revision"] - } - assert rev_origins - assert provenance.storage.revision_set_origin( - rev_origins - ) == provenance_storage.revision_set_origin(rev_origins) - - assert provenance.storage.revision_get( - rev_dates - ) == provenance_storage.revision_get(rev_dates) - assert provenance.storage.entity_get_all( - EntityType.REVISION - ) == provenance_storage.entity_get_all(EntityType.REVISION) - # Test content-in-revision relation. # Create flat models of every root directory for the revisions in the dataset. cnt_in_rev: Set[RelationData] = set() for rev in data["revision"]: root = next( subdir for subdir in data["directory"] if subdir["id"] == rev["directory"] ) cnt_in_rev.update(dircontent(data, rev["id"], root)) relation_add_and_compare_result( RelationType.CNT_EARLY_IN_REV, cnt_in_rev, provenance.storage, provenance_storage, provenance_storage.with_path(), ) # Test content-in-directory relation. # Create flat models for every directory in the dataset. cnt_in_dir: Set[RelationData] = set() for dir in data["directory"]: cnt_in_dir.update(dircontent(data, dir["id"], dir)) relation_add_and_compare_result( RelationType.CNT_IN_DIR, cnt_in_dir, provenance.storage, provenance_storage, provenance_storage.with_path(), ) # Test content-in-directory relation. # Add root directories to their correspondent revision in the dataset. dir_in_rev = { RelationData(rev["directory"], rev["id"], b".") for rev in data["revision"] } relation_add_and_compare_result( RelationType.DIR_IN_REV, dir_in_rev, provenance.storage, provenance_storage, provenance_storage.with_path(), ) # Test revision-in-origin relation. # Add all revisions that are head of some snapshot branch to the corresponding # origin. rev_in_org = { RelationData( branch["target"], hash_to_bytes(origin_identifier({"url": status["origin"]})), None, ) for status in data["origin_visit_status"] if status["snapshot"] is not None for snapshot in data["snapshot"] if snapshot["id"] == status["snapshot"] for _, branch in snapshot["branches"].items() if branch["target_type"] == "revision" } relation_add_and_compare_result( RelationType.REV_IN_ORG, rev_in_org, provenance.storage, provenance_storage, ) # Test revision-before-revision relation. # For each revision in the data set add an entry for each parent to the relation. rev_before_rev = { RelationData(parent, rev["id"], None) for rev in data["revision"] for parent in rev["parents"] } relation_add_and_compare_result( RelationType.REV_BEFORE_REV, rev_before_rev, provenance.storage, provenance_storage, ) + # Test content methods. + # Add all content present in the current repo to both storages, just assigning their + # creation dates. Then check that the inserted content is the same in both cases. + cnt_dates = {cnt["sha1_git"]: cnt["ctime"] for cnt in data["content"]} + assert cnt_dates + assert provenance.storage.content_set_date( + cnt_dates + ) == provenance_storage.content_set_date(cnt_dates) + + assert provenance.storage.content_get(cnt_dates) == provenance_storage.content_get( + cnt_dates + ) + assert provenance.storage.entity_get_all( + EntityType.CONTENT + ) == provenance_storage.entity_get_all(EntityType.CONTENT) + + # Test directory methods. + # Of all directories present in the current repo, only assign a date to those + # containing blobs (picking the max date among the available ones). Then check that + # the inserted data is the same in both storages. + def getmaxdate( + dir: Dict[str, Any], cnt_dates: Dict[Sha1Git, datetime] + ) -> Optional[datetime]: + dates = [ + cnt_dates[entry["target"]] + for entry in dir["entries"] + if entry["type"] == "file" + ] + return max(dates) if dates else None + + dir_dates = {dir["id"]: getmaxdate(dir, cnt_dates) for dir in data["directory"]} + assert dir_dates + assert provenance.storage.directory_set_date( + {sha1: date for sha1, date in dir_dates.items() if date is not None} + ) == provenance_storage.directory_set_date( + {sha1: date for sha1, date in dir_dates.items() if date is not None} + ) + assert provenance.storage.directory_get( + dir_dates + ) == provenance_storage.directory_get(dir_dates) + assert provenance.storage.entity_get_all( + EntityType.DIRECTORY + ) == provenance_storage.entity_get_all(EntityType.DIRECTORY) + + # Test revision methods. + # Add all revisions present in the current repo to both storages, assigning their + # dataes and an arbitrary origin to each one. Then check that the inserted data is + # the same in both cases. + rev_dates = {rev["id"]: ts2dt(rev["date"]) for rev in data["revision"]} + assert rev_dates + assert provenance.storage.revision_set_date( + rev_dates + ) == provenance_storage.revision_set_date(rev_dates) + + rev_origins = { + rev["id"]: next(iter(org_urls)) # any arbitrary origin will do + for rev in data["revision"] + } + assert rev_origins + assert provenance.storage.revision_set_origin( + rev_origins + ) == provenance_storage.revision_set_origin(rev_origins) + + assert provenance.storage.revision_get( + rev_dates + ) == provenance_storage.revision_get(rev_dates) + assert provenance.storage.entity_get_all( + EntityType.REVISION + ) == provenance_storage.entity_get_all(EntityType.REVISION) + # Test location_get. if provenance_storage.with_path(): assert provenance.storage.location_get() == provenance_storage.location_get() # Test content_find_first and content_find_all. def adapt_result( result: Optional[ProvenanceResult], with_path: bool ) -> Optional[ProvenanceResult]: if result is not None: return ProvenanceResult( result.content, result.revision, result.date, result.origin, result.path if with_path else b"", ) return result for cnt in cnt_dates: assert adapt_result( provenance.storage.content_find_first(cnt), provenance_storage.with_path() ) == provenance_storage.content_find_first(cnt) assert { adapt_result(occur, provenance_storage.with_path()) for occur in provenance.storage.content_find_all(cnt) } == set(provenance_storage.content_find_all(cnt)) def test_types(provenance: ProvenanceInterface) -> None: """Checks all methods of ProvenanceStorageInterface are implemented by this backend, and that they have the same signature.""" # Create an instance of the protocol (which cannot be instantiated # directly, so this creates a subclass, then instantiates it) interface = type("_", (ProvenanceStorageInterface,), {})() assert "content_find_first" in dir(interface) missing_methods = [] for meth_name in dir(interface): if meth_name.startswith("_"): continue interface_meth = getattr(interface, meth_name) try: concrete_meth = getattr(provenance.storage, meth_name) except AttributeError: if not getattr(interface_meth, "deprecated_endpoint", False): # The backend is missing a (non-deprecated) endpoint missing_methods.append(meth_name) continue expected_signature = inspect.signature(interface_meth) actual_signature = inspect.signature(concrete_meth) assert expected_signature == actual_signature, meth_name assert missing_methods == [] # If all the assertions above succeed, then this one should too. # But there's no harm in double-checking. # And we could replace the assertions above by this one, but unlike # the assertions above, it doesn't explain what is missing. assert isinstance(provenance.storage, ProvenanceStorageInterface)