diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -692,14 +692,14 @@ else: return None - def origin_get_by_sha1(self, sha1s): + def origin_get_by_sha1( + self, sha1s: List[bytes] + ) -> Iterable[Optional[Dict[str, Any]]]: results = [] for sha1 in sha1s: rows = self._cql_runner.origin_get_by_sha1(sha1) - if rows: - results.append({"url": rows.one().url}) - else: - results.append(None) + origin = {"url": rows.one().url} if rows else None + results.append(origin) return results def origin_list( diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -672,8 +672,11 @@ def origin_get(self, origins: List[str]) -> Iterable[Optional[Origin]]: return [self.origin_get_one(origin_url) for origin_url in origins] - def origin_get_by_sha1(self, sha1s): - return [self._convert_origin(self._origins_by_sha1.get(sha1)) for sha1 in sha1s] + def origin_get_by_sha1( + self, sha1s: List[bytes] + ) -> Iterable[Optional[Dict[str, Any]]]: + for sha1 in sha1s: + yield self._convert_origin(self._origins_by_sha1.get(sha1)) def origin_list( self, page_token: Optional[str] = None, limit: int = 100 diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -1002,16 +1002,16 @@ ... @remote_api_endpoint("origin/get_sha1") - def origin_get_by_sha1(self, sha1s): + def origin_get_by_sha1( + self, sha1s: List[bytes] + ) -> Iterable[Optional[Dict[str, Any]]]: """Return origins, identified by the sha1 of their URLs. Args: - sha1s (list[bytes]): a list of sha1s + sha1s: a list of sha1s - Yields: - dicts containing origin information as returned - by :meth:`swh.storage.storage.Storage.origin_get`, or None if an - origin matching the sha1 is not found. + List: + Origins whose sha1 of their url match, None when the origins is not found. """ ... diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1084,7 +1084,9 @@ @timed @db_transaction_generator(statement_timeout=500) - def origin_get_by_sha1(self, sha1s, db=None, cur=None): + def origin_get_by_sha1( + self, sha1s: List[bytes], db=None, cur=None + ) -> Iterable[Optional[Dict[str, Any]]]: for line in db.origin_get_by_sha1(sha1s, cur): if line[0] is not None: yield dict(zip(db.origin_cols, line))