diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py --- a/swh/provenance/__init__.py +++ b/swh/provenance/__init__.py @@ -1,28 +1,38 @@ -from .archive import ArchiveInterface -from .postgresql.archive import ArchivePostgreSQL +from typing import TYPE_CHECKING + from .postgresql.db_utils import connect -from .storage.archive import ArchiveStorage -from .provenance import ProvenanceInterface + +if TYPE_CHECKING: + from swh.provenance.archive import ArchiveInterface + from swh.provenance.provenance import ProvenanceInterface -def get_archive(cls: str, **kwargs) -> ArchiveInterface: +def get_archive(cls: str, **kwargs) -> "ArchiveInterface": if cls == "api": + from swh.provenance.storage.archive import ArchiveStorage + return ArchiveStorage(**kwargs["storage"]) elif cls == "direct": + from swh.provenance.postgresql.archive import ArchivePostgreSQL + conn = connect(kwargs["db"]) return ArchivePostgreSQL(conn) else: raise NotImplementedError -def get_provenance(cls: str, **kwargs) -> ProvenanceInterface: +def get_provenance(cls: str, **kwargs) -> "ProvenanceInterface": if cls == "local": conn = connect(kwargs["db"]) if kwargs.get("with_path", True): - from .postgresql.provenance_with_path import ProvenanceWithPathDB + from swh.provenance.postgresql.provenancedb_with_path import ( + ProvenanceWithPathDB, + ) return ProvenanceWithPathDB(conn) else: - from .postgresql.provenance_without_path import ProvenanceWithoutPathDB + from swh.provenance.postgresql.provenancedb_without_path import ( + ProvenanceWithoutPathDB, + ) return ProvenanceWithoutPathDB(conn) else: raise NotImplementedError diff --git a/swh/provenance/archive.py b/swh/provenance/archive.py --- a/swh/provenance/archive.py +++ b/swh/provenance/archive.py @@ -1,27 +1,27 @@ from typing import Any, Dict, List +from typing_extensions import Protocol, runtime_checkable -class ArchiveInterface: - def __init__(self, **kwargs): - raise NotImplementedError +@runtime_checkable +class ArchiveInterface(Protocol): def directory_ls(self, id: bytes) -> List[Dict[str, Any]]: - raise NotImplementedError + ... def iter_origins(self): - raise NotImplementedError + ... def iter_origin_visits(self, origin: str): - raise NotImplementedError + ... def iter_origin_visit_statuses(self, origin: str, visit: int): - raise NotImplementedError + ... def release_get(self, ids: List[bytes]): - raise NotImplementedError + ... def revision_get(self, ids: List[bytes]): - raise NotImplementedError + ... def snapshot_get_all_branches(self, snapshot: bytes): - raise NotImplementedError + ... diff --git a/swh/provenance/postgresql/provenancedb_base.py b/swh/provenance/postgresql/provenancedb_base.py --- a/swh/provenance/postgresql/provenancedb_base.py +++ b/swh/provenance/postgresql/provenancedb_base.py @@ -1,18 +1,17 @@ +from datetime import datetime import itertools import logging +from typing import Any, Dict, List, Optional + import psycopg2 import psycopg2.extras from ..model import DirectoryEntry, FileEntry from ..origin import OriginEntry -from ..provenance import ProvenanceInterface from ..revision import RevisionEntry -from datetime import datetime -from typing import Any, Dict, List, Optional - -class ProvenanceDBBase(ProvenanceInterface): +class ProvenanceDBBase: def __init__(self, conn: psycopg2.extensions.connection): # TODO: consider adding a mutex for thread safety conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -1,7 +1,9 @@ -import os from datetime import datetime +import os from typing import Dict, Generator, List, Optional, Tuple +from typing_extensions import Protocol, runtime_checkable + from .archive import ArchiveInterface from .model import DirectoryEntry, FileEntry, TreeEntry from .origin import OriginEntry @@ -13,100 +15,103 @@ return path != prefix and os.path.dirname(path) == prefix -class ProvenanceInterface: - def __init__(self, **kwargs): - raise NotImplementedError - +@runtime_checkable +class ProvenanceInterface(Protocol): def commit(self): - raise NotImplementedError + """Commit currently ongoing transactions in the backend DB""" + ... def content_add_to_directory( self, directory: DirectoryEntry, blob: FileEntry, prefix: bytes - ): - raise NotImplementedError + ) -> None: + ... def content_add_to_revision( self, revision: RevisionEntry, blob: FileEntry, prefix: bytes - ): - raise NotImplementedError + ) -> None: + ... def content_find_first( self, blobid: bytes ) -> Optional[Tuple[bytes, bytes, datetime, bytes]]: - raise NotImplementedError + ... def content_find_all( self, blobid: bytes ) -> Generator[Tuple[bytes, bytes, datetime, bytes], None, None]: - raise NotImplementedError + ... def content_get_early_date(self, blob: FileEntry) -> Optional[datetime]: - raise NotImplementedError + ... def content_get_early_dates(self, blobs: List[FileEntry]) -> Dict[bytes, datetime]: - raise NotImplementedError + ... - def content_set_early_date(self, blob: FileEntry, date: datetime): - raise NotImplementedError + def content_set_early_date(self, blob: FileEntry, date: datetime) -> None: + ... def directory_add_to_revision( self, revision: RevisionEntry, directory: DirectoryEntry, path: bytes - ): - raise NotImplementedError + ) -> None: + ... def directory_get_date_in_isochrone_frontier( self, directory: DirectoryEntry ) -> Optional[datetime]: - raise NotImplementedError + ... def directory_get_dates_in_isochrone_frontier( self, dirs: List[DirectoryEntry] ) -> Dict[bytes, datetime]: - raise NotImplementedError + ... - def directory_invalidate_in_isochrone_frontier(self, directory: DirectoryEntry): - raise NotImplementedError + def directory_invalidate_in_isochrone_frontier( + self, directory: DirectoryEntry + ) -> None: + ... def directory_set_date_in_isochrone_frontier( self, directory: DirectoryEntry, date: datetime - ): - raise NotImplementedError + ) -> None: + ... def origin_get_id(self, origin: OriginEntry) -> int: - raise NotImplementedError + ... - def revision_add(self, revision: RevisionEntry): - raise NotImplementedError + def revision_add(self, revision: RevisionEntry) -> None: + ... def revision_add_before_revision( self, relative: RevisionEntry, revision: RevisionEntry - ): - raise NotImplementedError + ) -> None: + ... - def revision_add_to_origin(self, origin: OriginEntry, revision: RevisionEntry): - raise NotImplementedError + def revision_add_to_origin( + self, origin: OriginEntry, revision: RevisionEntry + ) -> None: + ... def revision_get_early_date(self, revision: RevisionEntry) -> Optional[datetime]: - raise NotImplementedError + ... def revision_get_preferred_origin(self, revision: RevisionEntry) -> int: - raise NotImplementedError + ... def revision_in_history(self, revision: RevisionEntry) -> bool: - raise NotImplementedError + ... def revision_set_preferred_origin( self, origin: OriginEntry, revision: RevisionEntry - ): - raise NotImplementedError + ) -> None: + ... def revision_visited(self, revision: RevisionEntry) -> bool: - raise NotImplementedError + ... def directory_process_content( provenance: ProvenanceInterface, directory: DirectoryEntry, relative: DirectoryEntry -): +) -> None: stack = [(directory, b"")] while stack: current, prefix = stack.pop() @@ -119,7 +124,7 @@ stack.append((child, os.path.join(prefix, child.name))) -def origin_add(provenance: ProvenanceInterface, origin: OriginEntry): +def origin_add(provenance: ProvenanceInterface, origin: OriginEntry) -> None: # TODO: refactor to iterate over origin visit statuses and commit only once # per status. origin.id = provenance.origin_get_id(origin) @@ -131,7 +136,7 @@ def origin_add_revision( provenance: ProvenanceInterface, origin: OriginEntry, revision: RevisionEntry -): +) -> None: stack: List[Tuple[Optional[RevisionEntry], RevisionEntry]] = [(None, revision)] while stack: @@ -183,7 +188,7 @@ def revision_add( provenance: ProvenanceInterface, archive: ArchiveInterface, revision: RevisionEntry -): +) -> None: assert revision.date is not None assert revision.root is not None # Processed content starting from the revision's root directory. @@ -218,7 +223,7 @@ def build_isochrone_graph( provenance: ProvenanceInterface, revision: RevisionEntry, directory: DirectoryEntry -): +) -> IsochroneNode: assert revision.date is not None # Build the nodes structure root = IsochroneNode(directory) @@ -290,6 +295,7 @@ stack = [(build_isochrone_graph(provenance, revision, root), root.name)] while stack: current, path = stack.pop() + assert isinstance(current.entry, DirectoryEntry) if current.date is not None: assert current.date < revision.date # Current directory is an outer isochrone frontier for a previously diff --git a/swh/provenance/storage/archive.py b/swh/provenance/storage/archive.py --- a/swh/provenance/storage/archive.py +++ b/swh/provenance/storage/archive.py @@ -1,19 +1,17 @@ -from typing import List +from typing import Any, Dict, List # from functools import lru_cache from methodtools import lru_cache from swh.storage import get_storage -from ..archive import ArchiveInterface - -class ArchiveStorage(ArchiveInterface): +class ArchiveStorage: def __init__(self, cls: str, **kwargs): self.storage = get_storage(cls, **kwargs) @lru_cache(maxsize=1000000) - def directory_ls(self, id: bytes): + def directory_ls(self, id: bytes) -> List[Dict[str, Any]]: # TODO: filter unused fields return [entry for entry in self.storage.directory_ls(id)]