diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py --- a/swh/provenance/__init__.py +++ b/swh/provenance/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -5,7 +7,20 @@ from .provenance import ProvenanceInterface, ProvenanceStorageInterface -def get_archive(cls: str, **kwargs) -> "ArchiveInterface": +def get_archive(cls: str, **kwargs) -> ArchiveInterface: + """Get an archive object of class ``cls`` with arguments ``args``. + + Args: + cls: archive's class, either 'api' or 'direct' + args: dictionary of arguments passed to the archive class constructor + + Returns: + an instance of archive object (either using swh.storage API or direct + queries to the archive's database) + + Raises: + :cls:`ValueError` if passed an unknown archive class. + """ if cls == "api": from swh.storage import get_storage @@ -19,24 +34,44 @@ return ArchivePostgreSQL(BaseDb.connect(**kwargs["db"]).conn) else: - raise NotImplementedError + raise ValueError + +def get_provenance(**kwargs) -> ProvenanceInterface: + """Get an provenance object with arguments ``args``. -def get_provenance(**kwargs) -> "ProvenanceInterface": + Args: + args: dictionary of arguments to retrieve a swh.provenance.storage + class (see :func:`get_provenance_storage` for details) + + Returns: + an instance of provenance object + """ from .backend import ProvenanceBackend return ProvenanceBackend(get_provenance_storage(**kwargs)) -def get_provenance_storage(cls: str, **kwargs) -> "ProvenanceStorageInterface": +def get_provenance_storage(cls: str, **kwargs) -> ProvenanceStorageInterface: + """Get an archive object of class ``cls`` with arguments ``args``. + + Args: + cls: storage's class, only 'local' is currently supported + args: dictionary of arguments passed to the storage class constructor + + Returns: + an instance of storage object + + Raises: + :cls:`ValueError` if passed an unknown archive class. + """ if cls == "local": from swh.core.db import BaseDb from .postgresql.provenancedb_base import ProvenanceDBBase conn = BaseDb.connect(**kwargs["db"]).conn - flavor = ProvenanceDBBase(conn).flavor - if flavor == "with-path": + if ProvenanceDBBase(conn).flavor == "with-path": from .postgresql.provenancedb_with_path import ProvenanceWithPathDB return ProvenanceWithPathDB(conn) @@ -45,4 +80,4 @@ return ProvenanceWithoutPathDB(conn) else: - raise NotImplementedError + raise ValueError diff --git a/swh/provenance/backend.py b/swh/provenance/backend.py --- a/swh/provenance/backend.py +++ b/swh/provenance/backend.py @@ -58,7 +58,7 @@ # TODO: maybe move this to a separate file class ProvenanceBackend: - def __init__(self, storage: ProvenanceStorageInterface): + def __init__(self, storage: ProvenanceStorageInterface) -> None: self.storage = storage self.cache = new_cache() diff --git a/swh/provenance/cli.py b/swh/provenance/cli.py --- a/swh/provenance/cli.py +++ b/swh/provenance/cli.py @@ -77,7 +77,7 @@ help="""Enable profiling to specified file.""", ) @click.pass_context -def cli(ctx, config_file: Optional[str], profile: str) -> None: +def cli(ctx: click.core.Context, config_file: Optional[str], profile: str) -> None: if config_file is None and config.config_exists(DEFAULT_PATH): config_file = DEFAULT_PATH @@ -116,7 +116,7 @@ @click.option("-r", "--reuse", default=True, type=bool) @click.pass_context def iter_revisions( - ctx, + ctx: click.core.Context, filename: str, track_all: bool, limit: Optional[int], @@ -161,7 +161,7 @@ @click.argument("filename") @click.option("-l", "--limit", type=int) @click.pass_context -def iter_origins(ctx, filename: str, limit: Optional[int]) -> None: +def iter_origins(ctx: click.core.Context, filename: str, limit: Optional[int]) -> None: """Process a provided list of origins.""" from . import get_archive, get_provenance from .origin import CSVOriginIterator, origin_add @@ -185,7 +185,7 @@ @cli.command(name="find-first") @click.argument("swhid") @click.pass_context -def find_first(ctx, swhid: str) -> None: +def find_first(ctx: click.core.Context, swhid: str) -> None: """Find first occurrence of the requested blob.""" from . import get_provenance @@ -208,7 +208,7 @@ @click.argument("swhid") @click.option("-l", "--limit", type=int) @click.pass_context -def find_all(ctx, swhid: str, limit: Optional[int]) -> None: +def find_all(ctx: click.core.Context, swhid: str, limit: Optional[int]) -> None: """Find all occurrences of the requested blob.""" from . import get_provenance diff --git a/swh/provenance/graph.py b/swh/provenance/graph.py --- a/swh/provenance/graph.py +++ b/swh/provenance/graph.py @@ -1,7 +1,9 @@ +from __future__ import annotations + from datetime import datetime, timezone import logging import os -from typing import Dict, Optional, Set +from typing import Any, Dict, Optional, Set from swh.model.model import Sha1Git @@ -15,7 +17,7 @@ class HistoryNode: def __init__( self, entry: RevisionEntry, visited: bool = False, in_history: bool = False - ): + ) -> None: self.entry = entry # A revision is `visited` if it is directly pointed by an origin (ie. a head # revision for some snapshot) @@ -27,21 +29,21 @@ def add_parent( self, parent: RevisionEntry, visited: bool = False, in_history: bool = False - ) -> "HistoryNode": + ) -> HistoryNode: node = HistoryNode(parent, visited=visited, in_history=in_history) self.parents.add(node) return node - def __str__(self): + def __str__(self) -> str: return ( f"<{self.entry}: visited={self.visited}, in_history={self.in_history}, " f"parents=[{', '.join(str(parent) for parent in self.parents)}]>" ) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, HistoryNode) and self.__dict__ == other.__dict__ - def __hash__(self): + def __hash__(self) -> int: return hash((self.entry, self.visited, self.in_history)) @@ -86,7 +88,7 @@ dbdate: Optional[datetime] = None, depth: int = 0, prefix: bytes = b"", - ): + ) -> None: self.entry = entry self.depth = depth @@ -105,11 +107,11 @@ self.children: Set[IsochroneNode] = set() @property - def dbdate(self): + def dbdate(self) -> Optional[datetime]: # use a property to make this attribute (mostly) read-only return self._dbdate - def invalidate(self): + def invalidate(self) -> None: self._dbdate = None self.maxdate = None self.known = False @@ -117,7 +119,7 @@ def add_directory( self, child: DirectoryEntry, date: Optional[datetime] = None - ) -> "IsochroneNode": + ) -> IsochroneNode: # we should not be processing this node (ie add subdirectories or files) if it's # actually known by the provenance DB assert self.dbdate is None @@ -125,18 +127,18 @@ self.children.add(node) return node - def __str__(self): + def __str__(self) -> str: return ( f"<{self.entry}: depth={self.depth}, " f"dbdate={self.dbdate}, maxdate={self.maxdate}, " - f"known={self.known}, invalid={self.invalid}, path={self.path}, " + f"known={self.known}, invalid={self.invalid}, path={self.path!r}, " f"children=[{', '.join(str(child) for child in self.children)}]>" ) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, IsochroneNode) and self.__dict__ == other.__dict__ - def __hash__(self): + def __hash__(self) -> int: # only immutable attributes are considered to compute hash return hash((self.entry, self.depth, self.path)) diff --git a/swh/provenance/model.py b/swh/provenance/model.py --- a/swh/provenance/model.py +++ b/swh/provenance/model.py @@ -3,6 +3,8 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from __future__ import annotations + from datetime import datetime from typing import Iterable, Iterator, List, Optional @@ -14,20 +16,20 @@ class OriginEntry: - def __init__(self, url: str, snapshot: Sha1Git): + def __init__(self, url: str, snapshot: Sha1Git) -> None: self.url = url self.id: Sha1Git = hash_to_bytes(origin_identifier({"url": self.url})) self.snapshot = snapshot self._revisions: Optional[List[RevisionEntry]] = None - def retrieve_revisions(self, archive: ArchiveInterface): + def retrieve_revisions(self, archive: ArchiveInterface) -> None: if self._revisions is None: self._revisions = [ RevisionEntry(rev) for rev in archive.snapshot_get_heads(self.snapshot) ] @property - def revisions(self) -> Iterator["RevisionEntry"]: + def revisions(self) -> Iterator[RevisionEntry]: if self._revisions is None: raise RuntimeError( "Revisions of this node has not yet been retrieved. " @@ -35,7 +37,7 @@ ) return (x for x in self._revisions) - def __str__(self): + def __str__(self) -> str: return f"" @@ -46,7 +48,7 @@ date: Optional[datetime] = None, root: Optional[Sha1Git] = None, parents: Optional[Iterable[Sha1Git]] = None, - ): + ) -> None: self.id = id self.date = date assert self.date is None or self.date.tzinfo is not None @@ -54,14 +56,14 @@ self._parents_ids = parents self._parents_entries: Optional[List[RevisionEntry]] = None - def retrieve_parents(self, archive: ArchiveInterface): + def retrieve_parents(self, archive: ArchiveInterface) -> None: if self._parents_entries is None: if self._parents_ids is None: self._parents_ids = archive.revision_get_parents(self.id) self._parents_entries = [RevisionEntry(id) for id in self._parents_ids] @property - def parents(self) -> Iterator["RevisionEntry"]: + def parents(self) -> Iterator[RevisionEntry]: if self._parents_entries is None: raise RuntimeError( "Parents of this node has not yet been retrieved. " @@ -69,24 +71,24 @@ ) return (x for x in self._parents_entries) - def __str__(self): + def __str__(self) -> str: return f"" - def __eq__(self, other): + def __eq__(self, other) -> bool: return isinstance(other, RevisionEntry) and self.id == other.id - def __hash__(self): + def __hash__(self) -> int: return hash(self.id) class DirectoryEntry: - def __init__(self, id: Sha1Git, name: bytes = b""): + def __init__(self, id: Sha1Git, name: bytes = b"") -> None: self.id = id self.name = name self._files: Optional[List[FileEntry]] = None self._dirs: Optional[List[DirectoryEntry]] = None - def retrieve_children(self, archive: ArchiveInterface): + def retrieve_children(self, archive: ArchiveInterface) -> None: if self._files is None and self._dirs is None: self._files = [] self._dirs = [] @@ -99,7 +101,7 @@ self._files.append(FileEntry(child["target"], child["name"])) @property - def files(self) -> Iterator["FileEntry"]: + def files(self) -> Iterator[FileEntry]: if self._files is None: raise RuntimeError( "Children of this node has not yet been retrieved. " @@ -108,7 +110,7 @@ return (x for x in self._files) @property - def dirs(self) -> Iterator["DirectoryEntry"]: + def dirs(self) -> Iterator[DirectoryEntry]: if self._dirs is None: raise RuntimeError( "Children of this node has not yet been retrieved. " @@ -116,32 +118,32 @@ ) return (x for x in self._dirs) - def __str__(self): - return f"" + def __str__(self) -> str: + return f"" - def __eq__(self, other): + def __eq__(self, other) -> bool: return isinstance(other, DirectoryEntry) and (self.id, self.name) == ( other.id, other.name, ) - def __hash__(self): + def __hash__(self) -> int: return hash((self.id, self.name)) class FileEntry: - def __init__(self, id: Sha1Git, name: bytes): + def __init__(self, id: Sha1Git, name: bytes) -> None: self.id = id self.name = name - def __str__(self): - return f"" + def __str__(self) -> str: + return f"" - def __eq__(self, other): + def __eq__(self, other) -> bool: return isinstance(other, FileEntry) and (self.id, self.name) == ( other.id, other.name, ) - def __hash__(self): + def __hash__(self) -> int: return hash((self.id, self.name)) diff --git a/swh/provenance/postgresql/archive.py b/swh/provenance/postgresql/archive.py --- a/swh/provenance/postgresql/archive.py +++ b/swh/provenance/postgresql/archive.py @@ -8,16 +8,16 @@ class ArchivePostgreSQL: - def __init__(self, conn: psycopg2.extensions.connection): + def __init__(self, conn: psycopg2.extensions.connection) -> None: self.conn = conn self.storage = Storage(conn, objstorage={"cls": "memory"}) def directory_ls(self, id: Sha1Git) -> Iterable[Dict[str, Any]]: - entries = self.directory_ls_internal(id) + entries = self._directory_ls(id) yield from entries @lru_cache(maxsize=100000) - def directory_ls_internal(self, id: Sha1Git) -> List[Dict[str, Any]]: + def _directory_ls(self, id: Sha1Git) -> List[Dict[str, Any]]: # TODO: add file size filtering with self.conn.cursor() as cursor: cursor.execute( diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -9,6 +9,14 @@ from .model import DirectoryEntry, FileEntry, OriginEntry, RevisionEntry +class RelationType(enum.Enum): + CNT_EARLY_IN_REV = "content_in_revision" + CNT_IN_DIR = "content_in_directory" + DIR_IN_REV = "directory_in_revision" + REV_IN_ORG = "revision_in_origin" + REV_BEFORE_REV = "revision_before_revision" + + class ProvenanceResult: def __init__( self, @@ -25,9 +33,103 @@ self.path = path +@runtime_checkable +class ProvenanceStorageInterface(Protocol): + raise_on_commit: bool = False + + def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: + """Retrieve the first occurrence of the blob identified by `id`.""" + ... + + def content_find_all( + self, id: Sha1Git, limit: Optional[int] = None + ) -> Generator[ProvenanceResult, None, None]: + """Retrieve all the occurrences of the blob identified by `id`.""" + ... + + def content_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: + """Associate dates to blobs identified by sha1 ids, as paired in `dates`. Return + a boolean stating whether the information was successfully stored. + """ + ... + + def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: + """Retrieve the associated date for each blob sha1 in `ids`. If some blob has + no associated date, it is not present in the resulting dictionary. + """ + ... + + def directory_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: + """Associate dates to directories identified by sha1 ids, as paired in + `dates`. Return a boolean stating whether the information was successfully + stored. + """ + ... + + def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: + """Retrieve the associated date for each directory sha1 in `ids`. If some + directory has no associated date, it is not present in the resulting dictionary. + """ + ... + + def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool: + """Associate urls to origins identified by sha1 ids, as paired in `urls`. Return + a boolean stating whether the information was successfully stored. + """ + ... + + def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: + """Retrieve the associated url for each origin sha1 in `ids`. If some origin has + no associated date, it is not present in the resulting dictionary. + """ + ... + + def revision_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: + """Associate dates to revisions identified by sha1 ids, as paired in `dates`. + Return a boolean stating whether the information was successfully stored. + """ + ... + + def revision_set_origin(self, origins: Dict[Sha1Git, Sha1Git]) -> bool: + """Associate origins to revisions identified by sha1 ids, as paired in + `origins` (revision ids are keys and origin ids, values). Return a boolean + stating whether the information was successfully stored. + """ + ... + + def revision_get( + self, ids: Iterable[Sha1Git] + ) -> Dict[Sha1Git, Tuple[Optional[datetime], Optional[Sha1Git]]]: + """Retrieve the associated date and origin for each revision sha1 in `ids`. If + some revision has no associated date nor origin, it is not present in the + resulting dictionary. + """ + ... + + def relation_add( + self, + relation: RelationType, + data: Iterable[Tuple[Sha1Git, Sha1Git, Optional[bytes]]], + ) -> bool: + """Add entries in the selected `relation`. Each tuple in `data` is of the from + (`src`, `dst`, `path`), where `src` and `dst` are the sha1 ids of the entities + being related, and `path` is optional depending on the selected `relation`. + """ + ... + + def relation_get( + self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False + ) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]: + """Retrieve all tuples in the selected `relation` whose source entities are + identified by some sha1 id in `ids`. If `reverse` is set, destination entities + are matched instead. + """ + ... + + @runtime_checkable class ProvenanceInterface(Protocol): - storage: "ProvenanceStorageInterface" + storage: ProvenanceStorageInterface def flush(self) -> None: """Flush internal cache to the underlying `storage`.""" @@ -160,105 +262,3 @@ provenance model. """ ... - - -class RelationType(enum.Enum): - CNT_EARLY_IN_REV = "content_in_revision" - CNT_IN_DIR = "content_in_directory" - DIR_IN_REV = "directory_in_revision" - REV_IN_ORG = "revision_in_origin" - REV_BEFORE_REV = "revision_before_revision" - - -@runtime_checkable -class ProvenanceStorageInterface(Protocol): - raise_on_commit: bool = False - - def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: - """Retrieve the first occurrence of the blob identified by `id`.""" - ... - - def content_find_all( - self, id: Sha1Git, limit: Optional[int] = None - ) -> Generator[ProvenanceResult, None, None]: - """Retrieve all the occurrences of the blob identified by `id`.""" - ... - - def content_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: - """Associate dates to blobs identified by sha1 ids, as paired in `dates`. Return - a boolean stating whether the information was successfully stored. - """ - ... - - def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: - """Retrieve the associated date for each blob sha1 in `ids`. If some blob has - no associated date, it is not present in the resulting dictionary. - """ - ... - - def directory_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: - """Associate dates to directories identified by sha1 ids, as paired in - `dates`. Return a boolean stating whether the information was successfully - stored. - """ - ... - - def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: - """Retrieve the associated date for each directory sha1 in `ids`. If some - directory has no associated date, it is not present in the resulting dictionary. - """ - ... - - def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool: - """Associate urls to origins identified by sha1 ids, as paired in `urls`. Return - a boolean stating whether the information was successfully stored. - """ - ... - - def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: - """Retrieve the associated url for each origin sha1 in `ids`. If some origin has - no associated date, it is not present in the resulting dictionary. - """ - ... - - def revision_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: - """Associate dates to revisions identified by sha1 ids, as paired in `dates`. - Return a boolean stating whether the information was successfully stored. - """ - ... - - def revision_set_origin(self, origins: Dict[Sha1Git, Sha1Git]) -> bool: - """Associate origins to revisions identified by sha1 ids, as paired in - `origins` (revision ids are keys and origin ids, values). Return a boolean - stating whether the information was successfully stored. - """ - ... - - def revision_get( - self, ids: Iterable[Sha1Git] - ) -> Dict[Sha1Git, Tuple[Optional[datetime], Optional[Sha1Git]]]: - """Retrieve the associated date and origin for each revision sha1 in `ids`. If - some revision has no associated date nor origin, it is not present in the - resulting dictionary. - """ - ... - - def relation_add( - self, - relation: RelationType, - data: Iterable[Tuple[Sha1Git, Sha1Git, Optional[bytes]]], - ) -> bool: - """Add entries in the selected `relation`. Each tuple in `data` is of the from - (`src`, `dst`, `path`), where `src` and `dst` are the sha1 ids of the entities - being related, and `path` is optional depending on the selected `relation`. - """ - ... - - def relation_get( - self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False - ) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]: - """Retrieve all tuples in the selected `relation` whose source entities are - identified by some sha1 id in `ids`. If `reverse` is set, destination entities - are matched instead. - """ - ... diff --git a/swh/provenance/revision.py b/swh/provenance/revision.py --- a/swh/provenance/revision.py +++ b/swh/provenance/revision.py @@ -1,5 +1,4 @@ from datetime import datetime, timezone -from itertools import islice import logging import os import time @@ -33,6 +32,8 @@ ) -> None: self.revisions: Iterator[Tuple[Sha1Git, datetime, Sha1Git]] if limit is not None: + from itertools import islice + self.revisions = islice(revisions, limit) else: self.revisions = iter(revisions)