diff --git a/swh/storage/cassandra/converters.py b/swh/storage/cassandra/converters.py --- a/swh/storage/cassandra/converters.py +++ b/swh/storage/cassandra/converters.py @@ -42,7 +42,7 @@ return db_revision -def revision_from_db(db_revision: Row, parents: Tuple[Sha1Git]) -> Revision: +def revision_from_db(db_revision: Row, parents: Tuple[Sha1Git, ...]) -> Revision: revision = db_revision._asdict() # type: ignore metadata = json.loads(revision.pop("metadata", None)) extra_headers = revision.pop("extra_headers", ()) diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -9,7 +9,7 @@ import json import random import re -from typing import Any, Dict, List, Iterable, Optional, Tuple, Union +from typing import Any, Dict, List, Iterable, Optional, Set, Tuple, Union import attr @@ -462,10 +462,12 @@ return {"revision:add": len(revisions)} - def revision_missing(self, revisions): + def revision_missing(self, revisions: List[Sha1Git]) -> Iterable[Sha1Git]: return self._cql_runner.revision_missing(revisions) - def revision_get(self, revisions): + def revision_get( + self, revisions: List[Sha1Git] + ) -> Iterable[Optional[Dict[str, Any]]]: rows = self._cql_runner.revision_get(revisions) revs = {} for row in rows: @@ -483,7 +485,16 @@ for rev_id in revisions: yield revs.get(rev_id) - def _get_parent_revs(self, rev_ids, seen, limit, short): + def _get_parent_revs( + self, + rev_ids: Iterable[Sha1Git], + seen: Set[Sha1Git], + limit: Optional[int], + short: bool, + ) -> Union[ + Iterable[Optional[Dict[str, Any]]], + Iterable[Optional[Tuple[Sha1Git, Tuple[Sha1Git, ...]]]], + ]: if limit and len(seen) >= limit: return rev_ids = [id_ for id_ in rev_ids if id_ not in seen] @@ -517,12 +528,16 @@ yield rev.to_dict() yield from self._get_parent_revs(parents, seen, limit, short) - def revision_log(self, revisions, limit=None): - seen = set() + def revision_log( + self, revisions: List[Sha1Git], limit: Optional[int] = None + ) -> Iterable[Optional[Dict[str, Any]]]: + seen: Set[Sha1Git] = set() yield from self._get_parent_revs(revisions, seen, limit, False) - def revision_shortlog(self, revisions, limit=None): - seen = set() + def revision_shortlog( + self, revisions: List[Sha1Git], limit: Optional[int] = None + ) -> Iterable[Optional[Tuple[Sha1Git, Tuple[Sha1Git, ...]]]]: + seen: Set[Sha1Git] = set() yield from self._get_parent_revs(revisions, seen, limit, True) def revision_get_random(self) -> Sha1Git: diff --git a/swh/storage/converters.py b/swh/storage/converters.py --- a/swh/storage/converters.py +++ b/swh/storage/converters.py @@ -1,11 +1,11 @@ -# Copyright (C) 2015 The Software Heritage developers +# Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime -from typing import Optional, Dict +from typing import Any, Optional, Dict from swh.core.utils import encode_with_unescape from swh.model import identifiers @@ -184,7 +184,7 @@ } -def db_to_revision(db_revision): +def db_to_revision(db_revision: Dict[str, Any]) -> Dict[str, Any]: """Convert a database representation of a revision to its swh-model representation.""" diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -24,6 +24,7 @@ Iterator, List, Optional, + Set, Tuple, TypeVar, Union, @@ -522,19 +523,23 @@ return {"revision:add": count} - def revision_missing(self, revisions): + def revision_missing(self, revisions: List[Sha1Git]) -> Iterable[Sha1Git]: for id in revisions: if id not in self._revisions: yield id - def revision_get(self, revisions): + def revision_get( + self, revisions: List[Sha1Git] + ) -> Iterable[Optional[Dict[str, Any]]]: for id in revisions: if id in self._revisions: yield self._revisions.get(id).to_dict() else: yield None - def _get_parent_revs(self, rev_id, seen, limit): + def _get_parent_revs( + self, rev_id: Sha1Git, seen: Set[Sha1Git], limit: Optional[int] + ) -> Iterable[Optional[Dict[str, Any]]]: if limit and len(seen) >= limit: return if rev_id in seen or rev_id not in self._revisions: @@ -544,14 +549,19 @@ for parent in self._revisions[rev_id].parents: yield from self._get_parent_revs(parent, seen, limit) - def revision_log(self, revisions, limit=None): - seen = set() + def revision_log( + self, revisions: List[Sha1Git], limit: Optional[int] = None + ) -> Iterable[Optional[Dict[str, Any]]]: + seen: Set[Sha1Git] = set() for rev_id in revisions: yield from self._get_parent_revs(rev_id, seen, limit) - def revision_shortlog(self, revisions, limit=None): + def revision_shortlog( + self, revisions: List[Sha1Git], limit: Optional[int] = None + ) -> Iterable[Optional[Tuple[Sha1Git, Tuple[Sha1Git, ...]]]]: yield from ( - (rev["id"], rev["parents"]) for rev in self.revision_log(revisions, limit) + (rev["id"], rev["parents"]) if rev else None + for rev in self.revision_log(revisions, limit) ) def revision_get_random(self) -> Sha1Git: diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -508,11 +508,11 @@ ... @remote_api_endpoint("revision/missing") - def revision_missing(self, revisions): + def revision_missing(self, revisions: List[Sha1Git]) -> Iterable[Sha1Git]: """List revisions missing from storage Args: - revisions (iterable): revision ids + revisions: revision ids Yields: missing revision ids @@ -521,35 +521,40 @@ ... @remote_api_endpoint("revision") - def revision_get(self, revisions): - """Get all revisions from storage + def revision_get( + self, revisions: List[Sha1Git] + ) -> Iterable[Optional[Dict[str, Any]]]: + """Get revisions from storage Args: - revisions: an iterable of revision ids + revisions: revision ids - Returns: - iterable: an iterable of revisions as dictionaries (or None if the - revision doesn't exist) + Yields: + revisions as dictionaries (or None if the revision doesn't exist) """ ... @remote_api_endpoint("revision/log") - def revision_log(self, revisions, limit=None): + def revision_log( + self, revisions: List[Sha1Git], limit: Optional[int] = None + ) -> Iterable[Optional[Dict[str, Any]]]: """Fetch revision entry from the given root revisions. Args: - revisions: array of root revision to lookup + revisions: array of root revisions to lookup limit: limitation on the output result. Default to None. Yields: - List of revision log from such revisions root. + revision entries log from the given root root revisions """ ... @remote_api_endpoint("revision/shortlog") - def revision_shortlog(self, revisions, limit=None): + def revision_shortlog( + self, revisions: List[Sha1Git], limit: Optional[int] = None + ) -> Iterable[Optional[Tuple[Sha1Git, Tuple[Sha1Git, ...]]]]: """Fetch the shortlog for the given revisions Args: @@ -557,7 +562,7 @@ limit: depth limitation for the output Yields: - a list of (id, parents) tuples. + a list of (id, parents) tuples """ ... diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -604,16 +604,20 @@ @timed @db_transaction_generator() - def revision_missing(self, revisions, db=None, cur=None): + def revision_missing( + self, revisions: List[Sha1Git], db=None, cur=None + ) -> Iterable[Sha1Git]: if not revisions: - return + return None for obj in db.revision_missing_from_list(revisions, cur): yield obj[0] @timed @db_transaction_generator(statement_timeout=1000) - def revision_get(self, revisions, db=None, cur=None): + def revision_get( + self, revisions: List[Sha1Git], db=None, cur=None + ) -> Iterable[Optional[Dict[str, Any]]]: for line in db.revision_get_from_list(revisions, cur): data = converters.db_to_revision(dict(zip(db.revision_get_cols, line))) if not data["type"]: @@ -623,7 +627,9 @@ @timed @db_transaction_generator(statement_timeout=2000) - def revision_log(self, revisions, limit=None, db=None, cur=None): + def revision_log( + self, revisions: List[Sha1Git], limit: Optional[int] = None, db=None, cur=None + ) -> Iterable[Optional[Dict[str, Any]]]: for line in db.revision_log(revisions, limit, cur): data = converters.db_to_revision(dict(zip(db.revision_get_cols, line))) if not data["type"]: @@ -633,8 +639,9 @@ @timed @db_transaction_generator(statement_timeout=2000) - def revision_shortlog(self, revisions, limit=None, db=None, cur=None): - + def revision_shortlog( + self, revisions: List[Sha1Git], limit: Optional[int] = None, db=None, cur=None + ) -> Iterable[Optional[Tuple[Sha1Git, Tuple[Sha1Git, ...]]]]: yield from db.revision_shortlog(revisions, limit, cur) @timed