Page MenuHomeSoftware Heritage

D6334.id23268.diff
No OneTemporary

D6334.id23268.diff

diff --git a/requirements.txt b/requirements.txt
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,6 +4,7 @@
click
iso8601
methodtools
+mongomock
pymongo
PyYAML
types-click
diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py
--- a/swh/provenance/__init__.py
+++ b/swh/provenance/__init__.py
@@ -72,8 +72,6 @@
:cls:`ValueError` if passed an unknown archive class.
"""
if cls in ["local", "postgresql"]:
- from swh.core.db import BaseDb
-
from .postgresql.provenance import ProvenanceStoragePostgreSql
if cls == "local":
@@ -83,17 +81,15 @@
DeprecationWarning,
)
- conn = BaseDb.connect(**kwargs["db"]).conn
raise_on_commit = kwargs.get("raise_on_commit", False)
- return ProvenanceStoragePostgreSql(conn, raise_on_commit)
+ return ProvenanceStoragePostgreSql(
+ raise_on_commit=raise_on_commit, **kwargs["db"]
+ )
elif cls == "mongodb":
- from pymongo import MongoClient
-
from .mongo.backend import ProvenanceStorageMongoDb
- dbname = kwargs["db"].pop("dbname")
- db = MongoClient(**kwargs["db"]).get_database(dbname)
- return ProvenanceStorageMongoDb(db)
+ engine = kwargs.get("engine", "pymongo")
+ return ProvenanceStorageMongoDb(engine=engine, **kwargs["db"])
raise ValueError
diff --git a/swh/provenance/cli.py b/swh/provenance/cli.py
--- a/swh/provenance/cli.py
+++ b/swh/provenance/cli.py
@@ -145,19 +145,19 @@
from .revision import CSVRevisionIterator, revision_add
archive = get_archive(**ctx.obj["config"]["provenance"]["archive"])
- provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"])
revisions_provider = generate_revision_tuples(filename)
revisions = CSVRevisionIterator(revisions_provider, limit=limit)
- for revision in revisions:
- revision_add(
- provenance,
- archive,
- [revision],
- trackall=track_all,
- lower=reuse,
- mindepth=min_depth,
- )
+ with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance:
+ for revision in revisions:
+ revision_add(
+ provenance,
+ archive,
+ [revision],
+ trackall=track_all,
+ lower=reuse,
+ mindepth=min_depth,
+ )
def generate_revision_tuples(
@@ -183,12 +183,12 @@
from .origin import CSVOriginIterator, origin_add
archive = get_archive(**ctx.obj["config"]["provenance"]["archive"])
- provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"])
origins_provider = generate_origin_tuples(filename)
origins = CSVOriginIterator(origins_provider, limit=limit)
- for origin in origins:
- origin_add(provenance, archive, [origin])
+ with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance:
+ for origin in origins:
+ origin_add(provenance, archive, [origin])
def generate_origin_tuples(filename: str) -> Generator[Tuple[str, bytes], None, None]:
@@ -205,18 +205,18 @@
"""Find first occurrence of the requested blob."""
from . import get_provenance
- provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"])
- occur = provenance.content_find_first(hash_to_bytes(swhid))
- if occur is not None:
- print(
- f"swh:1:cnt:{hash_to_hex(occur.content)}, "
- f"swh:1:rev:{hash_to_hex(occur.revision)}, "
- f"{occur.date}, "
- f"{occur.origin}, "
- f"{os.fsdecode(occur.path)}"
- )
- else:
- print(f"Cannot find a content with the id {swhid}")
+ with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance:
+ occur = provenance.content_find_first(hash_to_bytes(swhid))
+ if occur is not None:
+ print(
+ f"swh:1:cnt:{hash_to_hex(occur.content)}, "
+ f"swh:1:rev:{hash_to_hex(occur.revision)}, "
+ f"{occur.date}, "
+ f"{occur.origin}, "
+ f"{os.fsdecode(occur.path)}"
+ )
+ else:
+ print(f"Cannot find a content with the id {swhid}")
@cli.command(name="find-all")
@@ -227,12 +227,12 @@
"""Find all occurrences of the requested blob."""
from . import get_provenance
- provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"])
- for occur in provenance.content_find_all(hash_to_bytes(swhid), limit=limit):
- print(
- f"swh:1:cnt:{hash_to_hex(occur.content)}, "
- f"swh:1:rev:{hash_to_hex(occur.revision)}, "
- f"{occur.date}, "
- f"{occur.origin}, "
- f"{os.fsdecode(occur.path)}"
- )
+ with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance:
+ for occur in provenance.content_find_all(hash_to_bytes(swhid), limit=limit):
+ print(
+ f"swh:1:cnt:{hash_to_hex(occur.content)}, "
+ f"swh:1:rev:{hash_to_hex(occur.revision)}, "
+ f"{occur.date}, "
+ f"{occur.origin}, "
+ f"{os.fsdecode(occur.path)}"
+ )
diff --git a/swh/provenance/interface.py b/swh/provenance/interface.py
--- a/swh/provenance/interface.py
+++ b/swh/provenance/interface.py
@@ -3,10 +3,13 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
+from __future__ import annotations
+
from dataclasses import dataclass
from datetime import datetime
import enum
-from typing import Dict, Generator, Iterable, Optional, Set, Union
+from types import TracebackType
+from typing import Dict, Generator, Iterable, Optional, Set, Type, Union
from typing_extensions import Protocol, runtime_checkable
@@ -65,6 +68,22 @@
@runtime_checkable
class ProvenanceStorageInterface(Protocol):
+ def __enter__(self) -> ProvenanceStorageInterface:
+ ...
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ ...
+
+ @remote_api_endpoint("close")
+ def close(self) -> None:
+ """Close connection to the storage and release resources."""
+ ...
+
@remote_api_endpoint("content_add")
def content_add(
self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]]
@@ -129,6 +148,11 @@
This method is used only in tests."""
...
+ @remote_api_endpoint("open")
+ def open(self) -> None:
+ """Open connection to the storage and allocate necessary resources."""
+ ...
+
@remote_api_endpoint("origin_add")
def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool:
"""Add origins identified by sha1 ids, with their corresponding url (as paired
@@ -198,6 +222,21 @@
class ProvenanceInterface(Protocol):
storage: ProvenanceStorageInterface
+ def __enter__(self) -> ProvenanceInterface:
+ ...
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ ...
+
+ def close(self) -> None:
+ """Close connection to the underlying `storage` and release resources."""
+ ...
+
def flush(self) -> None:
"""Flush internal cache to the underlying `storage`."""
...
@@ -279,6 +318,12 @@
"""
...
+ def open(self) -> None:
+ """Open connection to the underlying `storage` and allocate necessary
+ resources.
+ """
+ ...
+
def origin_add(self, origin: OriginEntry) -> None:
"""Add `origin` to the provenance model."""
...
diff --git a/swh/provenance/mongo/backend.py b/swh/provenance/mongo/backend.py
--- a/swh/provenance/mongo/backend.py
+++ b/swh/provenance/mongo/backend.py
@@ -3,18 +3,23 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
+from __future__ import annotations
+
from datetime import datetime, timezone
import os
-from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Union
+from types import TracebackType
+from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Type, Union
from bson import ObjectId
-import pymongo.database
+import mongomock
+import pymongo
from swh.model.model import Sha1Git
from ..interface import (
EntityType,
ProvenanceResult,
+ ProvenanceStorageInterface,
RelationData,
RelationType,
RevisionData,
@@ -22,8 +27,25 @@
class ProvenanceStorageMongoDb:
- def __init__(self, db: pymongo.database.Database):
- self.db = db
+ def __init__(self, engine: str, **kwargs):
+ self.engine = engine
+ self.dbname = kwargs.pop("dbname")
+ self.conn_args = kwargs
+
+ def __enter__(self) -> ProvenanceStorageInterface:
+ self.open()
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ self.close()
+
+ def close(self) -> None:
+ self.db.client.close()
def content_add(
self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]]
@@ -203,6 +225,13 @@
paths.extend(value for _, value in each_dir["revision"].items())
return set(sum(paths, []))
+ def open(self) -> None:
+ if self.engine == "mongomock":
+ self.db = mongomock.MongoClient(**self.conn_args).get_database(self.dbname)
+ else:
+ # assume real MongoDB server by default
+ self.db = pymongo.MongoClient(**self.conn_args).get_database(self.dbname)
+
def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool:
existing = {
x["sha1"]: x
diff --git a/swh/provenance/postgresql/provenance.py b/swh/provenance/postgresql/provenance.py
--- a/swh/provenance/postgresql/provenance.py
+++ b/swh/provenance/postgresql/provenance.py
@@ -3,11 +3,14 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
+from __future__ import annotations
+
from contextlib import contextmanager
from datetime import datetime
import itertools
import logging
-from typing import Dict, Generator, Iterable, List, Optional, Set, Union
+from types import TracebackType
+from typing import Dict, Generator, Iterable, List, Optional, Set, Type, Union
import psycopg2.extensions
import psycopg2.extras
@@ -19,6 +22,7 @@
from ..interface import (
EntityType,
ProvenanceResult,
+ ProvenanceStorageInterface,
RelationData,
RelationType,
RevisionData,
@@ -28,16 +32,23 @@
class ProvenanceStoragePostgreSql:
- def __init__(
- self, conn: psycopg2.extensions.connection, raise_on_commit: bool = False
- ) -> None:
- BaseDb.adapt_conn(conn)
- self.conn = conn
- with self.transaction() as cursor:
- cursor.execute("SET timezone TO 'UTC'")
+ def __init__(self, raise_on_commit: bool = False, **kwargs) -> None:
+ self.conn_args = kwargs
self._flavor: Optional[str] = None
self.raise_on_commit = raise_on_commit
+ def __enter__(self) -> ProvenanceStorageInterface:
+ self.open()
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ self.close()
+
@contextmanager
def transaction(
self, readonly: bool = False
@@ -60,6 +71,9 @@
def denormalized(self) -> bool:
return "denormalized" in self.flavor
+ def close(self) -> None:
+ self.conn.close()
+
def content_add(
self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]]
) -> bool:
@@ -140,6 +154,12 @@
raise
return False
+ def open(self) -> None:
+ self.conn = BaseDb.connect(**self.conn_args).conn
+ BaseDb.adapt_conn(self.conn)
+ with self.transaction() as cursor:
+ cursor.execute("SET timezone TO 'UTC'")
+
def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]:
urls: Dict[Sha1Git, str] = {}
sha1s = tuple(ids)
diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py
--- a/swh/provenance/provenance.py
+++ b/swh/provenance/provenance.py
@@ -6,13 +6,15 @@
from datetime import datetime
import logging
import os
-from typing import Dict, Generator, Iterable, Optional, Set, Tuple
+from types import TracebackType
+from typing import Dict, Generator, Iterable, Optional, Set, Tuple, Type
from typing_extensions import Literal, TypedDict
from swh.model.model import Sha1Git
from .interface import (
+ ProvenanceInterface,
ProvenanceResult,
ProvenanceStorageInterface,
RelationData,
@@ -74,9 +76,24 @@
self.storage = storage
self.cache = new_cache()
+ def __enter__(self) -> ProvenanceInterface:
+ self.open()
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ self.close()
+
def clear_caches(self) -> None:
self.cache = new_cache()
+ def close(self) -> None:
+ self.storage.close()
+
def flush(self) -> None:
# Revision-content layer insertions ############################################
@@ -336,6 +353,9 @@
dates[sha1] = date
return dates
+ def open(self) -> None:
+ self.storage.open()
+
def origin_add(self, origin: OriginEntry) -> None:
self.cache["origin"]["data"][origin.id] = origin.url
self.cache["origin"]["added"].add(origin.id)
diff --git a/swh/provenance/tests/conftest.py b/swh/provenance/tests/conftest.py
--- a/swh/provenance/tests/conftest.py
+++ b/swh/provenance/tests/conftest.py
@@ -5,12 +5,12 @@
from datetime import datetime, timedelta, timezone
from os import path
-from typing import Any, Dict, Iterable
+from typing import Any, Dict, Generator, Iterable
from _pytest.fixtures import SubRequest
+import mongomock.database
import msgpack
import psycopg2.extensions
-import pymongo.database
import pytest
from pytest_postgresql.factories import postgresql
@@ -48,20 +48,27 @@
def provenance_storage(
request: SubRequest,
provenance_postgresqldb: Dict[str, str],
- mongodb: pymongo.database.Database,
-) -> ProvenanceStorageInterface:
+ mongodb: mongomock.database.Database,
+) -> Generator[ProvenanceStorageInterface, None, None]:
"""Return a working and initialized ProvenanceStorageInterface object"""
if request.param == "mongodb":
- from swh.provenance.mongo.backend import ProvenanceStorageMongoDb
-
- return ProvenanceStorageMongoDb(mongodb)
+ mongodb_params = {
+ "host": mongodb.client.address[0],
+ "port": mongodb.client.address[1],
+ "dbname": mongodb.name,
+ }
+ with get_provenance_storage(
+ cls=request.param, db=mongodb_params, engine="mongomock"
+ ) as storage:
+ yield storage
else:
# in test sessions, we DO want to raise any exception occurring at commit time
- return get_provenance_storage(
+ with get_provenance_storage(
cls=request.param, db=provenance_postgresqldb, raise_on_commit=True
- )
+ ) as storage:
+ yield storage
provenance_postgresql = postgresql("postgresql_proc", dbname="provenance_tests")
@@ -70,7 +77,7 @@
@pytest.fixture
def provenance(
provenance_postgresql: psycopg2.extensions.connection,
-) -> ProvenanceInterface:
+) -> Generator[ProvenanceInterface, None, None]:
"""Return a working and initialized ProvenanceInterface object"""
from swh.core.cli.db import populate_database_for_package
@@ -79,11 +86,12 @@
"swh.provenance", provenance_postgresql.dsn, flavor="with-path"
)
# in test sessions, we DO want to raise any exception occurring at commit time
- return get_provenance(
+ with get_provenance(
cls="postgresql",
db=provenance_postgresql.get_dsn_parameters(),
raise_on_commit=True,
- )
+ ) as provenance:
+ yield provenance
@pytest.fixture

File Metadata

Mime Type
text/plain
Expires
Tue, Dec 17, 8:24 AM (2 w, 1 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3217132

Event Timeline