diff --git a/mypy.ini b/mypy.ini index 53b0ffb..09ce722 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,39 +1,36 @@ [mypy] namespace_packages = True warn_unused_ignores = True exclude = swh/provenance/tools/ # 3rd party libraries without stubs (yet) [mypy-bson.*] ignore_missing_imports = True [mypy-confluent_kafka.*] ignore_missing_imports = True [mypy-iso8601.*] ignore_missing_imports = True [mypy-methodtools.*] ignore_missing_imports = True [mypy-msgpack.*] ignore_missing_imports = True [mypy-pika.*] ignore_missing_imports = True [mypy-pkg_resources.*] ignore_missing_imports = True -[mypy-pymongo.*] -ignore_missing_imports = True - [mypy-pytest.*] ignore_missing_imports = True [mypy-pytest_postgresql.*] ignore_missing_imports = True [mypy-psycopg2.*] ignore_missing_imports = True diff --git a/pytest.ini b/pytest.ini index 684c551..d8fe211 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,8 +1,4 @@ [pytest] norecursedirs = docs .* -mongodb_fixture_dir = swh/provenance/tests/data/mongo -mongodb_engine = mongomock -mongodb_dbname = test - postgresql_postgres_options = -N 500 diff --git a/requirements-test.txt b/requirements-test.txt index 9d2d915..c23c4cc 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,8 +1,7 @@ pytest -pytest-mongodb pytest-rabbitmq swh.loader.git >= 0.8 swh.journal >= 0.8 swh.storage >= 0.40 swh.graph >= 0.3.2 types-Deprecated diff --git a/requirements.txt b/requirements.txt index 6b70f34..6cde9d2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,13 @@ # Add here external Python modules dependencies, one per line. Module names # should match https://pypi.python.org/pypi names. For the full spec or # dependency lines, see https://pip.readthedocs.org/en/1.1/requirements.html click deprecated iso8601 methodtools -mongomock pika -pymongo PyYAML types-click types-PyYAML zmq diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py index c13d047..e201048 100644 --- a/swh/provenance/__init__.py +++ b/swh/provenance/__init__.py @@ -1,125 +1,119 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from __future__ import annotations from typing import TYPE_CHECKING import warnings if TYPE_CHECKING: from .archive import ArchiveInterface from .interface import ProvenanceInterface, ProvenanceStorageInterface def get_archive(cls: str, **kwargs) -> ArchiveInterface: """Get an archive object of class ``cls`` with arguments ``args``. Args: cls: archive's class, either 'api', 'direct' or 'graph' args: dictionary of arguments passed to the archive class constructor Returns: an instance of archive object (either using swh.storage API or direct queries to the archive's database) Raises: :cls:`ValueError` if passed an unknown archive class. """ if cls == "api": from swh.storage import get_storage from .storage.archive import ArchiveStorage return ArchiveStorage(get_storage(**kwargs["storage"])) elif cls == "direct": from swh.core.db import BaseDb from .postgresql.archive import ArchivePostgreSQL return ArchivePostgreSQL(BaseDb.connect(**kwargs["db"]).conn) elif cls == "graph": try: from swh.graph.client import RemoteGraphClient from swh.storage import get_storage from .swhgraph.archive import ArchiveGraph graph = RemoteGraphClient(kwargs.get("url")) return ArchiveGraph(graph, get_storage(**kwargs["storage"])) except ModuleNotFoundError: raise EnvironmentError( "Graph configuration required but module is not installed." ) elif cls == "multiplexer": from .multiplexer.archive import ArchiveMultiplexed archives = list((get_archive(**archive) for archive in kwargs["archives"])) return ArchiveMultiplexed(archives) else: raise ValueError def get_provenance(**kwargs) -> ProvenanceInterface: """Get an provenance object with arguments ``args``. Args: args: dictionary of arguments to retrieve a swh.provenance.storage class (see :func:`get_provenance_storage` for details) Returns: an instance of provenance object """ from .provenance import Provenance return Provenance(get_provenance_storage(**kwargs)) def get_provenance_storage(cls: str, **kwargs) -> ProvenanceStorageInterface: """Get an archive object of class ``cls`` with arguments ``args``. Args: cls: storage's class, only 'local' is currently supported args: dictionary of arguments passed to the storage class constructor Returns: an instance of storage object Raises: :cls:`ValueError` if passed an unknown archive class. """ if cls in ["local", "postgresql"]: from .postgresql.provenance import ProvenanceStoragePostgreSql if cls == "local": warnings.warn( '"local" class is deprecated for provenance storage, please ' 'use "postgresql" class instead.', DeprecationWarning, ) raise_on_commit = kwargs.get("raise_on_commit", False) return ProvenanceStoragePostgreSql( raise_on_commit=raise_on_commit, **kwargs["db"] ) - elif cls == "mongodb": - from .mongo.backend import ProvenanceStorageMongoDb - - engine = kwargs.get("engine", "pymongo") - return ProvenanceStorageMongoDb(engine=engine, **kwargs["db"]) - elif cls == "rabbitmq": from .api.client import ProvenanceStorageRabbitMQClient rmq_storage = ProvenanceStorageRabbitMQClient(**kwargs) if TYPE_CHECKING: assert isinstance(rmq_storage, ProvenanceStorageInterface) return rmq_storage raise ValueError diff --git a/swh/provenance/cli.py b/swh/provenance/cli.py index a9a5118..5a276bf 100644 --- a/swh/provenance/cli.py +++ b/swh/provenance/cli.py @@ -1,607 +1,602 @@ # Copyright (C) 2021-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information # WARNING: do not import unnecessary things here to keep cli startup time under # control from datetime import datetime, timezone from functools import partial import os from typing import Any, Dict, Generator, Optional, Tuple import click from deprecated import deprecated import iso8601 import yaml from swh.core import config from swh.core.cli import CONTEXT_SETTINGS from swh.core.cli import swh as swh_cli_group from swh.model.hashutil import hash_to_bytes, hash_to_hex from swh.model.model import Sha1Git # All generic config code should reside in swh.core.config CONFIG_ENVVAR = "SWH_CONFIG_FILENAME" DEFAULT_PATH = os.environ.get(CONFIG_ENVVAR, None) DEFAULT_CONFIG: Dict[str, Any] = { "provenance": { "archive": { # Storage API based Archive object # "cls": "api", # "storage": { # "cls": "remote", # "url": "http://uffizi.internal.softwareheritage.org:5002", # } # Direct access Archive object "cls": "direct", "db": { "host": "belvedere.internal.softwareheritage.org", "port": 5432, "dbname": "softwareheritage", "user": "guest", }, }, "storage": { # Local PostgreSQL Storage # "cls": "postgresql", # "db": { # "host": "localhost", # "user": "postgres", # "password": "postgres", # "dbname": "provenance", # }, - # Local MongoDB Storage - # "cls": "mongodb", - # "db": { - # "dbname": "provenance", - # }, # Remote RabbitMQ/PostgreSQL Storage "cls": "rabbitmq", "url": "amqp://localhost:5672/%2f", "storage_config": { "cls": "postgresql", "db": { "host": "localhost", "user": "postgres", "password": "postgres", "dbname": "provenance", }, }, "batch_size": 100, "prefetch_count": 100, }, } } CONFIG_FILE_HELP = f""" \b Configuration can be loaded from a yaml file given either as --config-file option or the {CONFIG_ENVVAR} environment variable. If no configuration file is specified, use the following default configuration:: \b {yaml.dump(DEFAULT_CONFIG)}""" PROVENANCE_HELP = f"""Software Heritage provenance index database tools {CONFIG_FILE_HELP} """ @swh_cli_group.group( name="provenance", context_settings=CONTEXT_SETTINGS, help=PROVENANCE_HELP ) @click.option( "-C", "--config-file", default=None, type=click.Path(exists=True, dir_okay=False, path_type=str), help="""YAML configuration file.""", ) @click.option( "-P", "--profile", default=None, type=click.Path(exists=False, dir_okay=False, path_type=str), help="""Enable profiling to specified file.""", ) @click.pass_context def cli(ctx: click.core.Context, config_file: Optional[str], profile: str) -> None: if ( config_file is None and DEFAULT_PATH is not None and config.config_exists(DEFAULT_PATH) ): config_file = DEFAULT_PATH if config_file is None: conf = DEFAULT_CONFIG else: # read_raw_config do not fail on ENOENT if not os.path.exists(config_file): raise FileNotFoundError(config_file) conf = yaml.safe_load(open(config_file, "rb")) ctx.ensure_object(dict) ctx.obj["config"] = conf if profile: import atexit import cProfile print("Profiling...") pr = cProfile.Profile() pr.enable() def exit() -> None: pr.disable() pr.dump_stats(profile) atexit.register(exit) @cli.group(name="origin") @click.pass_context def origin(ctx: click.core.Context): from . import get_archive, get_provenance archive = get_archive(**ctx.obj["config"]["provenance"]["archive"]) provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"]) ctx.obj["provenance"] = provenance ctx.obj["archive"] = archive @origin.command(name="from-csv") @click.argument("filename", type=click.Path(exists=True)) @click.option( "-l", "--limit", type=int, help="""Limit the amount of entries (origins) to read from the input file.""", ) @click.pass_context def origin_from_csv(ctx: click.core.Context, filename: str, limit: Optional[int]): from .origin import CSVOriginIterator, origin_add provenance = ctx.obj["provenance"] archive = ctx.obj["archive"] origins_provider = generate_origin_tuples(filename) origins = CSVOriginIterator(origins_provider, limit=limit) with provenance: for origin in origins: origin_add(provenance, archive, [origin]) @origin.command(name="from-journal") @click.pass_context def origin_from_journal(ctx: click.core.Context): from swh.journal.client import get_journal_client from .journal_client import process_journal_origins provenance = ctx.obj["provenance"] archive = ctx.obj["archive"] journal_cfg = ctx.obj["config"].get("journal_client", {}) worker_fn = partial( process_journal_origins, archive=archive, provenance=provenance, ) cls = journal_cfg.pop("cls", None) or "kafka" client = get_journal_client( cls, **{ **journal_cfg, "object_types": ["origin_visit_status"], }, ) try: client.process(worker_fn) except KeyboardInterrupt: ctx.exit(0) else: print("Done.") finally: client.close() @cli.group(name="revision") @click.pass_context def revision(ctx: click.core.Context): from . import get_archive, get_provenance archive = get_archive(**ctx.obj["config"]["provenance"]["archive"]) provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"]) ctx.obj["provenance"] = provenance ctx.obj["archive"] = archive @revision.command(name="from-csv") @click.argument("filename", type=click.Path(exists=True)) @click.option( "-a", "--track-all", default=True, type=bool, help="""Index all occurrences of files in the development history.""", ) @click.option( "-f", "--flatten", default=True, type=bool, help="""Create flat models for directories in the isochrone frontier.""", ) @click.option( "-l", "--limit", type=int, help="""Limit the amount of entries (revisions) to read from the input file.""", ) @click.option( "-m", "--min-depth", default=1, type=int, help="""Set minimum depth (in the directory tree) at which an isochrone """ """frontier can be defined.""", ) @click.option( "-r", "--reuse", default=True, type=bool, help="""Prioritize the usage of previously defined isochrone frontiers """ """whenever possible.""", ) @click.option( "-s", "--min-size", default=0, type=int, help="""Set the minimum size (in bytes) of files to be indexed. """ """Any smaller file will be ignored.""", ) @click.pass_context def revision_from_csv( ctx: click.core.Context, filename: str, track_all: bool, flatten: bool, limit: Optional[int], min_depth: int, reuse: bool, min_size: int, ) -> None: from .revision import CSVRevisionIterator, revision_add provenance = ctx.obj["provenance"] archive = ctx.obj["archive"] revisions_provider = generate_revision_tuples(filename) revisions = CSVRevisionIterator(revisions_provider, limit=limit) with provenance: for revision in revisions: revision_add( provenance, archive, [revision], trackall=track_all, flatten=flatten, lower=reuse, mindepth=min_depth, minsize=min_size, ) @revision.command(name="from-journal") @click.option( "-a", "--track-all", default=True, type=bool, help="""Index all occurrences of files in the development history.""", ) @click.option( "-f", "--flatten", default=True, type=bool, help="""Create flat models for directories in the isochrone frontier.""", ) @click.option( "-l", "--limit", type=int, help="""Limit the amount of entries (revisions) to read from the input file.""", ) @click.option( "-m", "--min-depth", default=1, type=int, help="""Set minimum depth (in the directory tree) at which an isochrone """ """frontier can be defined.""", ) @click.option( "-r", "--reuse", default=True, type=bool, help="""Prioritize the usage of previously defined isochrone frontiers """ """whenever possible.""", ) @click.option( "-s", "--min-size", default=0, type=int, help="""Set the minimum size (in bytes) of files to be indexed. """ """Any smaller file will be ignored.""", ) @click.pass_context def revision_from_journal( ctx: click.core.Context, track_all: bool, flatten: bool, limit: Optional[int], min_depth: int, reuse: bool, min_size: int, ) -> None: from swh.journal.client import get_journal_client from .journal_client import process_journal_revisions provenance = ctx.obj["provenance"] archive = ctx.obj["archive"] journal_cfg = ctx.obj["config"].get("journal_client", {}) worker_fn = partial( process_journal_revisions, archive=archive, provenance=provenance, ) cls = journal_cfg.pop("cls", None) or "kafka" client = get_journal_client( cls, **{ **journal_cfg, "object_types": ["revision"], }, ) try: client.process(worker_fn) except KeyboardInterrupt: ctx.exit(0) else: print("Done.") finally: client.close() @cli.command(name="iter-frontiers") @click.argument("filename") @click.option( "-l", "--limit", type=int, help="""Limit the amount of entries (directories) to read from the input file.""", ) @click.option( "-s", "--min-size", default=0, type=int, help="""Set the minimum size (in bytes) of files to be indexed. """ """Any smaller file will be ignored.""", ) @click.pass_context def iter_frontiers( ctx: click.core.Context, filename: str, limit: Optional[int], min_size: int, ) -> None: """Process a provided list of directories in the isochrone frontier.""" from . import get_archive, get_provenance from .directory import CSVDirectoryIterator, directory_add archive = get_archive(**ctx.obj["config"]["provenance"]["archive"]) directories_provider = generate_directory_ids(filename) directories = CSVDirectoryIterator(directories_provider, limit=limit) with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance: for directory in directories: directory_add( provenance, archive, [directory], minsize=min_size, ) def generate_directory_ids( filename: str, ) -> Generator[Sha1Git, None, None]: for line in open(filename, "r"): if line.strip(): yield hash_to_bytes(line.strip()) @cli.command(name="iter-revisions") @click.argument("filename") @click.option( "-a", "--track-all", default=True, type=bool, help="""Index all occurrences of files in the development history.""", ) @click.option( "-f", "--flatten", default=True, type=bool, help="""Create flat models for directories in the isochrone frontier.""", ) @click.option( "-l", "--limit", type=int, help="""Limit the amount of entries (revisions) to read from the input file.""", ) @click.option( "-m", "--min-depth", default=1, type=int, help="""Set minimum depth (in the directory tree) at which an isochrone """ """frontier can be defined.""", ) @click.option( "-r", "--reuse", default=True, type=bool, help="""Prioritize the usage of previously defined isochrone frontiers """ """whenever possible.""", ) @click.option( "-s", "--min-size", default=0, type=int, help="""Set the minimum size (in bytes) of files to be indexed. """ """Any smaller file will be ignored.""", ) @click.pass_context def iter_revisions( ctx: click.core.Context, filename: str, track_all: bool, flatten: bool, limit: Optional[int], min_depth: int, reuse: bool, min_size: int, ) -> None: """Process a provided list of revisions.""" from . import get_archive, get_provenance from .revision import CSVRevisionIterator, revision_add archive = get_archive(**ctx.obj["config"]["provenance"]["archive"]) revisions_provider = generate_revision_tuples(filename) revisions = CSVRevisionIterator(revisions_provider, limit=limit) with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance: for revision in revisions: revision_add( provenance, archive, [revision], trackall=track_all, flatten=flatten, lower=reuse, mindepth=min_depth, minsize=min_size, ) def generate_revision_tuples( filename: str, ) -> Generator[Tuple[Sha1Git, datetime, Sha1Git], None, None]: for line in open(filename, "r"): if line.strip(): revision, date, root = line.strip().split(",") yield ( hash_to_bytes(revision), iso8601.parse_date(date, default_timezone=timezone.utc), hash_to_bytes(root), ) @cli.command(name="iter-origins") @click.argument("filename") @click.option( "-l", "--limit", type=int, help="""Limit the amount of entries (origins) to read from the input file.""", ) @click.pass_context @deprecated(version="0.0.1", reason="Use `swh provenance origin from-csv` instead") def iter_origins(ctx: click.core.Context, filename: str, limit: Optional[int]) -> None: """Process a provided list of origins.""" from . import get_archive, get_provenance from .origin import CSVOriginIterator, origin_add archive = get_archive(**ctx.obj["config"]["provenance"]["archive"]) origins_provider = generate_origin_tuples(filename) origins = CSVOriginIterator(origins_provider, limit=limit) with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance: for origin in origins: origin_add(provenance, archive, [origin]) def generate_origin_tuples(filename: str) -> Generator[Tuple[str, bytes], None, None]: for line in open(filename, "r"): if line.strip(): url, snapshot = line.strip().split(",") yield (url, hash_to_bytes(snapshot)) @cli.command(name="find-first") @click.argument("swhid") @click.pass_context def find_first(ctx: click.core.Context, swhid: str) -> None: """Find first occurrence of the requested blob.""" from . import get_provenance with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance: occur = provenance.content_find_first(hash_to_bytes(swhid)) if occur is not None: print( f"swh:1:cnt:{hash_to_hex(occur.content)}, " f"swh:1:rev:{hash_to_hex(occur.revision)}, " f"{occur.date}, " f"{occur.origin}, " f"{os.fsdecode(occur.path)}" ) else: print(f"Cannot find a content with the id {swhid}") @cli.command(name="find-all") @click.argument("swhid") @click.option( "-l", "--limit", type=int, help="""Limit the amount results to be retrieved.""" ) @click.pass_context def find_all(ctx: click.core.Context, swhid: str, limit: Optional[int]) -> None: """Find all occurrences of the requested blob.""" from . import get_provenance with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance: for occur in provenance.content_find_all(hash_to_bytes(swhid), limit=limit): print( f"swh:1:cnt:{hash_to_hex(occur.content)}, " f"swh:1:rev:{hash_to_hex(occur.revision)}, " f"{occur.date}, " f"{occur.origin}, " f"{os.fsdecode(occur.path)}" ) diff --git a/swh/provenance/mongo/README.md b/swh/provenance/mongo/README.md deleted file mode 100644 index b8e393e..0000000 --- a/swh/provenance/mongo/README.md +++ /dev/null @@ -1,44 +0,0 @@ -mongo backend -============= - -Provenance storage implementation using MongoDB - -initial data-model ------------------- - -```json -content -{ - id: sha1 - ts: int // optional - revision: {: []} - directory: {: []} -} - -directory -{ - id: sha1 - ts: int // optional - revision: {: []} -} - -revision -{ - id: sha1 - ts: int // optional - preferred // optional - origin [] - revision [] -} - -origin -{ - id: sha1 - url: str -} - -path -{ - path: str -} -``` diff --git a/swh/provenance/mongo/__init__.py b/swh/provenance/mongo/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/swh/provenance/mongo/backend.py b/swh/provenance/mongo/backend.py deleted file mode 100644 index a08a8c7..0000000 --- a/swh/provenance/mongo/backend.py +++ /dev/null @@ -1,529 +0,0 @@ -# Copyright (C) 2021-2022 The Software Heritage developers -# See the AUTHORS file at the top-level directory of this distribution -# License: GNU General Public License version 3, or any later version -# See top-level LICENSE file for more information - -from __future__ import annotations - -from datetime import datetime, timezone -import os -from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Generator, - Iterable, - List, - Optional, - Set, - Type, - Union, -) - -from bson import ObjectId - -from swh.core.statsd import statsd -from swh.model.model import Sha1Git - -from ..interface import ( - DirectoryData, - EntityType, - ProvenanceResult, - ProvenanceStorageInterface, - RelationData, - RelationType, - RevisionData, -) - -STORAGE_DURATION_METRIC = "swh_provenance_storage_mongodb_duration_seconds" - -if TYPE_CHECKING: - from pymongo.database import Database - - -class ProvenanceStorageMongoDb: - def __init__(self, engine: str, **kwargs): - self.engine = engine - self.dbname = kwargs.pop("dbname") - self.conn_args = kwargs - - def __enter__(self) -> ProvenanceStorageInterface: - self.open() - return self - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - self.close() - - @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "close"}) - def close(self) -> None: - self.db.client.close() - - @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_add"}) - def content_add(self, cnts: Dict[Sha1Git, datetime]) -> bool: - existing = { - x["sha1"]: x - for x in self.db.content.find( - {"sha1": {"$in": list(cnts)}}, {"sha1": 1, "ts": 1, "_id": 1} - ) - } - for sha1, date in cnts.items(): - ts = datetime.timestamp(date) - if sha1 in existing: - cnt = existing[sha1] - if ts < cnt["ts"]: - self.db.content.update_one( - {"_id": cnt["_id"]}, {"$set": {"ts": ts}} - ) - else: - self.db.content.insert_one( - { - "sha1": sha1, - "ts": ts, - "revision": {}, - "directory": {}, - } - ) - return True - - @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_first"}) - def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: - # get all the revisions - # iterate and find the earliest - content = self.db.content.find_one({"sha1": id}) - if not content: - return None - - occurs = [] - for revision in self.db.revision.find( - {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["revision"]]}} - ): - if revision["preferred"] is not None: - origin = self.db.origin.find_one({"sha1": revision["preferred"]}) - else: - origin = {"url": None} - - assert origin is not None - - for path in content["revision"][str(revision["_id"])]: - occurs.append( - ProvenanceResult( - content=id, - revision=revision["sha1"], - date=datetime.fromtimestamp(revision["ts"], timezone.utc), - origin=origin["url"], - path=path, - ) - ) - return sorted(occurs, key=lambda x: (x.date, x.revision, x.origin, x.path))[0] - - @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_all"}) - def content_find_all( - self, id: Sha1Git, limit: Optional[int] = None - ) -> Generator[ProvenanceResult, None, None]: - content = self.db.content.find_one({"sha1": id}) - if not content: - return None - - occurs = [] - for revision in self.db.revision.find( - {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["revision"]]}} - ): - if revision["preferred"] is not None: - origin = self.db.origin.find_one({"sha1": revision["preferred"]}) - else: - origin = {"url": None} - - assert origin is not None - - for path in content["revision"][str(revision["_id"])]: - occurs.append( - ProvenanceResult( - content=id, - revision=revision["sha1"], - date=datetime.fromtimestamp(revision["ts"], timezone.utc), - origin=origin["url"], - path=path, - ) - ) - for directory in self.db.directory.find( - {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["directory"]]}} - ): - for revision in self.db.revision.find( - {"_id": {"$in": [ObjectId(obj_id) for obj_id in directory["revision"]]}} - ): - if revision["preferred"] is not None: - origin = self.db.origin.find_one({"sha1": revision["preferred"]}) - else: - origin = {"url": None} - - assert origin is not None - - for suffix in content["directory"][str(directory["_id"])]: - for prefix in directory["revision"][str(revision["_id"])]: - path = ( - os.path.join(prefix, suffix) - if prefix not in [b".", b""] - else suffix - ) - occurs.append( - ProvenanceResult( - content=id, - revision=revision["sha1"], - date=datetime.fromtimestamp( - revision["ts"], timezone.utc - ), - origin=origin["url"], - path=path, - ) - ) - yield from sorted(occurs, key=lambda x: (x.date, x.revision, x.origin, x.path)) - - @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_get"}) - def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: - return { - x["sha1"]: datetime.fromtimestamp(x["ts"], timezone.utc) - for x in self.db.content.find( - {"sha1": {"$in": list(ids)}}, {"sha1": 1, "ts": 1, "_id": 0} - ) - } - - @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_add"}) - def directory_add(self, dirs: Dict[Sha1Git, DirectoryData]) -> bool: - existing = { - x["sha1"]: x - for x in self.db.directory.find( - {"sha1": {"$in": list(dirs)}}, {"sha1": 1, "ts": 1, "flat": 1, "_id": 1} - ) - } - for sha1, info in dirs.items(): - ts = datetime.timestamp(info.date) - if sha1 in existing: - dir = existing[sha1] - if ts >= dir["ts"]: - ts = dir["ts"] - flat = info.flat or dir["flat"] - if ts != dir["ts"] or flat != dir["flat"]: - self.db.directory.update_one( - {"_id": dir["_id"]}, {"$set": {"ts": ts, "flat": flat}} - ) - else: - self.db.directory.insert_one( - {"sha1": sha1, "ts": ts, "revision": {}, "flat": info.flat} - ) - return True - - @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_get"}) - def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, DirectoryData]: - return { - x["sha1"]: DirectoryData( - date=datetime.fromtimestamp(x["ts"], timezone.utc), flat=x["flat"] - ) - for x in self.db.directory.find( - {"sha1": {"$in": list(ids)}}, {"sha1": 1, "ts": 1, "flat": 1, "_id": 0} - ) - } - - @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "entity_get_all"}) - def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: - return { - x["sha1"] - for x in self.db.get_collection(entity.value).find( - {}, {"sha1": 1, "_id": 0} - ) - } - - @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_add"}) - def location_add(self, paths: Iterable[bytes]) -> bool: - # TODO: implement this methods if path are to be stored in a separate collection - return True - - @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_get_all"}) - def location_get_all(self) -> Set[bytes]: - contents = self.db.content.find({}, {"revision": 1, "_id": 0, "directory": 1}) - paths: List[Iterable[bytes]] = [] - for content in contents: - paths.extend(value for _, value in content["revision"].items()) - paths.extend(value for _, value in content["directory"].items()) - - dirs = self.db.directory.find({}, {"revision": 1, "_id": 0}) - for each_dir in dirs: - paths.extend(value for _, value in each_dir["revision"].items()) - return set(sum(paths, [])) - - @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "open"}) - def open(self) -> None: - if self.engine == "mongomock": - from mongomock import MongoClient as MongoClient - else: # assume real MongoDB server by default - from pymongo import MongoClient - self.db: Database = MongoClient(**self.conn_args).get_database(self.dbname) - - @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_add"}) - def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: - existing = { - x["sha1"]: x - for x in self.db.origin.find( - {"sha1": {"$in": list(orgs)}}, {"sha1": 1, "url": 1, "_id": 1} - ) - } - for sha1, url in orgs.items(): - if sha1 not in existing: - # add new origin - self.db.origin.insert_one({"sha1": sha1, "url": url}) - return True - - @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_get"}) - def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: - return { - x["sha1"]: x["url"] - for x in self.db.origin.find( - {"sha1": {"$in": list(ids)}}, {"sha1": 1, "url": 1, "_id": 0} - ) - } - - @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_add"}) - def revision_add( - self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] - ) -> bool: - data = ( - revs - if isinstance(revs, dict) - else dict.fromkeys(revs, RevisionData(date=None, origin=None)) - ) - existing = { - x["sha1"]: x - for x in self.db.revision.find( - {"sha1": {"$in": list(data)}}, - {"sha1": 1, "ts": 1, "preferred": 1, "_id": 1}, - ) - } - for sha1, info in data.items(): - ts = datetime.timestamp(info.date) if info.date is not None else None - preferred = info.origin - if sha1 in existing: - rev = existing[sha1] - if ts is None or (rev["ts"] is not None and ts >= rev["ts"]): - ts = rev["ts"] - if preferred is None: - preferred = rev["preferred"] - if ts != rev["ts"] or preferred != rev["preferred"]: - self.db.revision.update_one( - {"_id": rev["_id"]}, - {"$set": {"ts": ts, "preferred": preferred}}, - ) - else: - self.db.revision.insert_one( - { - "sha1": sha1, - "preferred": preferred, - "origin": [], - "revision": [], - "ts": ts, - } - ) - return True - - @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_get"}) - def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: - return { - x["sha1"]: RevisionData( - date=datetime.fromtimestamp(x["ts"], timezone.utc) if x["ts"] else None, - origin=x["preferred"], - ) - for x in self.db.revision.find( - { - "sha1": {"$in": list(ids)}, - "$or": [{"preferred": {"$ne": None}}, {"ts": {"$ne": None}}], - }, - {"sha1": 1, "preferred": 1, "ts": 1, "_id": 0}, - ) - } - - @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_add"}) - def relation_add( - self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] - ) -> bool: - src_relation, *_, dst_relation = relation.value.split("_") - - dst_objs = { - x["sha1"]: x["_id"] - for x in self.db.get_collection(dst_relation).find( - { - "sha1": { - "$in": list({rel.dst for rels in data.values() for rel in rels}) - } - }, - {"_id": 1, "sha1": 1}, - ) - } - - denorm: Dict[Sha1Git, Any] = {} - for src, rels in data.items(): - for rel in rels: - if src_relation != "revision": - denorm.setdefault(src, {}).setdefault( - str(dst_objs[rel.dst]), [] - ).append(rel.path) - else: - denorm.setdefault(src, []).append(dst_objs[rel.dst]) - - src_objs = { - x["sha1"]: x - for x in self.db.get_collection(src_relation).find( - {"sha1": {"$in": list(denorm.keys())}} - ) - } - - for sha1, dsts in denorm.items(): - # update - if src_relation != "revision": - k = { - obj_id: list(set(paths + dsts.get(obj_id, []))) - for obj_id, paths in src_objs[sha1][dst_relation].items() - } - self.db.get_collection(src_relation).update_one( - {"_id": src_objs[sha1]["_id"]}, - {"$set": {dst_relation: dict(dsts, **k)}}, - ) - else: - self.db.get_collection(src_relation).update_one( - {"_id": src_objs[sha1]["_id"]}, - { - "$set": { - dst_relation: list(set(src_objs[sha1][dst_relation] + dsts)) - } - }, - ) - return True - - @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get"}) - def relation_get( - self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False - ) -> Dict[Sha1Git, Set[RelationData]]: - src, *_, dst = relation.value.split("_") - sha1s = set(ids) - if not reverse: - empty: Union[Dict[str, bytes], List[str]] = {} if src != "revision" else [] - src_objs = { - x["sha1"]: x[dst] - for x in self.db.get_collection(src).find( - {"sha1": {"$in": list(sha1s)}, dst: {"$ne": empty}}, - {"_id": 0, "sha1": 1, dst: 1}, - ) - } - dst_ids = list( - {ObjectId(obj_id) for _, value in src_objs.items() for obj_id in value} - ) - dst_objs = { - x["sha1"]: x["_id"] - for x in self.db.get_collection(dst).find( - {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} - ) - } - if src != "revision": - return { - src_sha1: { - RelationData(dst=dst_sha1, path=path) - for dst_sha1, dst_obj_id in dst_objs.items() - for dst_obj_str, paths in denorm.items() - for path in paths - if dst_obj_id == ObjectId(dst_obj_str) - } - for src_sha1, denorm in src_objs.items() - } - else: - return { - src_sha1: { - RelationData(dst=dst_sha1, path=None) - for dst_sha1, dst_obj_id in dst_objs.items() - for dst_obj_ref in denorm - if dst_obj_id == dst_obj_ref - } - for src_sha1, denorm in src_objs.items() - } - else: - dst_objs = { - x["sha1"]: x["_id"] - for x in self.db.get_collection(dst).find( - {"sha1": {"$in": list(sha1s)}}, {"_id": 1, "sha1": 1} - ) - } - src_objs = { - x["sha1"]: x[dst] - for x in self.db.get_collection(src).find( - {}, {"_id": 0, "sha1": 1, dst: 1} - ) - } - result: Dict[Sha1Git, Set[RelationData]] = {} - if src != "revision": - for dst_sha1, dst_obj_id in dst_objs.items(): - for src_sha1, denorm in src_objs.items(): - for dst_obj_str, paths in denorm.items(): - if dst_obj_id == ObjectId(dst_obj_str): - result.setdefault(src_sha1, set()).update( - RelationData(dst=dst_sha1, path=path) - for path in paths - ) - else: - for dst_sha1, dst_obj_id in dst_objs.items(): - for src_sha1, denorm in src_objs.items(): - if dst_obj_id in { - ObjectId(dst_obj_str) for dst_obj_str in denorm - }: - result.setdefault(src_sha1, set()).add( - RelationData(dst=dst_sha1, path=None) - ) - return result - - @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get_all"}) - def relation_get_all( - self, relation: RelationType - ) -> Dict[Sha1Git, Set[RelationData]]: - src, *_, dst = relation.value.split("_") - empty: Union[Dict[str, bytes], List[str]] = {} if src != "revision" else [] - src_objs = { - x["sha1"]: x[dst] - for x in self.db.get_collection(src).find( - {dst: {"$ne": empty}}, {"_id": 0, "sha1": 1, dst: 1} - ) - } - dst_ids = list( - {ObjectId(obj_id) for _, value in src_objs.items() for obj_id in value} - ) - dst_objs = { - x["_id"]: x["sha1"] - for x in self.db.get_collection(dst).find( - {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} - ) - } - if src != "revision": - return { - src_sha1: { - RelationData(dst=dst_sha1, path=path) - for dst_obj_id, dst_sha1 in dst_objs.items() - for dst_obj_str, paths in denorm.items() - for path in paths - if dst_obj_id == ObjectId(dst_obj_str) - } - for src_sha1, denorm in src_objs.items() - } - else: - return { - src_sha1: { - RelationData(dst=dst_sha1, path=None) - for dst_obj_id, dst_sha1 in dst_objs.items() - for dst_obj_ref in denorm - if dst_obj_id == dst_obj_ref - } - for src_sha1, denorm in src_objs.items() - } - - @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "with_path"}) - def with_path(self) -> bool: - return True diff --git a/swh/provenance/tests/conftest.py b/swh/provenance/tests/conftest.py index a467bda..08c078b 100644 --- a/swh/provenance/tests/conftest.py +++ b/swh/provenance/tests/conftest.py @@ -1,176 +1,163 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime, timedelta, timezone from os import path from typing import Any, Dict, Generator, List from _pytest.fixtures import SubRequest -import mongomock.database import msgpack import psycopg2.extensions import pytest from pytest_postgresql.factories import postgresql from swh.journal.serializers import msgpack_ext_hook from swh.model.model import BaseModel from swh.provenance import get_provenance, get_provenance_storage from swh.provenance.archive import ArchiveInterface from swh.provenance.interface import ProvenanceInterface, ProvenanceStorageInterface from swh.provenance.storage.archive import ArchiveStorage from swh.storage.interface import StorageInterface from swh.storage.replay import OBJECT_CONVERTERS, OBJECT_FIXERS, process_replay_objects @pytest.fixture( params=[ "with-path", "without-path", "with-path-denormalized", "without-path-denormalized", ] ) def provenance_postgresqldb( request: SubRequest, postgresql: psycopg2.extensions.connection, ) -> Dict[str, str]: """return a working and initialized provenance db""" from swh.core.db.db_utils import ( init_admin_extensions, populate_database_for_package, ) init_admin_extensions("swh.provenance", postgresql.dsn) populate_database_for_package( "swh.provenance", postgresql.dsn, flavor=request.param ) return postgresql.get_dsn_parameters() -@pytest.fixture(params=["mongodb", "postgresql", "rabbitmq"]) +@pytest.fixture(params=["postgresql", "rabbitmq"]) def provenance_storage( request: SubRequest, provenance_postgresqldb: Dict[str, str], - mongodb: mongomock.database.Database, ) -> Generator[ProvenanceStorageInterface, None, None]: """Return a working and initialized ProvenanceStorageInterface object""" - if request.param == "mongodb": - mongodb_params = { - "host": mongodb.client.address[0], - "port": mongodb.client.address[1], - "dbname": mongodb.name, - } - with get_provenance_storage( - cls=request.param, db=mongodb_params, engine="mongomock" - ) as storage: - yield storage - - elif request.param == "rabbitmq": + if request.param == "rabbitmq": from swh.provenance.api.server import ProvenanceStorageRabbitMQServer rabbitmq = request.getfixturevalue("rabbitmq") host = rabbitmq.args["host"] port = rabbitmq.args["port"] rabbitmq_params: Dict[str, Any] = { "url": f"amqp://guest:guest@{host}:{port}/%2f", "storage_config": { - "cls": "postgresql", # TODO: also test with underlying mongodb storage + "cls": "postgresql", "db": provenance_postgresqldb, "raise_on_commit": True, }, } server = ProvenanceStorageRabbitMQServer( url=rabbitmq_params["url"], storage_config=rabbitmq_params["storage_config"] ) server.start() with get_provenance_storage(cls=request.param, **rabbitmq_params) as storage: yield storage server.stop() else: # in test sessions, we DO want to raise any exception occurring at commit time with get_provenance_storage( cls=request.param, db=provenance_postgresqldb, raise_on_commit=True ) as storage: yield storage provenance_postgresql = postgresql("postgresql_proc", dbname="provenance_tests") @pytest.fixture def provenance( provenance_postgresql: psycopg2.extensions.connection, ) -> Generator[ProvenanceInterface, None, None]: """Return a working and initialized ProvenanceInterface object""" from swh.core.db.db_utils import ( init_admin_extensions, populate_database_for_package, ) init_admin_extensions("swh.provenance", provenance_postgresql.dsn) populate_database_for_package( "swh.provenance", provenance_postgresql.dsn, flavor="with-path" ) # in test sessions, we DO want to raise any exception occurring at commit time with get_provenance( cls="postgresql", db=provenance_postgresql.get_dsn_parameters(), raise_on_commit=True, ) as provenance: yield provenance @pytest.fixture def archive(swh_storage: StorageInterface) -> ArchiveInterface: """Return an ArchiveStorage-based ArchiveInterface object""" return ArchiveStorage(swh_storage) def fill_storage(storage: StorageInterface, data: Dict[str, List[dict]]) -> None: objects = { objtype: [objs_from_dict(objtype, d) for d in dicts] for objtype, dicts in data.items() } process_replay_objects(objects, storage=storage) def get_datafile(fname: str) -> str: return path.join(path.dirname(__file__), "data", fname) # TODO: this should return Dict[str, List[BaseModel]] directly, but it requires # refactoring several tests def load_repo_data(repo: str) -> Dict[str, List[dict]]: data: Dict[str, List[dict]] = {} with open(get_datafile(f"{repo}.msgpack"), "rb") as fobj: unpacker = msgpack.Unpacker( fobj, raw=False, ext_hook=msgpack_ext_hook, strict_map_key=False, timestamp=3, # convert Timestamp in datetime objects (tz UTC) ) for objtype, objd in unpacker: data.setdefault(objtype, []).append(objd) return data def objs_from_dict(object_type: str, dict_repr: dict) -> BaseModel: if object_type in OBJECT_FIXERS: dict_repr = OBJECT_FIXERS[object_type](dict_repr) obj = OBJECT_CONVERTERS[object_type](dict_repr) return obj # TODO: remove this function in favour of TimestampWithTimezone.to_datetime # from swh.model.model def ts2dt(ts: Dict[str, Any]) -> datetime: timestamp = datetime.fromtimestamp( ts["timestamp"]["seconds"], timezone(timedelta(minutes=ts["offset"])) ) return timestamp.replace(microsecond=ts["timestamp"]["microseconds"]) diff --git a/swh/provenance/tests/test_provenance_storage.py b/swh/provenance/tests/test_provenance_storage.py index 42b9c8d..f36b469 100644 --- a/swh/provenance/tests/test_provenance_storage.py +++ b/swh/provenance/tests/test_provenance_storage.py @@ -1,468 +1,463 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime, timezone import inspect import os from typing import Any, Dict, Iterable, Optional, Set, Tuple from swh.model.hashutil import hash_to_bytes from swh.model.model import Origin, Sha1Git from swh.provenance.archive import ArchiveInterface from swh.provenance.interface import ( DirectoryData, EntityType, ProvenanceInterface, ProvenanceResult, ProvenanceStorageInterface, RelationData, RelationType, RevisionData, ) from swh.provenance.model import OriginEntry, RevisionEntry -from swh.provenance.mongo.backend import ProvenanceStorageMongoDb from swh.provenance.origin import origin_add from swh.provenance.provenance import Provenance from swh.provenance.revision import revision_add from swh.provenance.tests.conftest import fill_storage, load_repo_data, ts2dt def test_provenance_storage_content( provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests content methods for every `ProvenanceStorageInterface` implementation.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") # Add all content present in the current repo to the storage, just assigning their # creation dates. Then check that the returned results when querying are the same. cnt_dates = { cnt["sha1_git"]: cnt["ctime"] for idx, cnt in enumerate(data["content"]) } assert provenance_storage.content_add(cnt_dates) assert provenance_storage.content_get(set(cnt_dates.keys())) == cnt_dates assert provenance_storage.entity_get_all(EntityType.CONTENT) == set( cnt_dates.keys() ) def test_provenance_storage_directory( provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests directory methods for every `ProvenanceStorageInterface` implementation.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") # Of all directories present in the current repo, only assign a date to those # containing blobs (picking the max date among the available ones). Then check that # the returned results when querying are the same. def getmaxdate( directory: Dict[str, Any], contents: Iterable[Dict[str, Any]] ) -> Optional[datetime]: dates = [ content["ctime"] for entry in directory["entries"] for content in contents if entry["type"] == "file" and entry["target"] == content["sha1_git"] ] return max(dates) if dates else None flat_values = (False, True) dir_dates = {} for idx, dir in enumerate(data["directory"]): date = getmaxdate(dir, data["content"]) if date is not None: dir_dates[dir["id"]] = DirectoryData(date=date, flat=flat_values[idx % 2]) assert provenance_storage.directory_add(dir_dates) assert provenance_storage.directory_get(set(dir_dates.keys())) == dir_dates assert provenance_storage.entity_get_all(EntityType.DIRECTORY) == set( dir_dates.keys() ) def test_provenance_storage_location( provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests location methods for every `ProvenanceStorageInterface` implementation.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") # Add all names of entries present in the directories of the current repo as paths # to the storage. Then check that the returned results when querying are the same. paths = {entry["name"] for dir in data["directory"] for entry in dir["entries"]} assert provenance_storage.location_add(paths) - if isinstance(provenance_storage, ProvenanceStorageMongoDb): - # TODO: remove this when `location_add` is properly implemented for MongoDb. - return - if provenance_storage.with_path(): assert provenance_storage.location_get_all() == paths else: assert provenance_storage.location_get_all() == set() def test_provenance_storage_origin( provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests origin methods for every `ProvenanceStorageInterface` implementation.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") # Test origin methods. # Add all origins present in the current repo to the storage. Then check that the # returned results when querying are the same. orgs = {Origin(url=org["url"]).id: org["url"] for org in data["origin"]} assert orgs assert provenance_storage.origin_add(orgs) assert provenance_storage.origin_get(set(orgs.keys())) == orgs assert provenance_storage.entity_get_all(EntityType.ORIGIN) == set(orgs.keys()) def test_provenance_storage_revision( provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests revision methods for every `ProvenanceStorageInterface` implementation.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") # Test revision methods. # Add all revisions present in the current repo to the storage, assigning their # dates and an arbitrary origin to each one. Then check that the returned results # when querying are the same. origin = Origin(url=next(iter(data["origin"]))["url"]) # Origin must be inserted in advance. assert provenance_storage.origin_add({origin.id: origin.url}) revs = {rev["id"] for idx, rev in enumerate(data["revision"]) if idx % 6 == 0} rev_data = { rev["id"]: RevisionData( date=ts2dt(rev["date"]) if idx % 2 != 0 else None, origin=origin.id if idx % 3 != 0 else None, ) for idx, rev in enumerate(data["revision"]) if idx % 6 != 0 } assert revs assert provenance_storage.revision_add(revs) assert provenance_storage.revision_add(rev_data) assert provenance_storage.revision_get(set(rev_data.keys())) == rev_data assert provenance_storage.entity_get_all(EntityType.REVISION) == revs | set( rev_data.keys() ) def dircontent( data: Dict[str, Any], ref: Sha1Git, dir: Dict[str, Any], prefix: bytes = b"", ) -> Iterable[Tuple[Sha1Git, RelationData]]: content = { ( entry["target"], RelationData(dst=ref, path=os.path.join(prefix, entry["name"])), ) for entry in dir["entries"] if entry["type"] == "file" } for entry in dir["entries"]: if entry["type"] == "dir": child = next( subdir for subdir in data["directory"] if subdir["id"] == entry["target"] ) content.update( dircontent(data, ref, child, os.path.join(prefix, entry["name"])) ) return content def entity_add( storage: ProvenanceStorageInterface, entity: EntityType, ids: Set[Sha1Git] ) -> bool: now = datetime.now(tz=timezone.utc) if entity == EntityType.CONTENT: return storage.content_add({sha1: now for sha1 in ids}) elif entity == EntityType.DIRECTORY: return storage.directory_add( {sha1: DirectoryData(date=now, flat=False) for sha1 in ids} ) else: # entity == EntityType.REVISION: return storage.revision_add( {sha1: RevisionData(date=None, origin=None) for sha1 in ids} ) def relation_add_and_compare_result( storage: ProvenanceStorageInterface, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]], ) -> None: # Source, destinations and locations must be added in advance. src, *_, dst = relation.value.split("_") srcs = {sha1 for sha1 in data} if src != "origin": assert entity_add(storage, EntityType(src), srcs) dsts = {rel.dst for rels in data.values() for rel in rels} if dst != "origin": assert entity_add(storage, EntityType(dst), dsts) if storage.with_path(): assert storage.location_add( {rel.path for rels in data.values() for rel in rels if rel.path is not None} ) assert data assert storage.relation_add(relation, data) for src_sha1 in srcs: relation_compare_result( storage.relation_get(relation, [src_sha1]), {src_sha1: data[src_sha1]}, storage.with_path(), ) for dst_sha1 in dsts: relation_compare_result( storage.relation_get(relation, [dst_sha1], reverse=True), { src_sha1: { RelationData(dst=dst_sha1, path=rel.path) for rel in rels if dst_sha1 == rel.dst } for src_sha1, rels in data.items() if dst_sha1 in {rel.dst for rel in rels} }, storage.with_path(), ) relation_compare_result( storage.relation_get_all(relation), data, storage.with_path() ) def relation_compare_result( computed: Dict[Sha1Git, Set[RelationData]], expected: Dict[Sha1Git, Set[RelationData]], with_path: bool, ) -> None: assert { src_sha1: { RelationData(dst=rel.dst, path=rel.path if with_path else None) for rel in rels } for src_sha1, rels in expected.items() } == computed def test_provenance_storage_relation( provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests relation methods for every `ProvenanceStorageInterface` implementation.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") # Test content-in-revision relation. # Create flat models of every root directory for the revisions in the dataset. cnt_in_rev: Dict[Sha1Git, Set[RelationData]] = {} for rev in data["revision"]: root = next( subdir for subdir in data["directory"] if subdir["id"] == rev["directory"] ) for cnt, rel in dircontent(data, rev["id"], root): cnt_in_rev.setdefault(cnt, set()).add(rel) relation_add_and_compare_result( provenance_storage, RelationType.CNT_EARLY_IN_REV, cnt_in_rev ) # Test content-in-directory relation. # Create flat models for every directory in the dataset. cnt_in_dir: Dict[Sha1Git, Set[RelationData]] = {} for dir in data["directory"]: for cnt, rel in dircontent(data, dir["id"], dir): cnt_in_dir.setdefault(cnt, set()).add(rel) relation_add_and_compare_result( provenance_storage, RelationType.CNT_IN_DIR, cnt_in_dir ) # Test content-in-directory relation. # Add root directories to their correspondent revision in the dataset. dir_in_rev: Dict[Sha1Git, Set[RelationData]] = {} for rev in data["revision"]: dir_in_rev.setdefault(rev["directory"], set()).add( RelationData(dst=rev["id"], path=b".") ) relation_add_and_compare_result( provenance_storage, RelationType.DIR_IN_REV, dir_in_rev ) # Test revision-in-origin relation. # Origins must be inserted in advance (cannot be done by `entity_add` inside # `relation_add_and_compare_result`). orgs = {Origin(url=org["url"]).id: org["url"] for org in data["origin"]} assert provenance_storage.origin_add(orgs) # Add all revisions that are head of some snapshot branch to the corresponding # origin. rev_in_org: Dict[Sha1Git, Set[RelationData]] = {} for status in data["origin_visit_status"]: if status["snapshot"] is not None: for snapshot in data["snapshot"]: if snapshot["id"] == status["snapshot"]: for branch in snapshot["branches"].values(): if branch["target_type"] == "revision": rev_in_org.setdefault(branch["target"], set()).add( RelationData( dst=Origin(url=status["origin"]).id, path=None, ) ) relation_add_and_compare_result( provenance_storage, RelationType.REV_IN_ORG, rev_in_org ) # Test revision-before-revision relation. # For each revision in the data set add an entry for each parent to the relation. rev_before_rev: Dict[Sha1Git, Set[RelationData]] = {} for rev in data["revision"]: for parent in rev["parents"]: rev_before_rev.setdefault(parent, set()).add( RelationData(dst=rev["id"], path=None) ) relation_add_and_compare_result( provenance_storage, RelationType.REV_BEFORE_REV, rev_before_rev ) def test_provenance_storage_find( provenance: ProvenanceInterface, provenance_storage: ProvenanceStorageInterface, archive: ArchiveInterface, ) -> None: """Tests `content_find_first` and `content_find_all` methods for every `ProvenanceStorageInterface` implementation. """ # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") fill_storage(archive.storage, data) # Test content_find_first and content_find_all, first only executing the # revision-content algorithm, then adding the origin-revision layer. def adapt_result( result: Optional[ProvenanceResult], with_path: bool ) -> Optional[ProvenanceResult]: if result is not None: return ProvenanceResult( result.content, result.revision, result.date, result.origin, result.path if with_path else b"", ) return result # Execute the revision-content algorithm on both storages. revisions = [ RevisionEntry(id=rev["id"], date=ts2dt(rev["date"]), root=rev["directory"]) for rev in data["revision"] ] revision_add(provenance, archive, revisions) revision_add(Provenance(provenance_storage), archive, revisions) assert adapt_result( ProvenanceResult( content=hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494"), revision=hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), date=datetime.fromtimestamp(1000000000.0, timezone.utc), origin=None, path=b"A/B/C/a", ), provenance_storage.with_path(), ) == provenance_storage.content_find_first( hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494") ) for cnt in {cnt["sha1_git"] for cnt in data["content"]}: assert adapt_result( provenance.storage.content_find_first(cnt), provenance_storage.with_path() ) == provenance_storage.content_find_first(cnt) assert { adapt_result(occur, provenance_storage.with_path()) for occur in provenance.storage.content_find_all(cnt) } == set(provenance_storage.content_find_all(cnt)) # Execute the origin-revision algorithm on both storages. origins = [ OriginEntry(url=sta["origin"], snapshot=sta["snapshot"]) for sta in data["origin_visit_status"] if sta["snapshot"] is not None ] origin_add(provenance, archive, origins) origin_add(Provenance(provenance_storage), archive, origins) assert adapt_result( ProvenanceResult( content=hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494"), revision=hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), date=datetime.fromtimestamp(1000000000.0, timezone.utc), origin="https://cmdbts2", path=b"A/B/C/a", ), provenance_storage.with_path(), ) == provenance_storage.content_find_first( hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494") ) for cnt in {cnt["sha1_git"] for cnt in data["content"]}: assert adapt_result( provenance.storage.content_find_first(cnt), provenance_storage.with_path() ) == provenance_storage.content_find_first(cnt) assert { adapt_result(occur, provenance_storage.with_path()) for occur in provenance.storage.content_find_all(cnt) } == set(provenance_storage.content_find_all(cnt)) def test_types(provenance_storage: ProvenanceStorageInterface) -> None: """Checks all methods of ProvenanceStorageInterface are implemented by this backend, and that they have the same signature.""" # Create an instance of the protocol (which cannot be instantiated # directly, so this creates a subclass, then instantiates it) interface = type("_", (ProvenanceStorageInterface,), {})() assert "content_find_first" in dir(interface) missing_methods = [] for meth_name in dir(interface): if meth_name.startswith("_"): continue interface_meth = getattr(interface, meth_name) try: concrete_meth = getattr(provenance_storage, meth_name) except AttributeError: if not getattr(interface_meth, "deprecated_endpoint", False): # The backend is missing a (non-deprecated) endpoint missing_methods.append(meth_name) continue expected_signature = inspect.signature(interface_meth) actual_signature = inspect.signature(concrete_meth) assert expected_signature == actual_signature, meth_name assert missing_methods == [] # If all the assertions above succeed, then this one should too. # But there's no harm in double-checking. # And we could replace the assertions above by this one, but unlike # the assertions above, it doesn't explain what is missing. assert isinstance(provenance_storage, ProvenanceStorageInterface)