diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -299,8 +299,9 @@ ) def _execute_many_with_retries( self, statement, args_list: List[Tuple] - ) -> ResultSet: - return execute_concurrent_with_args(self._session, statement, args_list) + ) -> Iterable[BaseRow]: + for res in execute_concurrent_with_args(self._session, statement, args_list): + yield from res.result_or_exc def _add_one(self, statement, obj: BaseRow) -> None: self._execute_with_retries(statement, dataclasses.astuple(obj)) @@ -308,8 +309,10 @@ def _add_many(self, statement, objs: Sequence[BaseRow]) -> None: tables = {obj.TABLE for obj in objs} assert len(tables) == 1, f"Cannot insert to multiple tables: {tables}" - (table,) = tables - self._execute_many_with_retries(statement, list(map(dataclasses.astuple, objs))) + rows = list(map(dataclasses.astuple, objs)) + for _ in self._execute_many_with_retries(statement, rows): + # Need to consume the generator to actually run the INSERTs + pass _T = TypeVar("_T", bound=BaseRow) @@ -475,8 +478,8 @@ """ self._execute_with_retries(query, [content.get_hash(algo), token]) - def content_get_tokens_from_single_hash( - self, algo: str, hash_: bytes + def content_get_tokens_from_single_algo( + self, algo: str, hashes: List[bytes] ) -> Iterable[int]: assert algo in HASH_ALGORITHMS query = f""" @@ -485,7 +488,10 @@ WHERE {algo} = %s """ return ( - row["target_token"] for row in self._execute_with_retries(query, [hash_]) + row["target_token"] # type: ignore + for row in self._execute_many_with_retries( + query, [(hash_,) for hash_ in hashes] + ) ) ########################## diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -157,13 +157,15 @@ return True - def _content_get_from_hash(self, algo, hash_) -> Iterable: + def _content_get_from_hashes(self, algo, hashes: List[bytes]) -> Iterable: """From the name of a hash algorithm and a value of that hash, looks up the "hash -> token" secondary table (content_by_{algo}) to get tokens. Then, looks up the main table (content) to get all contents with that token, and filters out contents whose hash doesn't match.""" - found_tokens = self._cql_runner.content_get_tokens_from_single_hash(algo, hash_) + found_tokens = self._cql_runner.content_get_tokens_from_single_algo( + algo, hashes + ) for token in found_tokens: assert isinstance(token, int), found_tokens @@ -172,7 +174,7 @@ for row in res: # re-check the the hash (in case of murmur3 collision) - if getattr(row, algo) == hash_: + if getattr(row, algo) in hashes: yield row def _content_add(self, contents: List[Content], with_data: bool) -> Dict[str, int]: @@ -216,7 +218,8 @@ collisions = [] # Get tokens of 'content' rows with the same value for # sha1/sha1_git - rows = self._content_get_from_hash(algo, content.get_hash(algo)) + # TODO: batch these requests, instead of sending them one by one + rows = self._content_get_from_hashes(algo, [content.get_hash(algo)]) for row in rows: if getattr(row, algo) != content.get_hash(algo): # collision of token(partition key), ignore this @@ -334,42 +337,53 @@ key = operator.attrgetter(algo) contents_by_hash: Dict[Sha1, Optional[Content]] = {} - for hash_ in contents: + for row in self._content_get_from_hashes(algo, contents): # Get all (sha1, sha1_git, sha256, blake2s256) whose sha1/sha1_git # matches the argument, from the index table ('content_by_*') - for row in self._content_get_from_hash(algo, hash_): - row_d = row.to_dict() - row_d.pop("ctime") - content = Content(**row_d) - contents_by_hash[key(content)] = content + row_d = row.to_dict() + row_d.pop("ctime") + content = Content(**row_d) + contents_by_hash[key(content)] = content return [contents_by_hash.get(hash_) for hash_ in contents] @timed def content_find(self, content: Dict[str, Any]) -> List[Content]: + return self._content_find_many([content]) + + def _content_find_many(self, contents: List[Dict[str, Any]]) -> List[Content]: # Find an algorithm that is common to all the requested contents. # It will be used to do an initial filtering efficiently. - filter_algos = list(set(content).intersection(HASH_ALGORITHMS)) + filter_algos = set(HASH_ALGORITHMS) + for content in contents: + filter_algos &= set(content) if not filter_algos: raise StorageArgumentException( "content keys must contain at least one " f"of: {', '.join(sorted(HASH_ALGORITHMS))}" ) - common_algo = filter_algos[0] + common_algo = list(filter_algos)[0] results = [] - rows = self._content_get_from_hash(common_algo, content[common_algo]) + rows = self._content_get_from_hashes( + common_algo, [content[common_algo] for content in contents] + ) for row in rows: # Re-check all the hashes, in case of collisions (either of the # hash of the partition key, or the hashes in it) - for algo in HASH_ALGORITHMS: - if content.get(algo) and getattr(row, algo) != content[algo]: - # This hash didn't match; discard the row. + for content in contents: + for algo in HASH_ALGORITHMS: + if content.get(algo) and getattr(row, algo) != content[algo]: + # This hash didn't match; discard the row. + break + else: + # All hashes match, keep this row. + row_d = row.to_dict() + row_d["ctime"] = row.ctime.replace(tzinfo=datetime.timezone.utc) + results.append(Content(**row_d)) break else: - # All hashes match, keep this row. - row_d = row.to_dict() - row_d["ctime"] = row.ctime.replace(tzinfo=datetime.timezone.utc) - results.append(Content(**row_d)) + # No content matched; skip it + pass return results @timed diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -237,10 +237,11 @@ def content_index_add_one(self, algo: str, content: Content, token: int) -> None: self._content_indexes[algo][content.get_hash(algo)].add(token) - def content_get_tokens_from_single_hash( - self, algo: str, hash_: bytes + def content_get_tokens_from_single_algo( + self, algo: str, hashes: List[bytes] ) -> Iterable[int]: - return self._content_indexes[algo][hash_] + for hash_ in hashes: + yield from self._content_indexes[algo][hash_] ########################## # 'skipped_content' table diff --git a/swh/storage/tests/test_cassandra.py b/swh/storage/tests/test_cassandra.py --- a/swh/storage/tests/test_cassandra.py +++ b/swh/storage/tests/test_cassandra.py @@ -276,14 +276,14 @@ cont, cont2 = sample_data.contents[:2] # always return a token - def mock_cgtfsh(algo, hash_): + def mock_cgtfsa(algo, hashes): nonlocal called called += 1 assert algo in ("sha1", "sha1_git") return [123456] mocker.patch.object( - swh_storage._cql_runner, "content_get_tokens_from_single_hash", mock_cgtfsh, + swh_storage._cql_runner, "content_get_tokens_from_single_algo", mock_cgtfsa, ) # For all tokens, always return cont @@ -324,14 +324,14 @@ cont, cont2 = [attr.evolve(c, ctime=now()) for c in sample_data.contents[:2]] # always return a token - def mock_cgtfsh(algo, hash_): + def mock_cgtfsa(algo, hashes): nonlocal called called += 1 assert algo in ("sha1", "sha1_git") return [123456] mocker.patch.object( - swh_storage._cql_runner, "content_get_tokens_from_single_hash", mock_cgtfsh, + swh_storage._cql_runner, "content_get_tokens_from_single_algo", mock_cgtfsa, ) # For all tokens, always return cont and cont2 @@ -369,14 +369,14 @@ cont, cont2 = [attr.evolve(c, ctime=now()) for c in sample_data.contents[:2]] # always return a token - def mock_cgtfsh(algo, hash_): + def mock_cgtfsa(algo, hashes): nonlocal called called += 1 assert algo in ("sha1", "sha1_git") return [123456] mocker.patch.object( - swh_storage._cql_runner, "content_get_tokens_from_single_hash", mock_cgtfsh, + swh_storage._cql_runner, "content_get_tokens_from_single_algo", mock_cgtfsa, ) # For all tokens, always return cont and cont2