diff --git a/requirements.txt b/requirements.txt --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ click iso8601 methodtools +mongomock pymongo PyYAML types-click diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py --- a/swh/provenance/__init__.py +++ b/swh/provenance/__init__.py @@ -72,8 +72,6 @@ :cls:`ValueError` if passed an unknown archive class. """ if cls in ["local", "postgresql"]: - from swh.core.db import BaseDb - from .postgresql.provenance import ProvenanceStoragePostgreSql if cls == "local": @@ -83,18 +81,16 @@ DeprecationWarning, ) - conn = BaseDb.connect(**kwargs["db"]).conn raise_on_commit = kwargs.get("raise_on_commit", False) - return ProvenanceStoragePostgreSql(conn, raise_on_commit) + return ProvenanceStoragePostgreSql( + raise_on_commit=raise_on_commit, **kwargs["db"] + ) elif cls == "mongodb": - from pymongo import MongoClient - from .mongo.backend import ProvenanceStorageMongoDb - dbname = kwargs["db"].pop("dbname") - db = MongoClient(**kwargs["db"]).get_database(dbname) - return ProvenanceStorageMongoDb(db) + engine = kwargs.get("engine", "pymongo") + return ProvenanceStorageMongoDb(engine=engine, **kwargs["db"]) elif cls in ["remote", "rpcapi"]: from .api.client import ProvenanceStorageRPCClient diff --git a/swh/provenance/api/server.py b/swh/provenance/api/server.py --- a/swh/provenance/api/server.py +++ b/swh/provenance/api/server.py @@ -23,6 +23,7 @@ global storage if storage is None: storage = get_provenance_storage(**app.config["provenance"]["storage"]) + storage.open() # XXX: nobody is closing this storage! return storage diff --git a/swh/provenance/cli.py b/swh/provenance/cli.py --- a/swh/provenance/cli.py +++ b/swh/provenance/cli.py @@ -152,6 +152,7 @@ revisions_provider = generate_revision_tuples(filename) revisions = CSVRevisionIterator(revisions_provider, limit=limit) + provenance.open() for revision in revisions: revision_add( provenance, @@ -161,6 +162,7 @@ lower=reuse, mindepth=min_depth, ) + provenance.close() def generate_revision_tuples( @@ -190,8 +192,10 @@ origins_provider = generate_origin_tuples(filename) origins = CSVOriginIterator(origins_provider, limit=limit) + provenance.open() for origin in origins: origin_add(provenance, archive, [origin]) + provenance.close() def generate_origin_tuples(filename: str) -> Generator[Tuple[str, bytes], None, None]: @@ -209,6 +213,8 @@ from . import get_provenance provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"]) + + provenance.open() occur = provenance.content_find_first(hash_to_bytes(swhid)) if occur is not None: print( @@ -220,6 +226,7 @@ ) else: print(f"Cannot find a content with the id {swhid}") + provenance.close() @cli.command(name="find-all") @@ -231,6 +238,8 @@ from . import get_provenance provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"]) + + provenance.open() for occur in provenance.content_find_all(hash_to_bytes(swhid), limit=limit): print( f"swh:1:cnt:{hash_to_hex(occur.content)}, " @@ -239,3 +248,4 @@ f"{occur.origin}, " f"{os.fsdecode(occur.path)}" ) + provenance.close() diff --git a/swh/provenance/interface.py b/swh/provenance/interface.py --- a/swh/provenance/interface.py +++ b/swh/provenance/interface.py @@ -65,6 +65,11 @@ @runtime_checkable class ProvenanceStorageInterface(Protocol): + @remote_api_endpoint("close") + def close(self) -> None: + """Close connection to the storage and release resources.""" + ... + @remote_api_endpoint("content_add") def content_add( self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] @@ -129,6 +134,11 @@ This method is used only in tests.""" ... + @remote_api_endpoint("open") + def open(self) -> None: + """Open connection to the storage and allocate necessary resources.""" + ... + @remote_api_endpoint("origin_add") def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: """Add origins identified by sha1 ids, with their corresponding url (as paired @@ -198,6 +208,10 @@ class ProvenanceInterface(Protocol): storage: ProvenanceStorageInterface + def close(self) -> None: + """Close connection to the underlying `storage` and release resources.""" + ... + def flush(self) -> None: """Flush internal cache to the underlying `storage`.""" ... @@ -279,6 +293,12 @@ """ ... + def open(self) -> None: + """Open connection to the underlying `storage` and allocate necessary + resources. + """ + ... + def origin_add(self, origin: OriginEntry) -> None: """Add `origin` to the provenance model.""" ... diff --git a/swh/provenance/mongo/backend.py b/swh/provenance/mongo/backend.py --- a/swh/provenance/mongo/backend.py +++ b/swh/provenance/mongo/backend.py @@ -8,7 +8,8 @@ from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Union from bson import ObjectId -import pymongo.database +import mongomock +import pymongo from swh.model.model import Sha1Git @@ -22,8 +23,13 @@ class ProvenanceStorageMongoDb: - def __init__(self, db: pymongo.database.Database): - self.db = db + def __init__(self, engine: str, **kwargs): + self.engine = engine + self.dbname = kwargs.pop("dbname") + self.conn_args = kwargs + + def close(self) -> None: + self.db.client.close() def content_add( self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] @@ -203,6 +209,13 @@ paths.extend(value for _, value in each_dir["revision"].items()) return set(sum(paths, [])) + def open(self) -> None: + if self.engine == "mongomock": + self.db = mongomock.MongoClient(**self.conn_args).get_database(self.dbname) + else: + # assume real MongoDB server by default + self.db = pymongo.MongoClient(**self.conn_args).get_database(self.dbname) + def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: existing = { x["sha1"]: x diff --git a/swh/provenance/postgresql/provenance.py b/swh/provenance/postgresql/provenance.py --- a/swh/provenance/postgresql/provenance.py +++ b/swh/provenance/postgresql/provenance.py @@ -28,13 +28,8 @@ class ProvenanceStoragePostgreSql: - def __init__( - self, conn: psycopg2.extensions.connection, raise_on_commit: bool = False - ) -> None: - BaseDb.adapt_conn(conn) - self.conn = conn - with self.transaction() as cursor: - cursor.execute("SET timezone TO 'UTC'") + def __init__(self, raise_on_commit: bool = False, **kwargs) -> None: + self.conn_args = kwargs self._flavor: Optional[str] = None self.raise_on_commit = raise_on_commit @@ -60,6 +55,9 @@ def denormalized(self) -> bool: return "denormalized" in self.flavor + def close(self) -> None: + self.conn.close() + def content_add( self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] ) -> bool: @@ -140,6 +138,12 @@ raise return False + def open(self) -> None: + self.conn = BaseDb.connect(**self.conn_args).conn + BaseDb.adapt_conn(self.conn) + with self.transaction() as cursor: + cursor.execute("SET timezone TO 'UTC'") + def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: urls: Dict[Sha1Git, str] = {} sha1s = tuple(ids) diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -77,6 +77,9 @@ def clear_caches(self) -> None: self.cache = new_cache() + def close(self) -> None: + self.storage.close() + def flush(self) -> None: # Revision-content layer insertions ############################################ @@ -336,6 +339,9 @@ dates[sha1] = date return dates + def open(self) -> None: + self.storage.open() + def origin_add(self, origin: OriginEntry) -> None: self.cache["origin"]["data"][origin.id] = origin.url self.cache["origin"]["added"].add(origin.id) diff --git a/swh/provenance/tests/conftest.py b/swh/provenance/tests/conftest.py --- a/swh/provenance/tests/conftest.py +++ b/swh/provenance/tests/conftest.py @@ -5,12 +5,12 @@ from datetime import datetime, timedelta, timezone from os import path -from typing import Any, Dict, Iterable, Iterator +from typing import Any, Dict, Generator, Iterable, Iterator from _pytest.fixtures import SubRequest +import mongomock.database import msgpack import psycopg2.extensions -import pymongo.database import pytest from pytest_postgresql.factories import postgresql @@ -55,7 +55,9 @@ server.storage = get_provenance_storage( cls="postgresql", db=provenance_postgresqldb ) + server.storage.open() yield server.app + server.storage.close() # the RPCClient class used as client used in these tests @@ -68,25 +70,38 @@ def provenance_storage( request: SubRequest, provenance_postgresqldb: Dict[str, str], - mongodb: pymongo.database.Database, + mongodb: mongomock.database.Database, swh_rpc_client: ProvenanceStorageRPCClient, -) -> ProvenanceStorageInterface: +) -> Generator[ProvenanceStorageInterface, None, None]: """Return a working and initialized ProvenanceStorageInterface object""" if request.param == "rpcapi": assert isinstance(swh_rpc_client, ProvenanceStorageInterface) - return swh_rpc_client + swh_rpc_client.open() + yield swh_rpc_client + swh_rpc_client.close() elif request.param == "mongodb": - from swh.provenance.mongo.backend import ProvenanceStorageMongoDb - - return ProvenanceStorageMongoDb(mongodb) + mongodb_params = { + "host": mongodb.client.address[0], + "port": mongodb.client.address[1], + "dbname": mongodb.name, + } + mongodb_storage = get_provenance_storage( + cls=request.param, db=mongodb_params, engine="mongomock" + ) + mongodb_storage.open() + yield mongodb_storage + mongodb_storage.close() else: # in test sessions, we DO want to raise any exception occurring at commit time - return get_provenance_storage( + storage = get_provenance_storage( cls=request.param, db=provenance_postgresqldb, raise_on_commit=True ) + storage.open() + yield storage + storage.close() provenance_postgresql = postgresql("postgresql_proc", dbname="provenance_tests") @@ -95,7 +110,7 @@ @pytest.fixture def provenance( provenance_postgresql: psycopg2.extensions.connection, -) -> ProvenanceInterface: +) -> Generator[ProvenanceInterface, None, None]: """Return a working and initialized ProvenanceInterface object""" from swh.core.cli.db import populate_database_for_package @@ -104,11 +119,14 @@ "swh.provenance", provenance_postgresql.dsn, flavor="with-path" ) # in test sessions, we DO want to raise any exception occurring at commit time - return get_provenance( + provenance = get_provenance( cls="postgresql", db=provenance_postgresql.get_dsn_parameters(), raise_on_commit=True, ) + provenance.open() + yield provenance + provenance.close() @pytest.fixture