diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py index 7368f72..5ff322f 100644 --- a/swh/provenance/__init__.py +++ b/swh/provenance/__init__.py @@ -1,38 +1,38 @@ from typing import TYPE_CHECKING from .postgresql.db_utils import connect if TYPE_CHECKING: from swh.provenance.archive import ArchiveInterface from swh.provenance.provenance import ProvenanceInterface def get_archive(cls: str, **kwargs) -> "ArchiveInterface": if cls == "api": from swh.provenance.storage.archive import ArchiveStorage + from swh.storage import get_storage - return ArchiveStorage(**kwargs["storage"]) + return ArchiveStorage(get_storage(**kwargs["storage"])) elif cls == "direct": from swh.provenance.postgresql.archive import ArchivePostgreSQL - conn = connect(kwargs["db"]) - return ArchivePostgreSQL(conn) + return ArchivePostgreSQL(connect(kwargs["db"])) else: raise NotImplementedError def get_provenance(cls: str, **kwargs) -> "ProvenanceInterface": if cls == "local": conn = connect(kwargs["db"]) if kwargs.get("with_path", True): from swh.provenance.postgresql.provenancedb_with_path import ( ProvenanceWithPathDB, ) return ProvenanceWithPathDB(conn) else: from swh.provenance.postgresql.provenancedb_without_path import ( ProvenanceWithoutPathDB, ) return ProvenanceWithoutPathDB(conn) else: raise NotImplementedError diff --git a/swh/provenance/storage/archive.py b/swh/provenance/storage/archive.py index e31de7c..06b7ce5 100644 --- a/swh/provenance/storage/archive.py +++ b/swh/provenance/storage/archive.py @@ -1,47 +1,47 @@ from typing import Any, Dict, List # from functools import lru_cache from methodtools import lru_cache -from swh.storage import get_storage +from swh.storage.interface import StorageInterface class ArchiveStorage: - def __init__(self, cls: str, **kwargs): - self.storage = get_storage(cls, **kwargs) + def __init__(self, storage: StorageInterface): + self.storage = storage @lru_cache(maxsize=1000000) def directory_ls(self, id: bytes) -> List[Dict[str, Any]]: # TODO: filter unused fields return [entry for entry in self.storage.directory_ls(id)] def iter_origins(self): from swh.storage.algos.origin import iter_origins yield from iter_origins(self.storage) def iter_origin_visits(self, origin: str): from swh.storage.algos.origin import iter_origin_visits # TODO: filter unused fields yield from iter_origin_visits(self.storage, origin) def iter_origin_visit_statuses(self, origin: str, visit: int): from swh.storage.algos.origin import iter_origin_visit_statuses # TODO: filter unused fields yield from iter_origin_visit_statuses(self.storage, origin, visit) def release_get(self, ids: List[bytes]): # TODO: filter unused fields yield from self.storage.release_get(ids) def revision_get(self, ids: List[bytes]): # TODO: filter unused fields yield from self.storage.revision_get(ids) def snapshot_get_all_branches(self, snapshot: bytes): from swh.storage.algos.snapshot import snapshot_get_all_branches # TODO: filter unused fields return snapshot_get_all_branches(self.storage, snapshot)