diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -731,9 +731,7 @@ else: return None - def origin_get_by_sha1( - self, sha1s: List[bytes] - ) -> Iterable[Optional[Dict[str, Any]]]: + def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]: results = [] for sha1 in sha1s: rows = self._cql_runner.origin_get_by_sha1(sha1) diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -705,11 +705,8 @@ def origin_get(self, origins: List[str]) -> Iterable[Optional[Origin]]: return [self.origin_get_one(origin_url) for origin_url in origins] - def origin_get_by_sha1( - self, sha1s: List[bytes] - ) -> Iterable[Optional[Dict[str, Any]]]: - for sha1 in sha1s: - yield self._convert_origin(self._origins_by_sha1.get(sha1)) + def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]: + return [self._convert_origin(self._origins_by_sha1.get(sha1)) for sha1 in sha1s] def origin_list( self, page_token: Optional[str] = None, limit: int = 100 diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -1022,16 +1022,14 @@ ... @remote_api_endpoint("origin/get_sha1") - def origin_get_by_sha1( - self, sha1s: List[bytes] - ) -> Iterable[Optional[Dict[str, Any]]]: + def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]: """Return origins, identified by the sha1 of their URLs. Args: sha1s: a list of sha1s - List: - Origins whose sha1 of their url match, None when the origins is not found. + Returns: + List of origins dict whose sha1 of their url match, None otherwise. """ ... diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1107,15 +1107,14 @@ return result @timed - @db_transaction_generator(statement_timeout=500) + @db_transaction(statement_timeout=500) def origin_get_by_sha1( self, sha1s: List[bytes], db=None, cur=None - ) -> Iterable[Optional[Dict[str, Any]]]: - for line in db.origin_get_by_sha1(sha1s, cur): - if line[0] is not None: - yield dict(zip(db.origin_cols, line)) - else: - yield None + ) -> List[Optional[Dict[str, Any]]]: + return [ + dict(zip(db.origin_cols, row)) if row else None + for row in db.origin_get_by_sha1(sha1s, cur) + ] @timed @db_transaction_generator()