diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -607,21 +607,23 @@ return {"snapshot:add": len(snapshots)} - def snapshot_missing(self, snapshots): + def snapshot_missing(self, snapshots: List[Sha1Git]) -> Iterable[Sha1Git]: return self._cql_runner.snapshot_missing(snapshots) - def snapshot_get(self, snapshot_id): + def snapshot_get(self, snapshot_id: Sha1Git) -> Optional[Dict[str, Any]]: return self.snapshot_get_branches(snapshot_id) - def snapshot_get_by_origin_visit(self, origin, visit): + def snapshot_get_by_origin_visit( + self, origin: str, visit: int + ) -> Optional[Dict[str, Any]]: visit_status = self.origin_visit_status_get_latest( origin, visit, require_snapshot=True ) - if not visit_status: - return None - return self.snapshot_get(visit_status.snapshot) + if visit_status and visit_status.snapshot: + return self.snapshot_get(visit_status.snapshot) + return None - def snapshot_count_branches(self, snapshot_id): + def snapshot_count_branches(self, snapshot_id: Sha1Git) -> Optional[Dict[str, int]]: if self._cql_runner.snapshot_missing([snapshot_id]): # Makes sure we don't fetch branches for a snapshot that is # being added. @@ -635,14 +637,18 @@ return counts def snapshot_get_branches( - self, snapshot_id, branches_from=b"", branches_count=1000, target_types=None - ): + self, + snapshot_id: Sha1Git, + branches_from: bytes = b"", + branches_count: int = 1000, + target_types: Optional[List[str]] = None, + ) -> Optional[Dict[str, Any]]: if self._cql_runner.snapshot_missing([snapshot_id]): # Makes sure we don't fetch branches for a snapshot that is # being added. return None - branches = [] + branches: List = [] while len(branches) < branches_count + 1: new_branches = list( self._cql_runner.snapshot_branch_get( @@ -675,7 +681,7 @@ else: last_branch = None - branches = { + branches_d = { branch.name: {"target": branch.target, "target_type": branch.target_type,} if branch.target else None @@ -684,7 +690,7 @@ return { "id": snapshot_id, - "branches": branches, + "branches": branches_d, "next_branch": last_branch, } diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -608,32 +608,34 @@ return {"snapshot:add": count} - def snapshot_missing(self, snapshots): + def snapshot_missing(self, snapshots: List[Sha1Git]) -> Iterable[Sha1Git]: for id in snapshots: if id not in self._snapshots: yield id - def snapshot_get(self, snapshot_id): + def snapshot_get(self, snapshot_id: Sha1Git) -> Optional[Dict[str, Any]]: return self.snapshot_get_branches(snapshot_id) - def snapshot_get_by_origin_visit(self, origin, visit): + def snapshot_get_by_origin_visit( + self, origin: str, visit: int + ) -> Optional[Dict[str, Any]]: origin_url = self._get_origin_url(origin) if not origin_url: - return + return None if origin_url not in self._origins or visit > len( self._origin_visits[origin_url] ): return None - visit = self._origin_visit_get_updated(origin_url, visit) - snapshot_id = visit["snapshot"] + visit_d = self._origin_visit_get_updated(origin_url, visit) + snapshot_id = visit_d["snapshot"] if snapshot_id: return self.snapshot_get(snapshot_id) else: return None - def snapshot_count_branches(self, snapshot_id): + def snapshot_count_branches(self, snapshot_id: Sha1Git) -> Optional[Dict[str, int]]: snapshot = self._snapshots[snapshot_id] return collections.Counter( branch.target_type.value if branch else None @@ -641,8 +643,12 @@ ) def snapshot_get_branches( - self, snapshot_id, branches_from=b"", branches_count=1000, target_types=None - ): + self, + snapshot_id: Sha1Git, + branches_from: bytes = b"", + branches_count: int = 1000, + target_types: Optional[List[str]] = None, + ) -> Optional[Dict[str, Any]]: snapshot = self._snapshots.get(snapshot_id) if snapshot is None: return None @@ -651,7 +657,7 @@ from_index = bisect.bisect_left(sorted_branch_names, branches_from) if target_types: next_branch = None - branches = {} + branches: Dict = {} for (branch_name, branch) in sorted_branches: if branch_name in sorted_branch_names[from_index:]: if branch and branch.target_type.value in target_types: diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -677,11 +677,11 @@ ... @remote_api_endpoint("snapshot/missing") - def snapshot_missing(self, snapshots): + def snapshot_missing(self, snapshots: List[Sha1Git]) -> Iterable[Sha1Git]: """List snapshots missing from storage Args: - snapshots (iterable): an iterable of snapshot ids + snapshots: snapshot ids Yields: missing snapshot ids @@ -690,7 +690,7 @@ ... @remote_api_endpoint("snapshot") - def snapshot_get(self, snapshot_id): + def snapshot_get(self, snapshot_id: Sha1Git) -> Optional[Dict[str, Any]]: """Get the content, possibly partial, of a snapshot with the given id The branches of the snapshot are iterated in the lexicographical @@ -702,7 +702,8 @@ should be used instead. Args: - snapshot_id (bytes): identifier of the snapshot + snapshot_id: snapshot identifier + Returns: dict: a dict with three keys: * **id**: identifier of the snapshot @@ -715,7 +716,9 @@ ... @remote_api_endpoint("snapshot/by_origin_visit") - def snapshot_get_by_origin_visit(self, origin, visit): + def snapshot_get_by_origin_visit( + self, origin: str, visit: int + ) -> Optional[Dict[str, Any]]: """Get the content, possibly partial, of a snapshot for the given origin visit The branches of the snapshot are iterated in the lexicographical @@ -727,8 +730,9 @@ should be used instead. Args: - origin (int): the origin identifier - visit (int): the visit identifier + origin: origin identifier (url) + visit: the visit identifier + Returns: dict: None if the snapshot does not exist; a dict with three keys otherwise: @@ -743,37 +747,43 @@ ... @remote_api_endpoint("snapshot/count_branches") - def snapshot_count_branches(self, snapshot_id): + def snapshot_count_branches(self, snapshot_id: Sha1Git) -> Optional[Dict[str, int]]: """Count the number of branches in the snapshot with the given id Args: - snapshot_id (bytes): identifier of the snapshot + snapshot_id: snapshot identifier Returns: - dict: A dict whose keys are the target types of branches and - values their corresponding amount + A dict whose keys are the target types of branches and values their + corresponding amount + """ ... @remote_api_endpoint("snapshot/get_branches") def snapshot_get_branches( - self, snapshot_id, branches_from=b"", branches_count=1000, target_types=None - ): + self, + snapshot_id: Sha1Git, + branches_from: bytes = b"", + branches_count: int = 1000, + target_types: Optional[List[str]] = None, + ) -> Optional[Dict[str, Any]]: """Get the content, possibly partial, of a snapshot with the given id The branches of the snapshot are iterated in the lexicographical order of their names. Args: - snapshot_id (bytes): identifier of the snapshot - branches_from (bytes): optional parameter used to skip branches + snapshot_id: identifier of the snapshot + branches_from: optional parameter used to skip branches whose name is lesser than it before returning them - branches_count (int): optional parameter used to restrain + branches_count: optional parameter used to restrain the amount of returned branches - target_types (list): optional parameter used to filter the + target_types: optional parameter used to filter the target types of branch to return (possible values that can be contained in that list are `'content', 'directory', 'revision', 'release', 'snapshot', 'alias'`) + Returns: dict: None if the snapshot does not exist; a dict with three keys otherwise: diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -742,19 +742,24 @@ @timed @db_transaction_generator() - def snapshot_missing(self, snapshots, db=None, cur=None): + def snapshot_missing( + self, snapshots: List[Sha1Git], db=None, cur=None + ) -> Iterable[Sha1Git]: for obj in db.snapshot_missing_from_list(snapshots, cur): yield obj[0] @timed @db_transaction(statement_timeout=2000) - def snapshot_get(self, snapshot_id, db=None, cur=None): - + def snapshot_get( + self, snapshot_id: Sha1Git, db=None, cur=None + ) -> Optional[Dict[str, Any]]: return self.snapshot_get_branches(snapshot_id, db=db, cur=cur) @timed @db_transaction(statement_timeout=2000) - def snapshot_get_by_origin_visit(self, origin, visit, db=None, cur=None): + def snapshot_get_by_origin_visit( + self, origin: str, visit: int, db=None, cur=None + ) -> Optional[Dict[str, Any]]: snapshot_id = db.snapshot_get_by_origin_visit(origin, visit, cur) if snapshot_id: @@ -764,20 +769,22 @@ @timed @db_transaction(statement_timeout=2000) - def snapshot_count_branches(self, snapshot_id, db=None, cur=None): + def snapshot_count_branches( + self, snapshot_id: Sha1Git, db=None, cur=None + ) -> Optional[Dict[str, int]]: return dict([bc for bc in db.snapshot_count_branches(snapshot_id, cur)]) @timed @db_transaction(statement_timeout=2000) def snapshot_get_branches( self, - snapshot_id, - branches_from=b"", - branches_count=1000, - target_types=None, + snapshot_id: Sha1Git, + branches_from: bytes = b"", + branches_count: int = 1000, + target_types: Optional[List[str]] = None, db=None, cur=None, - ): + ) -> Optional[Dict[str, Any]]: if snapshot_id == EMPTY_SNAPSHOT_ID: return { "id": snapshot_id,