diff --git a/requirements.txt b/requirements.txt --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ click iso8601 methodtools +mongomock pymongo PyYAML types-click diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py --- a/swh/provenance/__init__.py +++ b/swh/provenance/__init__.py @@ -72,8 +72,6 @@ :cls:`ValueError` if passed an unknown archive class. """ if cls in ["local", "postgresql"]: - from swh.core.db import BaseDb - from .postgresql.provenance import ProvenanceStoragePostgreSql if cls == "local": @@ -83,17 +81,15 @@ DeprecationWarning, ) - conn = BaseDb.connect(**kwargs["db"]).conn raise_on_commit = kwargs.get("raise_on_commit", False) - return ProvenanceStoragePostgreSql(conn, raise_on_commit) + return ProvenanceStoragePostgreSql( + raise_on_commit=raise_on_commit, **kwargs["db"] + ) elif cls == "mongodb": - from pymongo import MongoClient - from .mongo.backend import ProvenanceStorageMongoDb - dbname = kwargs["db"].pop("dbname") - db = MongoClient(**kwargs["db"]).get_database(dbname) - return ProvenanceStorageMongoDb(db) + engine = kwargs.get("engine", "pymongo") + return ProvenanceStorageMongoDb(engine=engine, **kwargs["db"]) raise ValueError diff --git a/swh/provenance/cli.py b/swh/provenance/cli.py --- a/swh/provenance/cli.py +++ b/swh/provenance/cli.py @@ -145,19 +145,19 @@ from .revision import CSVRevisionIterator, revision_add archive = get_archive(**ctx.obj["config"]["provenance"]["archive"]) - provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"]) revisions_provider = generate_revision_tuples(filename) revisions = CSVRevisionIterator(revisions_provider, limit=limit) - for revision in revisions: - revision_add( - provenance, - archive, - [revision], - trackall=track_all, - lower=reuse, - mindepth=min_depth, - ) + with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance: + for revision in revisions: + revision_add( + provenance, + archive, + [revision], + trackall=track_all, + lower=reuse, + mindepth=min_depth, + ) def generate_revision_tuples( @@ -183,12 +183,12 @@ from .origin import CSVOriginIterator, origin_add archive = get_archive(**ctx.obj["config"]["provenance"]["archive"]) - provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"]) origins_provider = generate_origin_tuples(filename) origins = CSVOriginIterator(origins_provider, limit=limit) - for origin in origins: - origin_add(provenance, archive, [origin]) + with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance: + for origin in origins: + origin_add(provenance, archive, [origin]) def generate_origin_tuples(filename: str) -> Generator[Tuple[str, bytes], None, None]: @@ -205,18 +205,18 @@ """Find first occurrence of the requested blob.""" from . import get_provenance - provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"]) - occur = provenance.content_find_first(hash_to_bytes(swhid)) - if occur is not None: - print( - f"swh:1:cnt:{hash_to_hex(occur.content)}, " - f"swh:1:rev:{hash_to_hex(occur.revision)}, " - f"{occur.date}, " - f"{occur.origin}, " - f"{os.fsdecode(occur.path)}" - ) - else: - print(f"Cannot find a content with the id {swhid}") + with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance: + occur = provenance.content_find_first(hash_to_bytes(swhid)) + if occur is not None: + print( + f"swh:1:cnt:{hash_to_hex(occur.content)}, " + f"swh:1:rev:{hash_to_hex(occur.revision)}, " + f"{occur.date}, " + f"{occur.origin}, " + f"{os.fsdecode(occur.path)}" + ) + else: + print(f"Cannot find a content with the id {swhid}") @cli.command(name="find-all") @@ -227,12 +227,12 @@ """Find all occurrences of the requested blob.""" from . import get_provenance - provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"]) - for occur in provenance.content_find_all(hash_to_bytes(swhid), limit=limit): - print( - f"swh:1:cnt:{hash_to_hex(occur.content)}, " - f"swh:1:rev:{hash_to_hex(occur.revision)}, " - f"{occur.date}, " - f"{occur.origin}, " - f"{os.fsdecode(occur.path)}" - ) + with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance: + for occur in provenance.content_find_all(hash_to_bytes(swhid), limit=limit): + print( + f"swh:1:cnt:{hash_to_hex(occur.content)}, " + f"swh:1:rev:{hash_to_hex(occur.revision)}, " + f"{occur.date}, " + f"{occur.origin}, " + f"{os.fsdecode(occur.path)}" + ) diff --git a/swh/provenance/interface.py b/swh/provenance/interface.py --- a/swh/provenance/interface.py +++ b/swh/provenance/interface.py @@ -3,10 +3,13 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from __future__ import annotations + from dataclasses import dataclass from datetime import datetime import enum -from typing import Dict, Generator, Iterable, Optional, Set, Union +from types import TracebackType +from typing import Dict, Generator, Iterable, Optional, Set, Type, Union from typing_extensions import Protocol, runtime_checkable @@ -65,6 +68,22 @@ @runtime_checkable class ProvenanceStorageInterface(Protocol): + def __enter__(self) -> ProvenanceStorageInterface: + ... + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + ... + + @remote_api_endpoint("close") + def close(self) -> None: + """Close connection to the storage and release resources.""" + ... + @remote_api_endpoint("content_add") def content_add( self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] @@ -129,6 +148,11 @@ This method is used only in tests.""" ... + @remote_api_endpoint("open") + def open(self) -> None: + """Open connection to the storage and allocate necessary resources.""" + ... + @remote_api_endpoint("origin_add") def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: """Add origins identified by sha1 ids, with their corresponding url (as paired @@ -198,6 +222,21 @@ class ProvenanceInterface(Protocol): storage: ProvenanceStorageInterface + def __enter__(self) -> ProvenanceInterface: + ... + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + ... + + def close(self) -> None: + """Close connection to the underlying `storage` and release resources.""" + ... + def flush(self) -> None: """Flush internal cache to the underlying `storage`.""" ... @@ -279,6 +318,12 @@ """ ... + def open(self) -> None: + """Open connection to the underlying `storage` and allocate necessary + resources. + """ + ... + def origin_add(self, origin: OriginEntry) -> None: """Add `origin` to the provenance model.""" ... diff --git a/swh/provenance/mongo/backend.py b/swh/provenance/mongo/backend.py --- a/swh/provenance/mongo/backend.py +++ b/swh/provenance/mongo/backend.py @@ -3,18 +3,23 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from __future__ import annotations + from datetime import datetime, timezone import os -from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Union +from types import TracebackType +from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Type, Union from bson import ObjectId -import pymongo.database +import mongomock +import pymongo from swh.model.model import Sha1Git from ..interface import ( EntityType, ProvenanceResult, + ProvenanceStorageInterface, RelationData, RelationType, RevisionData, @@ -22,8 +27,25 @@ class ProvenanceStorageMongoDb: - def __init__(self, db: pymongo.database.Database): - self.db = db + def __init__(self, engine: str, **kwargs): + self.engine = engine + self.dbname = kwargs.pop("dbname") + self.conn_args = kwargs + + def __enter__(self) -> ProvenanceStorageInterface: + self.open() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + + def close(self) -> None: + self.db.client.close() def content_add( self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] @@ -203,6 +225,13 @@ paths.extend(value for _, value in each_dir["revision"].items()) return set(sum(paths, [])) + def open(self) -> None: + if self.engine == "mongomock": + self.db = mongomock.MongoClient(**self.conn_args).get_database(self.dbname) + else: + # assume real MongoDB server by default + self.db = pymongo.MongoClient(**self.conn_args).get_database(self.dbname) + def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: existing = { x["sha1"]: x diff --git a/swh/provenance/postgresql/provenance.py b/swh/provenance/postgresql/provenance.py --- a/swh/provenance/postgresql/provenance.py +++ b/swh/provenance/postgresql/provenance.py @@ -3,11 +3,14 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from __future__ import annotations + from contextlib import contextmanager from datetime import datetime import itertools import logging -from typing import Dict, Generator, Iterable, List, Optional, Set, Union +from types import TracebackType +from typing import Dict, Generator, Iterable, List, Optional, Set, Type, Union import psycopg2.extensions import psycopg2.extras @@ -19,6 +22,7 @@ from ..interface import ( EntityType, ProvenanceResult, + ProvenanceStorageInterface, RelationData, RelationType, RevisionData, @@ -28,16 +32,23 @@ class ProvenanceStoragePostgreSql: - def __init__( - self, conn: psycopg2.extensions.connection, raise_on_commit: bool = False - ) -> None: - BaseDb.adapt_conn(conn) - self.conn = conn - with self.transaction() as cursor: - cursor.execute("SET timezone TO 'UTC'") + def __init__(self, raise_on_commit: bool = False, **kwargs) -> None: + self.conn_args = kwargs self._flavor: Optional[str] = None self.raise_on_commit = raise_on_commit + def __enter__(self) -> ProvenanceStorageInterface: + self.open() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + @contextmanager def transaction( self, readonly: bool = False @@ -60,6 +71,9 @@ def denormalized(self) -> bool: return "denormalized" in self.flavor + def close(self) -> None: + self.conn.close() + def content_add( self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] ) -> bool: @@ -140,6 +154,12 @@ raise return False + def open(self) -> None: + self.conn = BaseDb.connect(**self.conn_args).conn + BaseDb.adapt_conn(self.conn) + with self.transaction() as cursor: + cursor.execute("SET timezone TO 'UTC'") + def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: urls: Dict[Sha1Git, str] = {} sha1s = tuple(ids) diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -6,13 +6,15 @@ from datetime import datetime import logging import os -from typing import Dict, Generator, Iterable, Optional, Set, Tuple +from types import TracebackType +from typing import Dict, Generator, Iterable, Optional, Set, Tuple, Type from typing_extensions import Literal, TypedDict from swh.model.model import Sha1Git from .interface import ( + ProvenanceInterface, ProvenanceResult, ProvenanceStorageInterface, RelationData, @@ -74,9 +76,24 @@ self.storage = storage self.cache = new_cache() + def __enter__(self) -> ProvenanceInterface: + self.open() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + def clear_caches(self) -> None: self.cache = new_cache() + def close(self) -> None: + self.storage.close() + def flush(self) -> None: # Revision-content layer insertions ############################################ @@ -336,6 +353,9 @@ dates[sha1] = date return dates + def open(self) -> None: + self.storage.open() + def origin_add(self, origin: OriginEntry) -> None: self.cache["origin"]["data"][origin.id] = origin.url self.cache["origin"]["added"].add(origin.id) diff --git a/swh/provenance/tests/conftest.py b/swh/provenance/tests/conftest.py --- a/swh/provenance/tests/conftest.py +++ b/swh/provenance/tests/conftest.py @@ -5,12 +5,12 @@ from datetime import datetime, timedelta, timezone from os import path -from typing import Any, Dict, Iterable +from typing import Any, Dict, Generator, Iterable from _pytest.fixtures import SubRequest +import mongomock.database import msgpack import psycopg2.extensions -import pymongo.database import pytest from pytest_postgresql.factories import postgresql @@ -48,20 +48,27 @@ def provenance_storage( request: SubRequest, provenance_postgresqldb: Dict[str, str], - mongodb: pymongo.database.Database, -) -> ProvenanceStorageInterface: + mongodb: mongomock.database.Database, +) -> Generator[ProvenanceStorageInterface, None, None]: """Return a working and initialized ProvenanceStorageInterface object""" if request.param == "mongodb": - from swh.provenance.mongo.backend import ProvenanceStorageMongoDb - - return ProvenanceStorageMongoDb(mongodb) + mongodb_params = { + "host": mongodb.client.address[0], + "port": mongodb.client.address[1], + "dbname": mongodb.name, + } + with get_provenance_storage( + cls=request.param, db=mongodb_params, engine="mongomock" + ) as storage: + yield storage else: # in test sessions, we DO want to raise any exception occurring at commit time - return get_provenance_storage( + with get_provenance_storage( cls=request.param, db=provenance_postgresqldb, raise_on_commit=True - ) + ) as storage: + yield storage provenance_postgresql = postgresql("postgresql_proc", dbname="provenance_tests") @@ -70,7 +77,7 @@ @pytest.fixture def provenance( provenance_postgresql: psycopg2.extensions.connection, -) -> ProvenanceInterface: +) -> Generator[ProvenanceInterface, None, None]: """Return a working and initialized ProvenanceInterface object""" from swh.core.cli.db import populate_database_for_package @@ -79,11 +86,12 @@ "swh.provenance", provenance_postgresql.dsn, flavor="with-path" ) # in test sessions, we DO want to raise any exception occurring at commit time - return get_provenance( + with get_provenance( cls="postgresql", db=provenance_postgresql.get_dsn_parameters(), raise_on_commit=True, - ) + ) as provenance: + yield provenance @pytest.fixture