diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -478,11 +478,9 @@ def revision_missing(self, revisions: List[Sha1Git]) -> Iterable[Sha1Git]: return self._cql_runner.revision_missing(revisions) - def revision_get( - self, revisions: List[Sha1Git] - ) -> Iterable[Optional[Dict[str, Any]]]: - rows = self._cql_runner.revision_get(revisions) - revs = {} + def revision_get(self, revision_ids: List[Sha1Git]) -> List[Optional[Revision]]: + rows = self._cql_runner.revision_get(revision_ids) + revisions: Dict[Sha1Git, Revision] = {} for row in rows: # TODO: use a single query to get all parents? # (it might have lower latency, but requires more code and more @@ -492,10 +490,9 @@ # parent_rank is the clustering key, so results are already # sorted by rank. rev = converters.revision_from_db(row, parents=parents) - revs[rev.id] = rev.to_dict() + revisions[rev.id] = rev - for rev_id in revisions: - yield revs.get(rev_id) + return [revisions.get(rev_id) for rev_id in revision_ids] def _get_parent_revs( self, diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -490,16 +490,14 @@ ... @remote_api_endpoint("revision") - def revision_get( - self, revisions: List[Sha1Git] - ) -> Iterable[Optional[Dict[str, Any]]]: + def revision_get(self, revision_ids: List[Sha1Git]) -> List[Optional[Revision]]: """Get revisions from storage Args: revisions: revision ids - Yields: - revisions as dictionaries (or None if the revision doesn't exist) + Returns: + list of revision object (if the revision exists or None otherwise) """ ... diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -610,16 +610,16 @@ yield obj[0] @timed - @db_transaction_generator(statement_timeout=1000) + @db_transaction(statement_timeout=1000) def revision_get( - self, revisions: List[Sha1Git], db=None, cur=None - ) -> Iterable[Optional[Dict[str, Any]]]: - for line in db.revision_get_from_list(revisions, cur): - data = converters.db_to_revision(dict(zip(db.revision_get_cols, line))) - if not data: - yield None - continue - yield data.to_dict() + self, revision_ids: List[Sha1Git], db=None, cur=None + ) -> List[Optional[Revision]]: + revisions = [] + for line in db.revision_get_from_list(revision_ids, cur): + revision = converters.db_to_revision(dict(zip(db.revision_get_cols, line))) + revisions.append(revision) + + return revisions @timed @db_transaction_generator(statement_timeout=2000) diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -902,13 +902,12 @@ assert add_result == {"revision:add": 2} # order 1 - res1 = swh_storage.revision_get([revision.id, revision2.id]) - - assert [Revision.from_dict(r) for r in res1] == [revision, revision2] + actual_revisions = swh_storage.revision_get([revision.id, revision2.id]) + assert actual_revisions == [revision, revision2] # order 2 - res2 = swh_storage.revision_get([revision2.id, revision.id]) - assert [Revision.from_dict(r) for r in res2] == [revision2, revision] + actual_revisions2 = swh_storage.revision_get([revision2.id, revision.id]) + assert actual_revisions2 == [revision2, revision] def test_revision_log(self, swh_storage, sample_data): revision1, revision2, revision3, revision4 = sample_data.revisions[:4] @@ -973,21 +972,19 @@ swh_storage.revision_add([revision]) - actual_revisions = list(swh_storage.revision_get([revision.id, revision2.id])) + actual_revisions = swh_storage.revision_get([revision.id, revision2.id]) assert len(actual_revisions) == 2 - assert Revision.from_dict(actual_revisions[0]) == revision - assert actual_revisions[1] is None + assert actual_revisions == [revision, None] def test_revision_get_no_parents(self, swh_storage, sample_data): revision = sample_data.revision swh_storage.revision_add([revision]) - get = list(swh_storage.revision_get([revision.id])) + actual_revision = swh_storage.revision_get([revision.id])[0] - assert len(get) == 1 assert revision.parents == () - assert tuple(get[0]["parents"]) == () # no parents on this one + assert actual_revision.parents == () # no parents on this one def test_revision_get_random(self, swh_storage, sample_data): revision1, revision2, revision3 = sample_data.revisions[:3] @@ -2464,10 +2461,10 @@ swh_storage.revision_add([revision, revision2]) # when getting added revisions - revisions = list(swh_storage.revision_get([revision.id, revision2.id])) + revisions = swh_storage.revision_get([revision.id, revision2.id]) # then check committers are the same - assert revisions[0]["committer"] == revisions[1]["committer"] + assert revisions[0].committer == revisions[1].committer def test_snapshot_add_get_empty(self, swh_storage, sample_data): empty_snapshot = sample_data.snapshots[1] diff --git a/swh/storage/tests/test_filter.py b/swh/storage/tests/test_filter.py --- a/swh/storage/tests/test_filter.py +++ b/swh/storage/tests/test_filter.py @@ -94,15 +94,15 @@ def test_filtering_proxy_storage_revision(swh_storage, sample_data): sample_revision = sample_data.revision - revision = next(swh_storage.revision_get([sample_revision.id])) - assert not revision + revision = swh_storage.revision_get([sample_revision.id])[0] + assert revision is None s = swh_storage.revision_add([sample_revision]) assert s == { "revision:add": 1, } - revision = next(swh_storage.revision_get([sample_revision.id])) + revision = swh_storage.revision_get([sample_revision.id])[0] assert revision is not None s = swh_storage.revision_add([sample_revision]) diff --git a/swh/storage/tests/test_retry.py b/swh/storage/tests/test_retry.py --- a/swh/storage/tests/test_retry.py +++ b/swh/storage/tests/test_retry.py @@ -620,16 +620,16 @@ """ sample_rev = sample_data.revision - revision = next(swh_storage.revision_get([sample_rev.id])) - assert not revision + revision = swh_storage.revision_get([sample_rev.id])[0] + assert revision is None s = swh_storage.revision_add([sample_rev]) assert s == { "revision:add": 1, } - revision = next(swh_storage.revision_get([sample_rev.id])) - assert revision["id"] == sample_rev.id + revision = swh_storage.revision_get([sample_rev.id])[0] + assert revision == sample_rev def test_retrying_proxy_storage_revision_add_with_retry( @@ -650,8 +650,8 @@ sample_rev = sample_data.revision - revision = next(swh_storage.revision_get([sample_rev.id])) - assert not revision + revision = swh_storage.revision_get([sample_rev.id])[0] + assert revision is None s = swh_storage.revision_add([sample_rev]) assert s == { @@ -674,8 +674,8 @@ sample_rev = sample_data.revision - revision = next(swh_storage.revision_get([sample_rev.id])) - assert not revision + revision = swh_storage.revision_get([sample_rev.id])[0] + assert revision is None with pytest.raises(StorageArgumentException, match="Refuse to add"): swh_storage.revision_add([sample_rev]) diff --git a/swh/storage/tests/test_revision_bw_compat.py b/swh/storage/tests/test_revision_bw_compat.py --- a/swh/storage/tests/test_revision_bw_compat.py +++ b/swh/storage/tests/test_revision_bw_compat.py @@ -6,7 +6,6 @@ import attr from swh.core.utils import decode_with_escape -from swh.model.model import Revision from swh.storage import get_storage from swh.storage.tests.test_postgresql import db_transaction @@ -44,4 +43,4 @@ assert metadata == bw_rev.metadata # check the Revision build from revision_get is the original, "new style", Revision - assert [Revision.from_dict(x) for x in storage.revision_get([rev.id])] == [rev] + assert storage.revision_get([rev.id]) == [rev]