diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -582,17 +582,14 @@ def release_missing(self, releases: List[Sha1Git]) -> Iterable[Sha1Git]: return self._cql_runner.release_missing(releases) - def release_get( - self, releases: List[Sha1Git] - ) -> Iterable[Optional[Dict[str, Any]]]: + def release_get(self, releases: List[Sha1Git]) -> List[Optional[Release]]: rows = self._cql_runner.release_get(releases) - rels = {} + rels: Dict[Sha1Git, Release] = {} for row in rows: release = converters.release_from_db(row) - rels[row.id] = release.to_dict() + rels[row.id] = release - for rel_id in releases: - yield rels.get(rel_id) + return [rels.get(rel_id) for rel_id in releases] def release_get_random(self) -> Sha1Git: release = self._cql_runner.release_get_random() diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -588,17 +588,15 @@ ... @remote_api_endpoint("release") - def release_get( - self, releases: List[Sha1Git] - ) -> Iterable[Optional[Dict[str, Any]]]: + def release_get(self, releases: List[Sha1Git]) -> List[Optional[Release]]: """Given a list of sha1, return the releases's information Args: releases: list of sha1s - Yields: - dicts with the same keys as those given to `release_add` - (or ``None`` if a release does not exist) + Returns: + List of releases matching the identifiers or None if the release does + not exist. """ ... diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -686,13 +686,15 @@ yield obj[0] @timed - @db_transaction_generator(statement_timeout=500) + @db_transaction(statement_timeout=500) def release_get( self, releases: List[Sha1Git], db=None, cur=None - ) -> Iterable[Optional[Dict[str, Any]]]: + ) -> List[Optional[Release]]: + rels = [] for release in db.release_get_from_list(releases, cur): data = converters.db_to_release(dict(zip(db.release_get_cols, release))) - yield data.to_dict() if data else None + rels.append(data if data else None) + return rels @timed @db_transaction() diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -29,7 +29,6 @@ OriginVisit, OriginVisitStatus, Person, - Release, Revision, Snapshot, TargetType, @@ -1076,13 +1075,12 @@ swh_storage.release_add([release, release2]) # when - releases = list(swh_storage.release_get([release.id, release2.id])) - actual_releases = [Release.from_dict(r) for r in releases] + actual_releases = swh_storage.release_get([release.id, release2.id]) # then assert actual_releases == [release, release2] - unknown_releases = list(swh_storage.release_get([release3.id])) + unknown_releases = swh_storage.release_get([release3.id]) assert unknown_releases[0] is None def test_release_get_order(self, swh_storage, sample_data): @@ -1092,12 +1090,12 @@ assert add_result == {"release:add": 2} # order 1 - res1 = swh_storage.release_get([release.id, release2.id]) - assert list(res1) == [release.to_dict(), release2.to_dict()] + actual_releases = swh_storage.release_get([release.id, release2.id]) + assert actual_releases == [release, release2] # order 2 - res2 = swh_storage.release_get([release2.id, release.id]) - assert list(res2) == [release2.to_dict(), release.to_dict()] + actual_releases2 = swh_storage.release_get([release2.id, release.id]) + assert actual_releases2 == [release2, release] def test_release_get_random(self, swh_storage, sample_data): release, release2, release3 = sample_data.releases[:3] diff --git a/swh/storage/tests/test_retry.py b/swh/storage/tests/test_retry.py --- a/swh/storage/tests/test_retry.py +++ b/swh/storage/tests/test_retry.py @@ -689,16 +689,16 @@ """ sample_rel = sample_data.release - release = next(swh_storage.release_get([sample_rel.id])) - assert not release + release = swh_storage.release_get([sample_rel.id])[0] + assert release is None s = swh_storage.release_add([sample_rel]) assert s == { "release:add": 1, } - release = next(swh_storage.release_get([sample_rel.id])) - assert release["id"] == sample_rel.id + release = swh_storage.release_get([sample_rel.id])[0] + assert release == sample_rel def test_retrying_proxy_storage_release_add_with_retry( @@ -719,8 +719,8 @@ sample_rel = sample_data.release - release = next(swh_storage.release_get([sample_rel.id])) - assert not release + release = swh_storage.release_get([sample_rel.id])[0] + assert release is None s = swh_storage.release_add([sample_rel]) assert s == { @@ -743,8 +743,8 @@ sample_rel = sample_data.release - release = next(swh_storage.release_get([sample_rel.id])) - assert not release + release = swh_storage.release_get([sample_rel.id])[0] + assert release is None with pytest.raises(StorageArgumentException, match="Refuse to add"): swh_storage.release_add([sample_rel])