diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -299,7 +299,7 @@ ) def _execute_many_with_retries( self, statement, args_list: List[Tuple] - ) -> Iterable[BaseRow]: + ) -> Iterable[Dict[str, Any]]: for res in execute_concurrent_with_args(self._session, statement, args_list): yield from res.result_or_exc @@ -424,8 +424,11 @@ @_prepared_select_statement( ContentRow, f"WHERE token({', '.join(ContentRow.PARTITION_KEY)}) = ?" ) - def content_get_from_token(self, token, *, statement) -> Iterable[ContentRow]: - return map(ContentRow.from_dict, self._execute_with_retries(statement, [token])) + def content_get_from_tokens(self, tokens, *, statement) -> Iterable[ContentRow]: + return map( + ContentRow.from_dict, + self._execute_many_with_retries(statement, [(token,) for token in tokens]), + ) @_prepared_select_statement( ContentRow, f"WHERE token({', '.join(ContentRow.PARTITION_KEY)}) > ? LIMIT 1" diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -163,19 +163,17 @@ to get tokens. Then, looks up the main table (content) to get all contents with that token, and filters out contents whose hash doesn't match.""" - found_tokens = self._cql_runner.content_get_tokens_from_single_algo( - algo, hashes + found_tokens = list( + self._cql_runner.content_get_tokens_from_single_algo(algo, hashes) ) + assert all(isinstance(token, int) for token in found_tokens) - for token in found_tokens: - assert isinstance(token, int), found_tokens - # Query the main table ('content'). - res = self._cql_runner.content_get_from_token(token) - - for row in res: - # re-check the the hash (in case of murmur3 collision) - if getattr(row, algo) in hashes: - yield row + # Query the main table ('content'). + rows = self._cql_runner.content_get_from_tokens(found_tokens) + for row in rows: + # re-check the the hash (in case of murmur3 collision) + if getattr(row, algo) in hashes: + yield row def _content_add(self, contents: List[Content], with_data: bool) -> Dict[str, int]: # Filter-out content already in the database. diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -198,8 +198,8 @@ primary_key = self._contents.primary_key_from_dict(content_hashes) return self._contents.get_from_primary_key(primary_key) - def content_get_from_token(self, token: int) -> Iterable[ContentRow]: - return self._contents.get_from_token(token) + def content_get_from_tokens(self, tokens: List[int]) -> Iterable[ContentRow]: + return itertools.chain.from_iterable(map(self._contents.get_from_token, tokens)) def content_get_random(self) -> Optional[ContentRow]: return self._contents.get_random() diff --git a/swh/storage/tests/test_cassandra.py b/swh/storage/tests/test_cassandra.py --- a/swh/storage/tests/test_cassandra.py +++ b/swh/storage/tests/test_cassandra.py @@ -287,7 +287,7 @@ ) # For all tokens, always return cont - def mock_cgft(token): + def mock_cgft(tokens): nonlocal called called += 1 return [ @@ -300,7 +300,7 @@ ] mocker.patch.object( - swh_storage._cql_runner, "content_get_from_token", mock_cgft + swh_storage._cql_runner, "content_get_from_tokens", mock_cgft ) actual_result = swh_storage.content_add([cont2]) @@ -337,7 +337,7 @@ # For all tokens, always return cont and cont2 cols = list(set(cont.to_dict()) - {"data"}) - def mock_cgft(token): + def mock_cgft(tokens): nonlocal called called += 1 return [ @@ -346,7 +346,7 @@ ] mocker.patch.object( - swh_storage._cql_runner, "content_get_from_token", mock_cgft + swh_storage._cql_runner, "content_get_from_tokens", mock_cgft ) actual_result = swh_storage.content_get([cont.sha1]) @@ -382,7 +382,7 @@ # For all tokens, always return cont and cont2 cols = list(set(cont.to_dict()) - {"data"}) - def mock_cgft(token): + def mock_cgft(tokens): nonlocal called called += 1 return [ @@ -391,7 +391,7 @@ ] mocker.patch.object( - swh_storage._cql_runner, "content_get_from_token", mock_cgft + swh_storage._cql_runner, "content_get_from_tokens", mock_cgft ) expected_content = attr.evolve(cont, data=None) diff --git a/swh/storage/tests/test_cassandra_migration.py b/swh/storage/tests/test_cassandra_migration.py --- a/swh/storage/tests/test_cassandra_migration.py +++ b/swh/storage/tests/test_cassandra_migration.py @@ -88,11 +88,12 @@ ContentRowWithXor, f"WHERE token({', '.join(ContentRowWithXor.PARTITION_KEY)}) = ?", ) - def content_get_from_token( - self, token, *, statement + def content_get_from_tokens( + self, tokens, *, statement ) -> Iterable[ContentRowWithXor]: return map( - ContentRowWithXor.from_dict, self._execute_with_retries(statement, [token]) + ContentRowWithXor.from_dict, + self._execute_many_with_retries(statement, [(token,) for token in tokens]), ) # Redecorate content_add_prepare with the new ContentRow class @@ -219,12 +220,12 @@ ContentRowWithXorPK, f"WHERE token({', '.join(ContentRowWithXorPK.PARTITION_KEY)}) = ?", ) - def content_get_from_token( - self, token, *, statement + def content_get_from_tokens( + self, tokens, *, statement ) -> Iterable[ContentRowWithXorPK]: return map( ContentRowWithXorPK.from_dict, - self._execute_with_retries(statement, [token]), + self._execute_many_with_retries(statement, [(token,) for token in tokens]), ) # Redecorate content_add_prepare with the new ContentRow class