Page Menu
Home
Software Heritage
Search
Configure Global Search
Log In
Files
F7122860
D6334.id23268.diff
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
16 KB
Subscribers
None
D6334.id23268.diff
View Options
diff --git a/requirements.txt b/requirements.txt
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,6 +4,7 @@
click
iso8601
methodtools
+mongomock
pymongo
PyYAML
types-click
diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py
--- a/swh/provenance/__init__.py
+++ b/swh/provenance/__init__.py
@@ -72,8 +72,6 @@
:cls:`ValueError` if passed an unknown archive class.
"""
if cls in ["local", "postgresql"]:
- from swh.core.db import BaseDb
-
from .postgresql.provenance import ProvenanceStoragePostgreSql
if cls == "local":
@@ -83,17 +81,15 @@
DeprecationWarning,
)
- conn = BaseDb.connect(**kwargs["db"]).conn
raise_on_commit = kwargs.get("raise_on_commit", False)
- return ProvenanceStoragePostgreSql(conn, raise_on_commit)
+ return ProvenanceStoragePostgreSql(
+ raise_on_commit=raise_on_commit, **kwargs["db"]
+ )
elif cls == "mongodb":
- from pymongo import MongoClient
-
from .mongo.backend import ProvenanceStorageMongoDb
- dbname = kwargs["db"].pop("dbname")
- db = MongoClient(**kwargs["db"]).get_database(dbname)
- return ProvenanceStorageMongoDb(db)
+ engine = kwargs.get("engine", "pymongo")
+ return ProvenanceStorageMongoDb(engine=engine, **kwargs["db"])
raise ValueError
diff --git a/swh/provenance/cli.py b/swh/provenance/cli.py
--- a/swh/provenance/cli.py
+++ b/swh/provenance/cli.py
@@ -145,19 +145,19 @@
from .revision import CSVRevisionIterator, revision_add
archive = get_archive(**ctx.obj["config"]["provenance"]["archive"])
- provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"])
revisions_provider = generate_revision_tuples(filename)
revisions = CSVRevisionIterator(revisions_provider, limit=limit)
- for revision in revisions:
- revision_add(
- provenance,
- archive,
- [revision],
- trackall=track_all,
- lower=reuse,
- mindepth=min_depth,
- )
+ with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance:
+ for revision in revisions:
+ revision_add(
+ provenance,
+ archive,
+ [revision],
+ trackall=track_all,
+ lower=reuse,
+ mindepth=min_depth,
+ )
def generate_revision_tuples(
@@ -183,12 +183,12 @@
from .origin import CSVOriginIterator, origin_add
archive = get_archive(**ctx.obj["config"]["provenance"]["archive"])
- provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"])
origins_provider = generate_origin_tuples(filename)
origins = CSVOriginIterator(origins_provider, limit=limit)
- for origin in origins:
- origin_add(provenance, archive, [origin])
+ with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance:
+ for origin in origins:
+ origin_add(provenance, archive, [origin])
def generate_origin_tuples(filename: str) -> Generator[Tuple[str, bytes], None, None]:
@@ -205,18 +205,18 @@
"""Find first occurrence of the requested blob."""
from . import get_provenance
- provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"])
- occur = provenance.content_find_first(hash_to_bytes(swhid))
- if occur is not None:
- print(
- f"swh:1:cnt:{hash_to_hex(occur.content)}, "
- f"swh:1:rev:{hash_to_hex(occur.revision)}, "
- f"{occur.date}, "
- f"{occur.origin}, "
- f"{os.fsdecode(occur.path)}"
- )
- else:
- print(f"Cannot find a content with the id {swhid}")
+ with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance:
+ occur = provenance.content_find_first(hash_to_bytes(swhid))
+ if occur is not None:
+ print(
+ f"swh:1:cnt:{hash_to_hex(occur.content)}, "
+ f"swh:1:rev:{hash_to_hex(occur.revision)}, "
+ f"{occur.date}, "
+ f"{occur.origin}, "
+ f"{os.fsdecode(occur.path)}"
+ )
+ else:
+ print(f"Cannot find a content with the id {swhid}")
@cli.command(name="find-all")
@@ -227,12 +227,12 @@
"""Find all occurrences of the requested blob."""
from . import get_provenance
- provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"])
- for occur in provenance.content_find_all(hash_to_bytes(swhid), limit=limit):
- print(
- f"swh:1:cnt:{hash_to_hex(occur.content)}, "
- f"swh:1:rev:{hash_to_hex(occur.revision)}, "
- f"{occur.date}, "
- f"{occur.origin}, "
- f"{os.fsdecode(occur.path)}"
- )
+ with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance:
+ for occur in provenance.content_find_all(hash_to_bytes(swhid), limit=limit):
+ print(
+ f"swh:1:cnt:{hash_to_hex(occur.content)}, "
+ f"swh:1:rev:{hash_to_hex(occur.revision)}, "
+ f"{occur.date}, "
+ f"{occur.origin}, "
+ f"{os.fsdecode(occur.path)}"
+ )
diff --git a/swh/provenance/interface.py b/swh/provenance/interface.py
--- a/swh/provenance/interface.py
+++ b/swh/provenance/interface.py
@@ -3,10 +3,13 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
+from __future__ import annotations
+
from dataclasses import dataclass
from datetime import datetime
import enum
-from typing import Dict, Generator, Iterable, Optional, Set, Union
+from types import TracebackType
+from typing import Dict, Generator, Iterable, Optional, Set, Type, Union
from typing_extensions import Protocol, runtime_checkable
@@ -65,6 +68,22 @@
@runtime_checkable
class ProvenanceStorageInterface(Protocol):
+ def __enter__(self) -> ProvenanceStorageInterface:
+ ...
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ ...
+
+ @remote_api_endpoint("close")
+ def close(self) -> None:
+ """Close connection to the storage and release resources."""
+ ...
+
@remote_api_endpoint("content_add")
def content_add(
self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]]
@@ -129,6 +148,11 @@
This method is used only in tests."""
...
+ @remote_api_endpoint("open")
+ def open(self) -> None:
+ """Open connection to the storage and allocate necessary resources."""
+ ...
+
@remote_api_endpoint("origin_add")
def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool:
"""Add origins identified by sha1 ids, with their corresponding url (as paired
@@ -198,6 +222,21 @@
class ProvenanceInterface(Protocol):
storage: ProvenanceStorageInterface
+ def __enter__(self) -> ProvenanceInterface:
+ ...
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ ...
+
+ def close(self) -> None:
+ """Close connection to the underlying `storage` and release resources."""
+ ...
+
def flush(self) -> None:
"""Flush internal cache to the underlying `storage`."""
...
@@ -279,6 +318,12 @@
"""
...
+ def open(self) -> None:
+ """Open connection to the underlying `storage` and allocate necessary
+ resources.
+ """
+ ...
+
def origin_add(self, origin: OriginEntry) -> None:
"""Add `origin` to the provenance model."""
...
diff --git a/swh/provenance/mongo/backend.py b/swh/provenance/mongo/backend.py
--- a/swh/provenance/mongo/backend.py
+++ b/swh/provenance/mongo/backend.py
@@ -3,18 +3,23 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
+from __future__ import annotations
+
from datetime import datetime, timezone
import os
-from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Union
+from types import TracebackType
+from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Type, Union
from bson import ObjectId
-import pymongo.database
+import mongomock
+import pymongo
from swh.model.model import Sha1Git
from ..interface import (
EntityType,
ProvenanceResult,
+ ProvenanceStorageInterface,
RelationData,
RelationType,
RevisionData,
@@ -22,8 +27,25 @@
class ProvenanceStorageMongoDb:
- def __init__(self, db: pymongo.database.Database):
- self.db = db
+ def __init__(self, engine: str, **kwargs):
+ self.engine = engine
+ self.dbname = kwargs.pop("dbname")
+ self.conn_args = kwargs
+
+ def __enter__(self) -> ProvenanceStorageInterface:
+ self.open()
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ self.close()
+
+ def close(self) -> None:
+ self.db.client.close()
def content_add(
self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]]
@@ -203,6 +225,13 @@
paths.extend(value for _, value in each_dir["revision"].items())
return set(sum(paths, []))
+ def open(self) -> None:
+ if self.engine == "mongomock":
+ self.db = mongomock.MongoClient(**self.conn_args).get_database(self.dbname)
+ else:
+ # assume real MongoDB server by default
+ self.db = pymongo.MongoClient(**self.conn_args).get_database(self.dbname)
+
def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool:
existing = {
x["sha1"]: x
diff --git a/swh/provenance/postgresql/provenance.py b/swh/provenance/postgresql/provenance.py
--- a/swh/provenance/postgresql/provenance.py
+++ b/swh/provenance/postgresql/provenance.py
@@ -3,11 +3,14 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
+from __future__ import annotations
+
from contextlib import contextmanager
from datetime import datetime
import itertools
import logging
-from typing import Dict, Generator, Iterable, List, Optional, Set, Union
+from types import TracebackType
+from typing import Dict, Generator, Iterable, List, Optional, Set, Type, Union
import psycopg2.extensions
import psycopg2.extras
@@ -19,6 +22,7 @@
from ..interface import (
EntityType,
ProvenanceResult,
+ ProvenanceStorageInterface,
RelationData,
RelationType,
RevisionData,
@@ -28,16 +32,23 @@
class ProvenanceStoragePostgreSql:
- def __init__(
- self, conn: psycopg2.extensions.connection, raise_on_commit: bool = False
- ) -> None:
- BaseDb.adapt_conn(conn)
- self.conn = conn
- with self.transaction() as cursor:
- cursor.execute("SET timezone TO 'UTC'")
+ def __init__(self, raise_on_commit: bool = False, **kwargs) -> None:
+ self.conn_args = kwargs
self._flavor: Optional[str] = None
self.raise_on_commit = raise_on_commit
+ def __enter__(self) -> ProvenanceStorageInterface:
+ self.open()
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ self.close()
+
@contextmanager
def transaction(
self, readonly: bool = False
@@ -60,6 +71,9 @@
def denormalized(self) -> bool:
return "denormalized" in self.flavor
+ def close(self) -> None:
+ self.conn.close()
+
def content_add(
self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]]
) -> bool:
@@ -140,6 +154,12 @@
raise
return False
+ def open(self) -> None:
+ self.conn = BaseDb.connect(**self.conn_args).conn
+ BaseDb.adapt_conn(self.conn)
+ with self.transaction() as cursor:
+ cursor.execute("SET timezone TO 'UTC'")
+
def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]:
urls: Dict[Sha1Git, str] = {}
sha1s = tuple(ids)
diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py
--- a/swh/provenance/provenance.py
+++ b/swh/provenance/provenance.py
@@ -6,13 +6,15 @@
from datetime import datetime
import logging
import os
-from typing import Dict, Generator, Iterable, Optional, Set, Tuple
+from types import TracebackType
+from typing import Dict, Generator, Iterable, Optional, Set, Tuple, Type
from typing_extensions import Literal, TypedDict
from swh.model.model import Sha1Git
from .interface import (
+ ProvenanceInterface,
ProvenanceResult,
ProvenanceStorageInterface,
RelationData,
@@ -74,9 +76,24 @@
self.storage = storage
self.cache = new_cache()
+ def __enter__(self) -> ProvenanceInterface:
+ self.open()
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ self.close()
+
def clear_caches(self) -> None:
self.cache = new_cache()
+ def close(self) -> None:
+ self.storage.close()
+
def flush(self) -> None:
# Revision-content layer insertions ############################################
@@ -336,6 +353,9 @@
dates[sha1] = date
return dates
+ def open(self) -> None:
+ self.storage.open()
+
def origin_add(self, origin: OriginEntry) -> None:
self.cache["origin"]["data"][origin.id] = origin.url
self.cache["origin"]["added"].add(origin.id)
diff --git a/swh/provenance/tests/conftest.py b/swh/provenance/tests/conftest.py
--- a/swh/provenance/tests/conftest.py
+++ b/swh/provenance/tests/conftest.py
@@ -5,12 +5,12 @@
from datetime import datetime, timedelta, timezone
from os import path
-from typing import Any, Dict, Iterable
+from typing import Any, Dict, Generator, Iterable
from _pytest.fixtures import SubRequest
+import mongomock.database
import msgpack
import psycopg2.extensions
-import pymongo.database
import pytest
from pytest_postgresql.factories import postgresql
@@ -48,20 +48,27 @@
def provenance_storage(
request: SubRequest,
provenance_postgresqldb: Dict[str, str],
- mongodb: pymongo.database.Database,
-) -> ProvenanceStorageInterface:
+ mongodb: mongomock.database.Database,
+) -> Generator[ProvenanceStorageInterface, None, None]:
"""Return a working and initialized ProvenanceStorageInterface object"""
if request.param == "mongodb":
- from swh.provenance.mongo.backend import ProvenanceStorageMongoDb
-
- return ProvenanceStorageMongoDb(mongodb)
+ mongodb_params = {
+ "host": mongodb.client.address[0],
+ "port": mongodb.client.address[1],
+ "dbname": mongodb.name,
+ }
+ with get_provenance_storage(
+ cls=request.param, db=mongodb_params, engine="mongomock"
+ ) as storage:
+ yield storage
else:
# in test sessions, we DO want to raise any exception occurring at commit time
- return get_provenance_storage(
+ with get_provenance_storage(
cls=request.param, db=provenance_postgresqldb, raise_on_commit=True
- )
+ ) as storage:
+ yield storage
provenance_postgresql = postgresql("postgresql_proc", dbname="provenance_tests")
@@ -70,7 +77,7 @@
@pytest.fixture
def provenance(
provenance_postgresql: psycopg2.extensions.connection,
-) -> ProvenanceInterface:
+) -> Generator[ProvenanceInterface, None, None]:
"""Return a working and initialized ProvenanceInterface object"""
from swh.core.cli.db import populate_database_for_package
@@ -79,11 +86,12 @@
"swh.provenance", provenance_postgresql.dsn, flavor="with-path"
)
# in test sessions, we DO want to raise any exception occurring at commit time
- return get_provenance(
+ with get_provenance(
cls="postgresql",
db=provenance_postgresql.get_dsn_parameters(),
raise_on_commit=True,
- )
+ ) as provenance:
+ yield provenance
@pytest.fixture
File Metadata
Details
Attached
Mime Type
text/plain
Expires
Tue, Dec 17, 8:24 AM (2 w, 1 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3217132
Attached To
D6334: Add `close` method to both `ProvenanceInterface` and `ProvenanceStorageInterface`
Event Timeline
Log In to Comment