diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -730,9 +730,7 @@ else: return None - def origin_get_by_sha1( - self, sha1s: List[bytes] - ) -> Iterable[Optional[Dict[str, Any]]]: + def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]: results = [] for sha1 in sha1s: rows = self._cql_runner.origin_get_by_sha1(sha1) diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -705,11 +705,8 @@ def origin_get(self, origins: List[str]) -> Iterable[Optional[Origin]]: return [self.origin_get_one(origin_url) for origin_url in origins] - def origin_get_by_sha1( - self, sha1s: List[bytes] - ) -> Iterable[Optional[Dict[str, Any]]]: - for sha1 in sha1s: - yield self._convert_origin(self._origins_by_sha1.get(sha1)) + def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]: + return [self._convert_origin(self._origins_by_sha1.get(sha1)) for sha1 in sha1s] def origin_list( self, page_token: Optional[str] = None, limit: int = 100 diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -1022,16 +1022,14 @@ ... @remote_api_endpoint("origin/get_sha1") - def origin_get_by_sha1( - self, sha1s: List[bytes] - ) -> Iterable[Optional[Dict[str, Any]]]: + def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]: """Return origins, identified by the sha1 of their URLs. Args: sha1s: a list of sha1s - List: - Origins whose sha1 of their url match, None when the origins is not found. + Returns: + List of origins dict whose sha1 of their url match, None otherwise. """ ... diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1107,15 +1107,14 @@ return result @timed - @db_transaction_generator(statement_timeout=500) + @db_transaction(statement_timeout=500) def origin_get_by_sha1( self, sha1s: List[bytes], db=None, cur=None - ) -> Iterable[Optional[Dict[str, Any]]]: - for line in db.origin_get_by_sha1(sha1s, cur): - if line[0] is not None: - yield dict(zip(db.origin_cols, line)) - else: - yield None + ) -> List[Optional[Dict[str, Any]]]: + return [ + dict(zip(db.origin_cols, row)) if row[0] else None + for row in db.origin_get_by_sha1(sha1s, cur) + ] @timed @db_transaction_generator() diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1595,9 +1595,9 @@ assert origins[0]["url"] == origin.url def test_origin_get_by_sha1_not_found(self, swh_storage, sample_data): - origin = sample_data.origin - assert swh_storage.origin_get([origin.url])[0] is None - origins = list(swh_storage.origin_get_by_sha1([sha1(origin.url)])) + unknown_origin = sample_data.origin + assert swh_storage.origin_get([unknown_origin.url])[0] is None + origins = list(swh_storage.origin_get_by_sha1([sha1(unknown_origin.url)])) assert len(origins) == 1 assert origins[0] is None