diff --git a/requirements.txt b/requirements.txt index 199ddc7..3b87e8d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,11 @@ # Add here external Python modules dependencies, one per line. Module names # should match https://pypi.python.org/pypi names. For the full spec or # dependency lines, see https://pip.readthedocs.org/en/1.1/requirements.html click iso8601 methodtools +mongomock pymongo PyYAML types-click types-PyYAML diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py index 0c43d40..de47753 100644 --- a/swh/provenance/__init__.py +++ b/swh/provenance/__init__.py @@ -1,99 +1,95 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from __future__ import annotations from typing import TYPE_CHECKING import warnings if TYPE_CHECKING: from .archive import ArchiveInterface from .interface import ProvenanceInterface, ProvenanceStorageInterface def get_archive(cls: str, **kwargs) -> ArchiveInterface: """Get an archive object of class ``cls`` with arguments ``args``. Args: cls: archive's class, either 'api' or 'direct' args: dictionary of arguments passed to the archive class constructor Returns: an instance of archive object (either using swh.storage API or direct queries to the archive's database) Raises: :cls:`ValueError` if passed an unknown archive class. """ if cls == "api": from swh.storage import get_storage from .storage.archive import ArchiveStorage return ArchiveStorage(get_storage(**kwargs["storage"])) elif cls == "direct": from swh.core.db import BaseDb from .postgresql.archive import ArchivePostgreSQL return ArchivePostgreSQL(BaseDb.connect(**kwargs["db"]).conn) else: raise ValueError def get_provenance(**kwargs) -> ProvenanceInterface: """Get an provenance object with arguments ``args``. Args: args: dictionary of arguments to retrieve a swh.provenance.storage class (see :func:`get_provenance_storage` for details) Returns: an instance of provenance object """ from .provenance import Provenance return Provenance(get_provenance_storage(**kwargs)) def get_provenance_storage(cls: str, **kwargs) -> ProvenanceStorageInterface: """Get an archive object of class ``cls`` with arguments ``args``. Args: cls: storage's class, only 'local' is currently supported args: dictionary of arguments passed to the storage class constructor Returns: an instance of storage object Raises: :cls:`ValueError` if passed an unknown archive class. """ if cls in ["local", "postgresql"]: - from swh.core.db import BaseDb - from .postgresql.provenance import ProvenanceStoragePostgreSql if cls == "local": warnings.warn( '"local" class is deprecated for provenance storage, please ' 'use "postgresql" class instead.', DeprecationWarning, ) - conn = BaseDb.connect(**kwargs["db"]).conn raise_on_commit = kwargs.get("raise_on_commit", False) - return ProvenanceStoragePostgreSql(conn, raise_on_commit) + return ProvenanceStoragePostgreSql( + raise_on_commit=raise_on_commit, **kwargs["db"] + ) elif cls == "mongodb": - from pymongo import MongoClient - from .mongo.backend import ProvenanceStorageMongoDb - dbname = kwargs["db"].pop("dbname") - db = MongoClient(**kwargs["db"]).get_database(dbname) - return ProvenanceStorageMongoDb(db) + engine = kwargs.get("engine", "pymongo") + return ProvenanceStorageMongoDb(engine=engine, **kwargs["db"]) raise ValueError diff --git a/swh/provenance/cli.py b/swh/provenance/cli.py index a4f8905..96359e3 100644 --- a/swh/provenance/cli.py +++ b/swh/provenance/cli.py @@ -1,238 +1,238 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information # WARNING: do not import unnecessary things here to keep cli startup time under # control from datetime import datetime, timezone import os from typing import Any, Dict, Generator, Optional, Tuple import click import iso8601 import yaml from swh.core import config from swh.core.cli import CONTEXT_SETTINGS from swh.core.cli import swh as swh_cli_group from swh.model.hashutil import hash_to_bytes, hash_to_hex from swh.model.model import Sha1Git # All generic config code should reside in swh.core.config CONFIG_ENVVAR = "SWH_CONFIG_FILENAME" DEFAULT_PATH = os.environ.get(CONFIG_ENVVAR, None) DEFAULT_CONFIG: Dict[str, Any] = { "provenance": { "archive": { # Storage API based Archive object # "cls": "api", # "storage": { # "cls": "remote", # "url": "http://uffizi.internal.softwareheritage.org:5002", # } # Direct access Archive object "cls": "direct", "db": { "host": "db.internal.softwareheritage.org", "dbname": "softwareheritage", "user": "guest", }, }, "storage": { # Local PostgreSQL Storage "cls": "postgresql", "db": { "host": "localhost", "user": "postgres", "password": "postgres", "dbname": "provenance", }, # Local MongoDB Storage # "cls": "mongodb", # "db": { # "dbname": "provenance", # }, }, } } CONFIG_FILE_HELP = f""" \b Configuration can be loaded from a yaml file given either as --config-file option or the {CONFIG_ENVVAR} environment variable. If no configuration file is specified, use the following default configuration:: \b {yaml.dump(DEFAULT_CONFIG)}""" PROVENANCE_HELP = f"""Software Heritage provenance index database tools {CONFIG_FILE_HELP} """ @swh_cli_group.group( name="provenance", context_settings=CONTEXT_SETTINGS, help=PROVENANCE_HELP ) @click.option( "-C", "--config-file", default=None, type=click.Path(exists=True, dir_okay=False, path_type=str), help="""YAML configuration file.""", ) @click.option( "-P", "--profile", default=None, type=click.Path(exists=False, dir_okay=False, path_type=str), help="""Enable profiling to specified file.""", ) @click.pass_context def cli(ctx: click.core.Context, config_file: Optional[str], profile: str) -> None: if ( config_file is None and DEFAULT_PATH is not None and config.config_exists(DEFAULT_PATH) ): config_file = DEFAULT_PATH if config_file is None: conf = DEFAULT_CONFIG else: # read_raw_config do not fail on ENOENT if not os.path.exists(config_file): raise FileNotFoundError(config_file) conf = yaml.safe_load(open(config_file, "rb")) ctx.ensure_object(dict) ctx.obj["config"] = conf if profile: import atexit import cProfile print("Profiling...") pr = cProfile.Profile() pr.enable() def exit() -> None: pr.disable() pr.dump_stats(profile) atexit.register(exit) @cli.command(name="iter-revisions") @click.argument("filename") @click.option("-a", "--track-all", default=True, type=bool) @click.option("-l", "--limit", type=int) @click.option("-m", "--min-depth", default=1, type=int) @click.option("-r", "--reuse", default=True, type=bool) @click.pass_context def iter_revisions( ctx: click.core.Context, filename: str, track_all: bool, limit: Optional[int], min_depth: int, reuse: bool, ) -> None: # TODO: add file size filtering """Process a provided list of revisions.""" from . import get_archive, get_provenance from .revision import CSVRevisionIterator, revision_add archive = get_archive(**ctx.obj["config"]["provenance"]["archive"]) - provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"]) revisions_provider = generate_revision_tuples(filename) revisions = CSVRevisionIterator(revisions_provider, limit=limit) - for revision in revisions: - revision_add( - provenance, - archive, - [revision], - trackall=track_all, - lower=reuse, - mindepth=min_depth, - ) + with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance: + for revision in revisions: + revision_add( + provenance, + archive, + [revision], + trackall=track_all, + lower=reuse, + mindepth=min_depth, + ) def generate_revision_tuples( filename: str, ) -> Generator[Tuple[Sha1Git, datetime, Sha1Git], None, None]: for line in open(filename, "r"): if line.strip(): revision, date, root = line.strip().split(",") yield ( hash_to_bytes(revision), iso8601.parse_date(date, default_timezone=timezone.utc), hash_to_bytes(root), ) @cli.command(name="iter-origins") @click.argument("filename") @click.option("-l", "--limit", type=int) @click.pass_context def iter_origins(ctx: click.core.Context, filename: str, limit: Optional[int]) -> None: """Process a provided list of origins.""" from . import get_archive, get_provenance from .origin import CSVOriginIterator, origin_add archive = get_archive(**ctx.obj["config"]["provenance"]["archive"]) - provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"]) origins_provider = generate_origin_tuples(filename) origins = CSVOriginIterator(origins_provider, limit=limit) - for origin in origins: - origin_add(provenance, archive, [origin]) + with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance: + for origin in origins: + origin_add(provenance, archive, [origin]) def generate_origin_tuples(filename: str) -> Generator[Tuple[str, bytes], None, None]: for line in open(filename, "r"): if line.strip(): url, snapshot = line.strip().split(",") yield (url, hash_to_bytes(snapshot)) @cli.command(name="find-first") @click.argument("swhid") @click.pass_context def find_first(ctx: click.core.Context, swhid: str) -> None: """Find first occurrence of the requested blob.""" from . import get_provenance - provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"]) - occur = provenance.content_find_first(hash_to_bytes(swhid)) - if occur is not None: - print( - f"swh:1:cnt:{hash_to_hex(occur.content)}, " - f"swh:1:rev:{hash_to_hex(occur.revision)}, " - f"{occur.date}, " - f"{occur.origin}, " - f"{os.fsdecode(occur.path)}" - ) - else: - print(f"Cannot find a content with the id {swhid}") + with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance: + occur = provenance.content_find_first(hash_to_bytes(swhid)) + if occur is not None: + print( + f"swh:1:cnt:{hash_to_hex(occur.content)}, " + f"swh:1:rev:{hash_to_hex(occur.revision)}, " + f"{occur.date}, " + f"{occur.origin}, " + f"{os.fsdecode(occur.path)}" + ) + else: + print(f"Cannot find a content with the id {swhid}") @cli.command(name="find-all") @click.argument("swhid") @click.option("-l", "--limit", type=int) @click.pass_context def find_all(ctx: click.core.Context, swhid: str, limit: Optional[int]) -> None: """Find all occurrences of the requested blob.""" from . import get_provenance - provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"]) - for occur in provenance.content_find_all(hash_to_bytes(swhid), limit=limit): - print( - f"swh:1:cnt:{hash_to_hex(occur.content)}, " - f"swh:1:rev:{hash_to_hex(occur.revision)}, " - f"{occur.date}, " - f"{occur.origin}, " - f"{os.fsdecode(occur.path)}" - ) + with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance: + for occur in provenance.content_find_all(hash_to_bytes(swhid), limit=limit): + print( + f"swh:1:cnt:{hash_to_hex(occur.content)}, " + f"swh:1:rev:{hash_to_hex(occur.revision)}, " + f"{occur.date}, " + f"{occur.origin}, " + f"{os.fsdecode(occur.path)}" + ) diff --git a/swh/provenance/interface.py b/swh/provenance/interface.py index d04e858..430876c 100644 --- a/swh/provenance/interface.py +++ b/swh/provenance/interface.py @@ -1,331 +1,376 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from __future__ import annotations + from dataclasses import dataclass from datetime import datetime import enum -from typing import Dict, Generator, Iterable, Optional, Set, Union +from types import TracebackType +from typing import Dict, Generator, Iterable, Optional, Set, Type, Union from typing_extensions import Protocol, runtime_checkable from swh.core.api import remote_api_endpoint from swh.model.model import Sha1Git from .model import DirectoryEntry, FileEntry, OriginEntry, RevisionEntry class EntityType(enum.Enum): CONTENT = "content" DIRECTORY = "directory" REVISION = "revision" ORIGIN = "origin" class RelationType(enum.Enum): CNT_EARLY_IN_REV = "content_in_revision" CNT_IN_DIR = "content_in_directory" DIR_IN_REV = "directory_in_revision" REV_IN_ORG = "revision_in_origin" REV_BEFORE_REV = "revision_before_revision" @dataclass(eq=True, frozen=True) class ProvenanceResult: content: Sha1Git revision: Sha1Git date: datetime origin: Optional[str] path: bytes @dataclass(eq=True, frozen=True) class RevisionData: """Object representing the data associated to a revision in the provenance model, where `date` is the optional date of the revision (specifying it acknowledges that the revision was already processed by the revision-content algorithm); and `origin` identifies the preferred origin for the revision, if any. """ date: Optional[datetime] origin: Optional[Sha1Git] @dataclass(eq=True, frozen=True) class RelationData: """Object representing a relation entry in the provenance model, where `src` and `dst` are the sha1 ids of the entities being related, and `path` is optional depending on the relation being represented. """ dst: Sha1Git path: Optional[bytes] @runtime_checkable class ProvenanceStorageInterface(Protocol): + def __enter__(self) -> ProvenanceStorageInterface: + ... + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + ... + + @remote_api_endpoint("close") + def close(self) -> None: + """Close connection to the storage and release resources.""" + ... + @remote_api_endpoint("content_add") def content_add( self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] ) -> bool: """Add blobs identified by sha1 ids, with an optional associated date (as paired in `cnts`) to the provenance storage. Return a boolean stating whether the information was successfully stored. """ ... @remote_api_endpoint("content_find_first") def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: """Retrieve the first occurrence of the blob identified by `id`.""" ... @remote_api_endpoint("content_find_all") def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: """Retrieve all the occurrences of the blob identified by `id`.""" ... @remote_api_endpoint("content_get") def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: """Retrieve the associated date for each blob sha1 in `ids`. If some blob has no associated date, it is not present in the resulting dictionary. """ ... @remote_api_endpoint("directory_add") def directory_add( self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] ) -> bool: """Add directories identified by sha1 ids, with an optional associated date (as paired in `dirs`) to the provenance storage. Return a boolean stating if the information was successfully stored. """ ... @remote_api_endpoint("directory_get") def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: """Retrieve the associated date for each directory sha1 in `ids`. If some directory has no associated date, it is not present in the resulting dictionary. """ ... @remote_api_endpoint("entity_get_all") def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: """Retrieve all sha1 ids for entities of type `entity` present in the provenance model. This method is used only in tests. """ ... @remote_api_endpoint("location_add") def location_add(self, paths: Iterable[bytes]) -> bool: """Register the given `paths` in the storage.""" ... @remote_api_endpoint("location_get_all") def location_get_all(self) -> Set[bytes]: """Retrieve all paths present in the provenance model. This method is used only in tests.""" ... + @remote_api_endpoint("open") + def open(self) -> None: + """Open connection to the storage and allocate necessary resources.""" + ... + @remote_api_endpoint("origin_add") def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: """Add origins identified by sha1 ids, with their corresponding url (as paired in `orgs`) to the provenance storage. Return a boolean stating if the information was successfully stored. """ ... @remote_api_endpoint("origin_get") def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: """Retrieve the associated url for each origin sha1 in `ids`.""" ... @remote_api_endpoint("revision_add") def revision_add( self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] ) -> bool: """Add revisions identified by sha1 ids, with optional associated date or origin (as paired in `revs`) to the provenance storage. Return a boolean stating if the information was successfully stored. """ ... @remote_api_endpoint("revision_get") def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: """Retrieve the associated date and origin for each revision sha1 in `ids`. If some revision has no associated date nor origin, it is not present in the resulting dictionary. """ ... @remote_api_endpoint("relation_add") def relation_add( self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] ) -> bool: """Add entries in the selected `relation`. This method assumes all entities being related are already registered in the storage. See `content_add`, `directory_add`, `origin_add`, and `revision_add`. """ ... @remote_api_endpoint("relation_get") def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Dict[Sha1Git, Set[RelationData]]: """Retrieve all entries in the selected `relation` whose source entities are identified by some sha1 id in `ids`. If `reverse` is set, destination entities are matched instead. """ ... @remote_api_endpoint("relation_get_all") def relation_get_all( self, relation: RelationType ) -> Dict[Sha1Git, Set[RelationData]]: """Retrieve all entries in the selected `relation` that are present in the provenance model. This method is used only in tests. """ ... @remote_api_endpoint("with_path") def with_path(self) -> bool: ... @runtime_checkable class ProvenanceInterface(Protocol): storage: ProvenanceStorageInterface + def __enter__(self) -> ProvenanceInterface: + ... + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + ... + + def close(self) -> None: + """Close connection to the underlying `storage` and release resources.""" + ... + def flush(self) -> None: """Flush internal cache to the underlying `storage`.""" ... def content_add_to_directory( self, directory: DirectoryEntry, blob: FileEntry, prefix: bytes ) -> None: """Associate `blob` with `directory` in the provenance model. `prefix` is the relative path from `directory` to `blob` (excluding `blob`'s name). """ ... def content_add_to_revision( self, revision: RevisionEntry, blob: FileEntry, prefix: bytes ) -> None: """Associate `blob` with `revision` in the provenance model. `prefix` is the absolute path from `revision`'s root directory to `blob` (excluding `blob`'s name). """ ... def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: """Retrieve the first occurrence of the blob identified by `id`.""" ... def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: """Retrieve all the occurrences of the blob identified by `id`.""" ... def content_get_early_date(self, blob: FileEntry) -> Optional[datetime]: """Retrieve the earliest known date of `blob`.""" ... def content_get_early_dates( self, blobs: Iterable[FileEntry] ) -> Dict[Sha1Git, datetime]: """Retrieve the earliest known date for each blob in `blobs`. If some blob has no associated date, it is not present in the resulting dictionary. """ ... def content_set_early_date(self, blob: FileEntry, date: datetime) -> None: """Associate `date` to `blob` as it's earliest known date.""" ... def directory_add_to_revision( self, revision: RevisionEntry, directory: DirectoryEntry, path: bytes ) -> None: """Associate `directory` with `revision` in the provenance model. `path` is the absolute path from `revision`'s root directory to `directory` (including `directory`'s name). """ ... def directory_get_date_in_isochrone_frontier( self, directory: DirectoryEntry ) -> Optional[datetime]: """Retrieve the earliest known date of `directory` as an isochrone frontier in the provenance model. """ ... def directory_get_dates_in_isochrone_frontier( self, dirs: Iterable[DirectoryEntry] ) -> Dict[Sha1Git, datetime]: """Retrieve the earliest known date for each directory in `dirs` as isochrone frontiers provenance model. If some directory has no associated date, it is not present in the resulting dictionary. """ ... def directory_set_date_in_isochrone_frontier( self, directory: DirectoryEntry, date: datetime ) -> None: """Associate `date` to `directory` as it's earliest known date as an isochrone frontier in the provenance model. """ ... + def open(self) -> None: + """Open connection to the underlying `storage` and allocate necessary + resources. + """ + ... + def origin_add(self, origin: OriginEntry) -> None: """Add `origin` to the provenance model.""" ... def revision_add(self, revision: RevisionEntry) -> None: """Add `revision` to the provenance model. This implies storing `revision`'s date in the model, thus `revision.date` must be a valid date. """ ... def revision_add_before_revision( self, head: RevisionEntry, revision: RevisionEntry ) -> None: """Associate `revision` to `head` as an ancestor of the latter.""" ... def revision_add_to_origin( self, origin: OriginEntry, revision: RevisionEntry ) -> None: """Associate `revision` to `origin` as a head revision of the latter (ie. the target of an snapshot for `origin` in the archive).""" ... def revision_get_date(self, revision: RevisionEntry) -> Optional[datetime]: """Retrieve the date associated to `revision`.""" ... def revision_get_preferred_origin( self, revision: RevisionEntry ) -> Optional[Sha1Git]: """Retrieve the preferred origin associated to `revision`.""" ... def revision_in_history(self, revision: RevisionEntry) -> bool: """Check if `revision` is known to be an ancestor of some head revision in the provenance model. """ ... def revision_set_preferred_origin( self, origin: OriginEntry, revision: RevisionEntry ) -> None: """Associate `origin` as the preferred origin for `revision`.""" ... def revision_visited(self, revision: RevisionEntry) -> bool: """Check if `revision` is known to be a head revision for some origin in the provenance model. """ ... diff --git a/swh/provenance/mongo/backend.py b/swh/provenance/mongo/backend.py index 664fa45..0f4cbb5 100644 --- a/swh/provenance/mongo/backend.py +++ b/swh/provenance/mongo/backend.py @@ -1,460 +1,489 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from __future__ import annotations + from datetime import datetime, timezone import os -from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Union +from types import TracebackType +from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Type, Union from bson import ObjectId -import pymongo.database +import mongomock +import pymongo from swh.model.model import Sha1Git from ..interface import ( EntityType, ProvenanceResult, + ProvenanceStorageInterface, RelationData, RelationType, RevisionData, ) class ProvenanceStorageMongoDb: - def __init__(self, db: pymongo.database.Database): - self.db = db + def __init__(self, engine: str, **kwargs): + self.engine = engine + self.dbname = kwargs.pop("dbname") + self.conn_args = kwargs + + def __enter__(self) -> ProvenanceStorageInterface: + self.open() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + + def close(self) -> None: + self.db.client.close() def content_add( self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] ) -> bool: data = cnts if isinstance(cnts, dict) else dict.fromkeys(cnts) existing = { x["sha1"]: x for x in self.db.content.find( {"sha1": {"$in": list(data)}}, {"sha1": 1, "ts": 1, "_id": 1} ) } for sha1, date in data.items(): ts = datetime.timestamp(date) if date is not None else None if sha1 in existing: cnt = existing[sha1] if ts is not None and (cnt["ts"] is None or ts < cnt["ts"]): self.db.content.update_one( {"_id": cnt["_id"]}, {"$set": {"ts": ts}} ) else: self.db.content.insert_one( { "sha1": sha1, "ts": ts, "revision": {}, "directory": {}, } ) return True def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: # get all the revisions # iterate and find the earliest content = self.db.content.find_one({"sha1": id}) if not content: return None occurs = [] for revision in self.db.revision.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["revision"]]}} ): if revision["preferred"] is not None: origin = self.db.origin.find_one({"sha1": revision["preferred"]}) else: origin = {"url": None} for path in content["revision"][str(revision["_id"])]: occurs.append( ProvenanceResult( content=id, revision=revision["sha1"], date=datetime.fromtimestamp(revision["ts"], timezone.utc), origin=origin["url"], path=path, ) ) return sorted(occurs, key=lambda x: (x.date, x.revision, x.origin, x.path))[0] def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: content = self.db.content.find_one({"sha1": id}) if not content: return None occurs = [] for revision in self.db.revision.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["revision"]]}} ): if revision["preferred"] is not None: origin = self.db.origin.find_one({"sha1": revision["preferred"]}) else: origin = {"url": None} for path in content["revision"][str(revision["_id"])]: occurs.append( ProvenanceResult( content=id, revision=revision["sha1"], date=datetime.fromtimestamp(revision["ts"], timezone.utc), origin=origin["url"], path=path, ) ) for directory in self.db.directory.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["directory"]]}} ): for revision in self.db.revision.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in directory["revision"]]}} ): if revision["preferred"] is not None: origin = self.db.origin.find_one({"sha1": revision["preferred"]}) else: origin = {"url": None} for suffix in content["directory"][str(directory["_id"])]: for prefix in directory["revision"][str(revision["_id"])]: path = ( os.path.join(prefix, suffix) if prefix not in [b".", b""] else suffix ) occurs.append( ProvenanceResult( content=id, revision=revision["sha1"], date=datetime.fromtimestamp( revision["ts"], timezone.utc ), origin=origin["url"], path=path, ) ) yield from sorted(occurs, key=lambda x: (x.date, x.revision, x.origin, x.path)) def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return { x["sha1"]: datetime.fromtimestamp(x["ts"], timezone.utc) for x in self.db.content.find( {"sha1": {"$in": list(ids)}, "ts": {"$ne": None}}, {"sha1": 1, "ts": 1, "_id": 0}, ) } def directory_add( self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] ) -> bool: data = dirs if isinstance(dirs, dict) else dict.fromkeys(dirs) existing = { x["sha1"]: x for x in self.db.directory.find( {"sha1": {"$in": list(data)}}, {"sha1": 1, "ts": 1, "_id": 1} ) } for sha1, date in data.items(): ts = datetime.timestamp(date) if date is not None else None if sha1 in existing: dir = existing[sha1] if ts is not None and (dir["ts"] is None or ts < dir["ts"]): self.db.directory.update_one( {"_id": dir["_id"]}, {"$set": {"ts": ts}} ) else: self.db.directory.insert_one({"sha1": sha1, "ts": ts, "revision": {}}) return True def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return { x["sha1"]: datetime.fromtimestamp(x["ts"], timezone.utc) for x in self.db.directory.find( {"sha1": {"$in": list(ids)}, "ts": {"$ne": None}}, {"sha1": 1, "ts": 1, "_id": 0}, ) } def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: return { x["sha1"] for x in self.db.get_collection(entity.value).find( {}, {"sha1": 1, "_id": 0} ) } def location_add(self, paths: Iterable[bytes]) -> bool: # TODO: implement this methods if path are to be stored in a separate collection return True def location_get_all(self) -> Set[bytes]: contents = self.db.content.find({}, {"revision": 1, "_id": 0, "directory": 1}) paths: List[Iterable[bytes]] = [] for content in contents: paths.extend(value for _, value in content["revision"].items()) paths.extend(value for _, value in content["directory"].items()) dirs = self.db.directory.find({}, {"revision": 1, "_id": 0}) for each_dir in dirs: paths.extend(value for _, value in each_dir["revision"].items()) return set(sum(paths, [])) + def open(self) -> None: + if self.engine == "mongomock": + self.db = mongomock.MongoClient(**self.conn_args).get_database(self.dbname) + else: + # assume real MongoDB server by default + self.db = pymongo.MongoClient(**self.conn_args).get_database(self.dbname) + def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: existing = { x["sha1"]: x for x in self.db.origin.find( {"sha1": {"$in": list(orgs)}}, {"sha1": 1, "url": 1, "_id": 1} ) } for sha1, url in orgs.items(): if sha1 not in existing: # add new origin self.db.origin.insert_one({"sha1": sha1, "url": url}) return True def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: return { x["sha1"]: x["url"] for x in self.db.origin.find( {"sha1": {"$in": list(ids)}}, {"sha1": 1, "url": 1, "_id": 0} ) } def revision_add( self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] ) -> bool: data = ( revs if isinstance(revs, dict) else dict.fromkeys(revs, RevisionData(date=None, origin=None)) ) existing = { x["sha1"]: x for x in self.db.revision.find( {"sha1": {"$in": list(data)}}, {"sha1": 1, "ts": 1, "preferred": 1, "_id": 1}, ) } for sha1, info in data.items(): ts = datetime.timestamp(info.date) if info.date is not None else None preferred = info.origin if sha1 in existing: rev = existing[sha1] if ts is None or (rev["ts"] is not None and ts >= rev["ts"]): ts = rev["ts"] if preferred is None: preferred = rev["preferred"] if ts != rev["ts"] or preferred != rev["preferred"]: self.db.revision.update_one( {"_id": rev["_id"]}, {"$set": {"ts": ts, "preferred": preferred}}, ) else: self.db.revision.insert_one( { "sha1": sha1, "preferred": preferred, "origin": [], "revision": [], "ts": ts, } ) return True def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: return { x["sha1"]: RevisionData( date=datetime.fromtimestamp(x["ts"], timezone.utc) if x["ts"] else None, origin=x["preferred"], ) for x in self.db.revision.find( { "sha1": {"$in": list(ids)}, "$or": [{"preferred": {"$ne": None}}, {"ts": {"$ne": None}}], }, {"sha1": 1, "preferred": 1, "ts": 1, "_id": 0}, ) } def relation_add( self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] ) -> bool: src_relation, *_, dst_relation = relation.value.split("_") dst_objs = { x["sha1"]: x["_id"] for x in self.db.get_collection(dst_relation).find( { "sha1": { "$in": list({rel.dst for rels in data.values() for rel in rels}) } }, {"_id": 1, "sha1": 1}, ) } denorm: Dict[Sha1Git, Any] = {} for src, rels in data.items(): for rel in rels: if src_relation != "revision": denorm.setdefault(src, {}).setdefault( str(dst_objs[rel.dst]), [] ).append(rel.path) else: denorm.setdefault(src, []).append(dst_objs[rel.dst]) src_objs = { x["sha1"]: x for x in self.db.get_collection(src_relation).find( {"sha1": {"$in": list(denorm.keys())}} ) } for sha1, dsts in denorm.items(): # update if src_relation != "revision": k = { obj_id: list(set(paths + dsts.get(obj_id, []))) for obj_id, paths in src_objs[sha1][dst_relation].items() } self.db.get_collection(src_relation).update_one( {"_id": src_objs[sha1]["_id"]}, {"$set": {dst_relation: dict(dsts, **k)}}, ) else: self.db.get_collection(src_relation).update_one( {"_id": src_objs[sha1]["_id"]}, { "$set": { dst_relation: list(set(src_objs[sha1][dst_relation] + dsts)) } }, ) return True def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Dict[Sha1Git, Set[RelationData]]: src, *_, dst = relation.value.split("_") sha1s = set(ids) if not reverse: empty: Union[Dict[str, bytes], List[str]] = {} if src != "revision" else [] src_objs = { x["sha1"]: x[dst] for x in self.db.get_collection(src).find( {"sha1": {"$in": list(sha1s)}, dst: {"$ne": empty}}, {"_id": 0, "sha1": 1, dst: 1}, ) } dst_ids = list( {ObjectId(obj_id) for _, value in src_objs.items() for obj_id in value} ) dst_objs = { x["sha1"]: x["_id"] for x in self.db.get_collection(dst).find( {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} ) } if src != "revision": return { src_sha1: { RelationData(dst=dst_sha1, path=path) for dst_sha1, dst_obj_id in dst_objs.items() for dst_obj_str, paths in denorm.items() for path in paths if dst_obj_id == ObjectId(dst_obj_str) } for src_sha1, denorm in src_objs.items() } else: return { src_sha1: { RelationData(dst=dst_sha1, path=None) for dst_sha1, dst_obj_id in dst_objs.items() for dst_obj_ref in denorm if dst_obj_id == dst_obj_ref } for src_sha1, denorm in src_objs.items() } else: dst_objs = { x["sha1"]: x["_id"] for x in self.db.get_collection(dst).find( {"sha1": {"$in": list(sha1s)}}, {"_id": 1, "sha1": 1} ) } src_objs = { x["sha1"]: x[dst] for x in self.db.get_collection(src).find( {}, {"_id": 0, "sha1": 1, dst: 1} ) } result: Dict[Sha1Git, Set[RelationData]] = {} if src != "revision": for dst_sha1, dst_obj_id in dst_objs.items(): for src_sha1, denorm in src_objs.items(): for dst_obj_str, paths in denorm.items(): if dst_obj_id == ObjectId(dst_obj_str): result.setdefault(src_sha1, set()).update( RelationData(dst=dst_sha1, path=path) for path in paths ) else: for dst_sha1, dst_obj_id in dst_objs.items(): for src_sha1, denorm in src_objs.items(): if dst_obj_id in { ObjectId(dst_obj_str) for dst_obj_str in denorm }: result.setdefault(src_sha1, set()).add( RelationData(dst=dst_sha1, path=None) ) return result def relation_get_all( self, relation: RelationType ) -> Dict[Sha1Git, Set[RelationData]]: src, *_, dst = relation.value.split("_") empty: Union[Dict[str, bytes], List[str]] = {} if src != "revision" else [] src_objs = { x["sha1"]: x[dst] for x in self.db.get_collection(src).find( {dst: {"$ne": empty}}, {"_id": 0, "sha1": 1, dst: 1} ) } dst_ids = list( {ObjectId(obj_id) for _, value in src_objs.items() for obj_id in value} ) dst_objs = { x["_id"]: x["sha1"] for x in self.db.get_collection(dst).find( {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} ) } if src != "revision": return { src_sha1: { RelationData(dst=dst_sha1, path=path) for dst_obj_id, dst_sha1 in dst_objs.items() for dst_obj_str, paths in denorm.items() for path in paths if dst_obj_id == ObjectId(dst_obj_str) } for src_sha1, denorm in src_objs.items() } else: return { src_sha1: { RelationData(dst=dst_sha1, path=None) for dst_obj_id, dst_sha1 in dst_objs.items() for dst_obj_ref in denorm if dst_obj_id == dst_obj_ref } for src_sha1, denorm in src_objs.items() } def with_path(self) -> bool: return True diff --git a/swh/provenance/postgresql/provenance.py b/swh/provenance/postgresql/provenance.py index 358af8f..dd25bc1 100644 --- a/swh/provenance/postgresql/provenance.py +++ b/swh/provenance/postgresql/provenance.py @@ -1,322 +1,342 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from __future__ import annotations + from contextlib import contextmanager from datetime import datetime import itertools import logging -from typing import Dict, Generator, Iterable, List, Optional, Set, Union +from types import TracebackType +from typing import Dict, Generator, Iterable, List, Optional, Set, Type, Union import psycopg2.extensions import psycopg2.extras from typing_extensions import Literal from swh.core.db import BaseDb from swh.model.model import Sha1Git from ..interface import ( EntityType, ProvenanceResult, + ProvenanceStorageInterface, RelationData, RelationType, RevisionData, ) LOGGER = logging.getLogger(__name__) class ProvenanceStoragePostgreSql: - def __init__( - self, conn: psycopg2.extensions.connection, raise_on_commit: bool = False - ) -> None: - BaseDb.adapt_conn(conn) - self.conn = conn - with self.transaction() as cursor: - cursor.execute("SET timezone TO 'UTC'") + def __init__(self, raise_on_commit: bool = False, **kwargs) -> None: + self.conn_args = kwargs self._flavor: Optional[str] = None self.raise_on_commit = raise_on_commit + def __enter__(self) -> ProvenanceStorageInterface: + self.open() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + @contextmanager def transaction( self, readonly: bool = False ) -> Generator[psycopg2.extensions.cursor, None, None]: self.conn.set_session(readonly=readonly) with self.conn: with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: yield cur @property def flavor(self) -> str: if self._flavor is None: with self.transaction(readonly=True) as cursor: cursor.execute("SELECT swh_get_dbflavor() AS flavor") self._flavor = cursor.fetchone()["flavor"] assert self._flavor is not None return self._flavor @property def denormalized(self) -> bool: return "denormalized" in self.flavor + def close(self) -> None: + self.conn.close() + def content_add( self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] ) -> bool: return self._entity_set_date("content", cnts) def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: sql = "SELECT * FROM swh_provenance_content_find_first(%s)" with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=(id,)) row = cursor.fetchone() return ProvenanceResult(**row) if row is not None else None def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: sql = "SELECT * FROM swh_provenance_content_find_all(%s, %s)" with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=(id, limit)) yield from (ProvenanceResult(**row) for row in cursor) def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self._entity_get_date("content", ids) def directory_add( self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] ) -> bool: return self._entity_set_date("directory", dirs) def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self._entity_get_date("directory", ids) def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: with self.transaction(readonly=True) as cursor: cursor.execute(f"SELECT sha1 FROM {entity.value}") return {row["sha1"] for row in cursor} def location_add(self, paths: Iterable[bytes]) -> bool: if not self.with_path(): return True try: values = [(path,) for path in paths] if values: sql = """ INSERT INTO location(path) VALUES %s ON CONFLICT DO NOTHING """ with self.transaction() as cursor: psycopg2.extras.execute_values(cursor, sql, argslist=values) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False def location_get_all(self) -> Set[bytes]: with self.transaction(readonly=True) as cursor: cursor.execute("SELECT location.path AS path FROM location") return {row["path"] for row in cursor} def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: try: if orgs: sql = """ INSERT INTO origin(sha1, url) VALUES %s ON CONFLICT DO NOTHING """ with self.transaction() as cursor: psycopg2.extras.execute_values( cur=cursor, sql=sql, argslist=orgs.items() ) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False + def open(self) -> None: + self.conn = BaseDb.connect(**self.conn_args).conn + BaseDb.adapt_conn(self.conn) + with self.transaction() as cursor: + cursor.execute("SET timezone TO 'UTC'") + def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: urls: Dict[Sha1Git, str] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, url FROM origin WHERE sha1 IN ({values}) """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) urls.update((row["sha1"], row["url"]) for row in cursor) return urls def revision_add( self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] ) -> bool: if isinstance(revs, dict): data = [(sha1, rev.date, rev.origin) for sha1, rev in revs.items()] else: data = [(sha1, None, None) for sha1 in revs] try: if data: sql = """ INSERT INTO revision(sha1, date, origin) (SELECT V.rev AS sha1, V.date::timestamptz AS date, O.id AS origin FROM (VALUES %s) AS V(rev, date, org) LEFT JOIN origin AS O ON (O.sha1=V.org::sha1_git)) ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date, revision.date), origin=COALESCE(EXCLUDED.origin, revision.origin) """ with self.transaction() as cursor: psycopg2.extras.execute_values(cur=cursor, sql=sql, argslist=data) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: result: Dict[Sha1Git, RevisionData] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT R.sha1, R.date, O.sha1 AS origin FROM revision AS R LEFT JOIN origin AS O ON (O.id=R.origin) WHERE R.sha1 IN ({values}) AND (R.date is not NULL OR O.sha1 is not NULL) """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) result.update( (row["sha1"], RevisionData(date=row["date"], origin=row["origin"])) for row in cursor ) return result def relation_add( self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] ) -> bool: rows = [(src, rel.dst, rel.path) for src, dsts in data.items() for rel in dsts] try: if rows: rel_table = relation.value src_table, *_, dst_table = rel_table.split("_") # Put the next three queries in a manual single transaction: # they use the same temp table with self.transaction() as cursor: cursor.execute("SELECT swh_mktemp_relation_add()") psycopg2.extras.execute_values( cur=cursor, sql="INSERT INTO tmp_relation_add(src, dst, path) VALUES %s", argslist=rows, ) sql = "SELECT swh_provenance_relation_add_from_temp(%s, %s, %s)" cursor.execute(query=sql, vars=(rel_table, src_table, dst_table)) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Dict[Sha1Git, Set[RelationData]]: return self._relation_get(relation, ids, reverse) def relation_get_all( self, relation: RelationType ) -> Dict[Sha1Git, Set[RelationData]]: return self._relation_get(relation, None) def _entity_get_date( self, entity: Literal["content", "directory", "revision"], ids: Iterable[Sha1Git], ) -> Dict[Sha1Git, datetime]: dates: Dict[Sha1Git, datetime] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, date FROM {entity} WHERE sha1 IN ({values}) AND date IS NOT NULL """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) dates.update((row["sha1"], row["date"]) for row in cursor) return dates def _entity_set_date( self, entity: Literal["content", "directory"], dates: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]], ) -> bool: data = dates if isinstance(dates, dict) else dict.fromkeys(dates) try: if data: sql = f""" INSERT INTO {entity}(sha1, date) VALUES %s ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date,{entity}.date) """ with self.transaction() as cursor: psycopg2.extras.execute_values(cursor, sql, argslist=data.items()) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False def _relation_get( self, relation: RelationType, ids: Optional[Iterable[Sha1Git]], reverse: bool = False, ) -> Dict[Sha1Git, Set[RelationData]]: result: Dict[Sha1Git, Set[RelationData]] = {} sha1s: List[Sha1Git] if ids is not None: sha1s = list(ids) filter = "filter-src" if not reverse else "filter-dst" else: sha1s = [] filter = "no-filter" if filter == "no-filter" or sha1s: rel_table = relation.value src_table, *_, dst_table = rel_table.split("_") sql = "SELECT * FROM swh_provenance_relation_get(%s, %s, %s, %s, %s)" with self.transaction(readonly=True) as cursor: cursor.execute( query=sql, vars=(rel_table, src_table, dst_table, filter, sha1s) ) for row in cursor: src = row.pop("src") result.setdefault(src, set()).add(RelationData(**row)) return result def with_path(self) -> bool: return "with-path" in self.flavor diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py index 006419a..81a033e 100644 --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -1,392 +1,412 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime import logging import os -from typing import Dict, Generator, Iterable, Optional, Set, Tuple +from types import TracebackType +from typing import Dict, Generator, Iterable, Optional, Set, Tuple, Type from typing_extensions import Literal, TypedDict from swh.model.model import Sha1Git from .interface import ( + ProvenanceInterface, ProvenanceResult, ProvenanceStorageInterface, RelationData, RelationType, RevisionData, ) from .model import DirectoryEntry, FileEntry, OriginEntry, RevisionEntry LOGGER = logging.getLogger(__name__) class DatetimeCache(TypedDict): data: Dict[Sha1Git, Optional[datetime]] added: Set[Sha1Git] class OriginCache(TypedDict): data: Dict[Sha1Git, str] added: Set[Sha1Git] class RevisionCache(TypedDict): data: Dict[Sha1Git, Sha1Git] added: Set[Sha1Git] class ProvenanceCache(TypedDict): content: DatetimeCache directory: DatetimeCache revision: DatetimeCache # below are insertion caches only content_in_revision: Set[Tuple[Sha1Git, Sha1Git, bytes]] content_in_directory: Set[Tuple[Sha1Git, Sha1Git, bytes]] directory_in_revision: Set[Tuple[Sha1Git, Sha1Git, bytes]] # these two are for the origin layer origin: OriginCache revision_origin: RevisionCache revision_before_revision: Dict[Sha1Git, Set[Sha1Git]] revision_in_origin: Set[Tuple[Sha1Git, Sha1Git]] def new_cache() -> ProvenanceCache: return ProvenanceCache( content=DatetimeCache(data={}, added=set()), directory=DatetimeCache(data={}, added=set()), revision=DatetimeCache(data={}, added=set()), content_in_revision=set(), content_in_directory=set(), directory_in_revision=set(), origin=OriginCache(data={}, added=set()), revision_origin=RevisionCache(data={}, added=set()), revision_before_revision={}, revision_in_origin=set(), ) class Provenance: def __init__(self, storage: ProvenanceStorageInterface) -> None: self.storage = storage self.cache = new_cache() + def __enter__(self) -> ProvenanceInterface: + self.open() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + def clear_caches(self) -> None: self.cache = new_cache() + def close(self) -> None: + self.storage.close() + def flush(self) -> None: # Revision-content layer insertions ############################################ # After relations, dates for the entities can be safely set, acknowledging that # these entities won't need to be reprocessed in case of failure. cnts = { src for src, _, _ in self.cache["content_in_revision"] | self.cache["content_in_directory"] } if cnts: while not self.storage.content_add(cnts): LOGGER.warning( "Unable to write content entities to the storage. Retrying..." ) dirs = {dst for _, dst, _ in self.cache["content_in_directory"]} if dirs: while not self.storage.directory_add(dirs): LOGGER.warning( "Unable to write directory entities to the storage. Retrying..." ) revs = { dst for _, dst, _ in self.cache["content_in_revision"] | self.cache["directory_in_revision"] } if revs: while not self.storage.revision_add(revs): LOGGER.warning( "Unable to write revision entities to the storage. Retrying..." ) paths = { path for _, _, path in self.cache["content_in_revision"] | self.cache["content_in_directory"] | self.cache["directory_in_revision"] } if paths: while not self.storage.location_add(paths): LOGGER.warning( "Unable to write locations entities to the storage. Retrying..." ) # For this layer, relations need to be inserted first so that, in case of # failure, reprocessing the input does not generated an inconsistent database. if self.cache["content_in_revision"]: cnt_in_rev: Dict[Sha1Git, Set[RelationData]] = {} for src, dst, path in self.cache["content_in_revision"]: cnt_in_rev.setdefault(src, set()).add(RelationData(dst=dst, path=path)) while not self.storage.relation_add( RelationType.CNT_EARLY_IN_REV, cnt_in_rev ): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.CNT_EARLY_IN_REV, ) if self.cache["content_in_directory"]: cnt_in_dir: Dict[Sha1Git, Set[RelationData]] = {} for src, dst, path in self.cache["content_in_directory"]: cnt_in_dir.setdefault(src, set()).add(RelationData(dst=dst, path=path)) while not self.storage.relation_add(RelationType.CNT_IN_DIR, cnt_in_dir): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.CNT_IN_DIR, ) if self.cache["directory_in_revision"]: dir_in_rev: Dict[Sha1Git, Set[RelationData]] = {} for src, dst, path in self.cache["directory_in_revision"]: dir_in_rev.setdefault(src, set()).add(RelationData(dst=dst, path=path)) while not self.storage.relation_add(RelationType.DIR_IN_REV, dir_in_rev): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.DIR_IN_REV, ) # After relations, dates for the entities can be safely set, acknowledging that # these entities won't need to be reprocessed in case of failure. cnt_dates = { sha1: date for sha1, date in self.cache["content"]["data"].items() if sha1 in self.cache["content"]["added"] and date is not None } if cnt_dates: while not self.storage.content_add(cnt_dates): LOGGER.warning( "Unable to write content dates to the storage. Retrying..." ) dir_dates = { sha1: date for sha1, date in self.cache["directory"]["data"].items() if sha1 in self.cache["directory"]["added"] and date is not None } if dir_dates: while not self.storage.directory_add(dir_dates): LOGGER.warning( "Unable to write directory dates to the storage. Retrying..." ) rev_dates = { sha1: RevisionData(date=date, origin=None) for sha1, date in self.cache["revision"]["data"].items() if sha1 in self.cache["revision"]["added"] and date is not None } if rev_dates: while not self.storage.revision_add(rev_dates): LOGGER.warning( "Unable to write revision dates to the storage. Retrying..." ) # Origin-revision layer insertions ############################################# # Origins and revisions should be inserted first so that internal ids' # resolution works properly. urls = { sha1: url for sha1, url in self.cache["origin"]["data"].items() if sha1 in self.cache["origin"]["added"] } if urls: while not self.storage.origin_add(urls): LOGGER.warning( "Unable to write origins urls to the storage. Retrying..." ) rev_orgs = { # Destinations in this relation should match origins in the next one **{ src: RevisionData(date=None, origin=None) for src in self.cache["revision_before_revision"] }, **{ # This relation comes second so that non-None origins take precedence src: RevisionData(date=None, origin=org) for src, org in self.cache["revision_in_origin"] }, } if rev_orgs: while not self.storage.revision_add(rev_orgs): LOGGER.warning( "Unable to write revision entities to the storage. Retrying..." ) # Second, flat models for revisions' histories (ie. revision-before-revision). if self.cache["revision_before_revision"]: rev_before_rev = { src: {RelationData(dst=dst, path=None) for dst in dsts} for src, dsts in self.cache["revision_before_revision"].items() } while not self.storage.relation_add( RelationType.REV_BEFORE_REV, rev_before_rev ): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.REV_BEFORE_REV, ) # Heads (ie. revision-in-origin entries) should be inserted once flat models for # their histories were already added. This is to guarantee consistent results if # something needs to be reprocessed due to a failure: already inserted heads # won't get reprocessed in such a case. if self.cache["revision_in_origin"]: rev_in_org: Dict[Sha1Git, Set[RelationData]] = {} for src, dst in self.cache["revision_in_origin"]: rev_in_org.setdefault(src, set()).add(RelationData(dst=dst, path=None)) while not self.storage.relation_add(RelationType.REV_IN_ORG, rev_in_org): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.REV_IN_ORG, ) # clear local cache ############################################################ self.clear_caches() def content_add_to_directory( self, directory: DirectoryEntry, blob: FileEntry, prefix: bytes ) -> None: self.cache["content_in_directory"].add( (blob.id, directory.id, normalize(os.path.join(prefix, blob.name))) ) def content_add_to_revision( self, revision: RevisionEntry, blob: FileEntry, prefix: bytes ) -> None: self.cache["content_in_revision"].add( (blob.id, revision.id, normalize(os.path.join(prefix, blob.name))) ) def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: return self.storage.content_find_first(id) def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: yield from self.storage.content_find_all(id, limit=limit) def content_get_early_date(self, blob: FileEntry) -> Optional[datetime]: return self.get_dates("content", [blob.id]).get(blob.id) def content_get_early_dates( self, blobs: Iterable[FileEntry] ) -> Dict[Sha1Git, datetime]: return self.get_dates("content", [blob.id for blob in blobs]) def content_set_early_date(self, blob: FileEntry, date: datetime) -> None: self.cache["content"]["data"][blob.id] = date self.cache["content"]["added"].add(blob.id) def directory_add_to_revision( self, revision: RevisionEntry, directory: DirectoryEntry, path: bytes ) -> None: self.cache["directory_in_revision"].add( (directory.id, revision.id, normalize(path)) ) def directory_get_date_in_isochrone_frontier( self, directory: DirectoryEntry ) -> Optional[datetime]: return self.get_dates("directory", [directory.id]).get(directory.id) def directory_get_dates_in_isochrone_frontier( self, dirs: Iterable[DirectoryEntry] ) -> Dict[Sha1Git, datetime]: return self.get_dates("directory", [directory.id for directory in dirs]) def directory_set_date_in_isochrone_frontier( self, directory: DirectoryEntry, date: datetime ) -> None: self.cache["directory"]["data"][directory.id] = date self.cache["directory"]["added"].add(directory.id) def get_dates( self, entity: Literal["content", "directory", "revision"], ids: Iterable[Sha1Git], ) -> Dict[Sha1Git, datetime]: cache = self.cache[entity] missing_ids = set(id for id in ids if id not in cache) if missing_ids: if entity == "revision": updated = { id: rev.date for id, rev in self.storage.revision_get(missing_ids).items() } else: updated = getattr(self.storage, f"{entity}_get")(missing_ids) cache["data"].update(updated) dates: Dict[Sha1Git, datetime] = {} for sha1 in ids: date = cache["data"].setdefault(sha1, None) if date is not None: dates[sha1] = date return dates + def open(self) -> None: + self.storage.open() + def origin_add(self, origin: OriginEntry) -> None: self.cache["origin"]["data"][origin.id] = origin.url self.cache["origin"]["added"].add(origin.id) def revision_add(self, revision: RevisionEntry) -> None: self.cache["revision"]["data"][revision.id] = revision.date self.cache["revision"]["added"].add(revision.id) def revision_add_before_revision( self, head: RevisionEntry, revision: RevisionEntry ) -> None: self.cache["revision_before_revision"].setdefault(revision.id, set()).add( head.id ) def revision_add_to_origin( self, origin: OriginEntry, revision: RevisionEntry ) -> None: self.cache["revision_in_origin"].add((revision.id, origin.id)) def revision_get_date(self, revision: RevisionEntry) -> Optional[datetime]: return self.get_dates("revision", [revision.id]).get(revision.id) def revision_get_preferred_origin( self, revision: RevisionEntry ) -> Optional[Sha1Git]: cache = self.cache["revision_origin"]["data"] if revision.id not in cache: ret = self.storage.revision_get([revision.id]) if revision.id in ret: origin = ret[revision.id].origin if origin is not None: cache[revision.id] = origin return cache.get(revision.id) def revision_in_history(self, revision: RevisionEntry) -> bool: return revision.id in self.cache["revision_before_revision"] or bool( self.storage.relation_get(RelationType.REV_BEFORE_REV, [revision.id]) ) def revision_set_preferred_origin( self, origin: OriginEntry, revision: RevisionEntry ) -> None: self.cache["revision_origin"]["data"][revision.id] = origin.id self.cache["revision_origin"]["added"].add(revision.id) def revision_visited(self, revision: RevisionEntry) -> bool: return revision.id in dict(self.cache["revision_in_origin"]) or bool( self.storage.relation_get(RelationType.REV_IN_ORG, [revision.id]) ) def normalize(path: bytes) -> bytes: return path[2:] if path.startswith(bytes("." + os.path.sep, "utf-8")) else path diff --git a/swh/provenance/tests/conftest.py b/swh/provenance/tests/conftest.py index f30f34b..f58c543 100644 --- a/swh/provenance/tests/conftest.py +++ b/swh/provenance/tests/conftest.py @@ -1,128 +1,136 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime, timedelta, timezone from os import path -from typing import Any, Dict, Iterable +from typing import Any, Dict, Generator, Iterable from _pytest.fixtures import SubRequest +import mongomock.database import msgpack import psycopg2.extensions -import pymongo.database import pytest from pytest_postgresql.factories import postgresql from swh.journal.serializers import msgpack_ext_hook from swh.provenance import get_provenance, get_provenance_storage from swh.provenance.archive import ArchiveInterface from swh.provenance.interface import ProvenanceInterface, ProvenanceStorageInterface from swh.provenance.storage.archive import ArchiveStorage from swh.storage.interface import StorageInterface from swh.storage.replay import process_replay_objects @pytest.fixture( params=[ "with-path", "without-path", "with-path-denormalized", "without-path-denormalized", ] ) def provenance_postgresqldb( request: SubRequest, postgresql: psycopg2.extensions.connection, ) -> Dict[str, str]: """return a working and initialized provenance db""" from swh.core.cli.db import populate_database_for_package populate_database_for_package( "swh.provenance", postgresql.dsn, flavor=request.param ) return postgresql.get_dsn_parameters() @pytest.fixture(params=["mongodb", "postgresql"]) def provenance_storage( request: SubRequest, provenance_postgresqldb: Dict[str, str], - mongodb: pymongo.database.Database, -) -> ProvenanceStorageInterface: + mongodb: mongomock.database.Database, +) -> Generator[ProvenanceStorageInterface, None, None]: """Return a working and initialized ProvenanceStorageInterface object""" if request.param == "mongodb": - from swh.provenance.mongo.backend import ProvenanceStorageMongoDb - - return ProvenanceStorageMongoDb(mongodb) + mongodb_params = { + "host": mongodb.client.address[0], + "port": mongodb.client.address[1], + "dbname": mongodb.name, + } + with get_provenance_storage( + cls=request.param, db=mongodb_params, engine="mongomock" + ) as storage: + yield storage else: # in test sessions, we DO want to raise any exception occurring at commit time - return get_provenance_storage( + with get_provenance_storage( cls=request.param, db=provenance_postgresqldb, raise_on_commit=True - ) + ) as storage: + yield storage provenance_postgresql = postgresql("postgresql_proc", dbname="provenance_tests") @pytest.fixture def provenance( provenance_postgresql: psycopg2.extensions.connection, -) -> ProvenanceInterface: +) -> Generator[ProvenanceInterface, None, None]: """Return a working and initialized ProvenanceInterface object""" from swh.core.cli.db import populate_database_for_package populate_database_for_package( "swh.provenance", provenance_postgresql.dsn, flavor="with-path" ) # in test sessions, we DO want to raise any exception occurring at commit time - return get_provenance( + with get_provenance( cls="postgresql", db=provenance_postgresql.get_dsn_parameters(), raise_on_commit=True, - ) + ) as provenance: + yield provenance @pytest.fixture def archive(swh_storage: StorageInterface) -> ArchiveInterface: """Return an ArchiveStorage-based ArchiveInterface object""" return ArchiveStorage(swh_storage) def get_datafile(fname: str) -> str: return path.join(path.dirname(__file__), "data", fname) def load_repo_data(repo: str) -> Dict[str, Any]: data: Dict[str, Any] = {} with open(get_datafile(f"{repo}.msgpack"), "rb") as fobj: unpacker = msgpack.Unpacker( fobj, raw=False, ext_hook=msgpack_ext_hook, strict_map_key=False, timestamp=3, # convert Timestamp in datetime objects (tz UTC) ) for objtype, objd in unpacker: data.setdefault(objtype, []).append(objd) return data def filter_dict(d: Dict[Any, Any], keys: Iterable[Any]) -> Dict[Any, Any]: return {k: v for (k, v) in d.items() if k in keys} def fill_storage(storage: StorageInterface, data: Dict[str, Any]) -> None: process_replay_objects(data, storage=storage) # TODO: remove this function in favour of TimestampWithTimezone.to_datetime # from swh.model.model def ts2dt(ts: Dict[str, Any]) -> datetime: timestamp = datetime.fromtimestamp( ts["timestamp"]["seconds"], timezone(timedelta(minutes=ts["offset"])) ) return timestamp.replace(microsecond=ts["timestamp"]["microseconds"])