diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py index 89662459..d4941c70 100644 --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -25,6 +25,7 @@ from cassandra import ConsistencyLevel, CoordinationFailure from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile, ResultSet +from cassandra.metadata import group_keys_by_replica from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy from cassandra.query import BoundStatement, PreparedStatement, dict_factory from mypy_extensions import NamedArg @@ -254,6 +255,7 @@ def __init__( port=port, execution_profiles=get_execution_profiles(consistency_level), ) + self._keyspace = keyspace self._session = self._cluster.connect(keyspace) self._cluster.register_user_type( keyspace, "microtimestamp_with_timezone", TimestampWithTimezone @@ -313,11 +315,15 @@ def _get_random_row(self, row_class: Type[_T], statement) -> Optional[_T]: # no else: return None - def _missing(self, statement, ids): + def _missing(self, statement: PreparedStatement, table: str, ids: Iterable): found_ids = set() - for id_group in grouper(ids, PARTITION_KEY_RESTRICTION_MAX_SIZE): - rows = self._execute_with_retries(statement, [list(id_group)]) - found_ids.update(row["id"] for row in rows) + ids_per_replica = group_keys_by_replica( + self._session, self._keyspace, table, [(id_,) for id_ in ids] + ) + for replica_ids in ids_per_replica.values(): + for id_group in grouper(ids, PARTITION_KEY_RESTRICTION_MAX_SIZE): + rows = self._execute_with_retries(statement, [list(id_group)]) + found_ids.update(row["id"] for row in rows) return [id_ for id_ in ids if id_ not in found_ids] ########################## @@ -418,7 +424,7 @@ def content_get_token_range( def content_missing_by_sha1_git( self, ids: List[bytes], *, statement ) -> List[bytes]: - return self._missing(statement, ids) + return self._missing(statement, "content_by_sha1_git", ids) def content_index_add_one(self, algo: str, content: Content, token: int) -> None: """Adds a row mapping content[algo] to the token of the Content in @@ -553,7 +559,7 @@ def skipped_content_get_tokens_from_single_hash( @_prepared_exists_statement("revision") def revision_missing(self, ids: List[bytes], *, statement) -> List[bytes]: - return self._missing(statement, ids) + return self._missing(statement, "revision", ids) @_prepared_insert_statement(RevisionRow) def revision_add_one(self, revision: RevisionRow, *, statement) -> None: @@ -604,7 +610,7 @@ def revision_parent_get( @_prepared_exists_statement("release") def release_missing(self, ids: List[bytes], *, statement) -> List[bytes]: - return self._missing(statement, ids) + return self._missing(statement, "release", ids) @_prepared_insert_statement(ReleaseRow) def release_add_one(self, release: ReleaseRow, *, statement) -> None: @@ -626,7 +632,7 @@ def release_get_random(self, *, statement) -> Optional[ReleaseRow]: @_prepared_exists_statement("directory") def directory_missing(self, ids: List[bytes], *, statement) -> List[bytes]: - return self._missing(statement, ids) + return self._missing(statement, "directory", ids) @_prepared_insert_statement(DirectoryRow) def directory_add_one(self, directory: DirectoryRow, *, statement) -> None: @@ -672,7 +678,7 @@ def directory_entry_get_from_name( @_prepared_exists_statement("snapshot") def snapshot_missing(self, ids: List[bytes], *, statement) -> List[bytes]: - return self._missing(statement, ids) + return self._missing(statement, "snapshot", ids) @_prepared_insert_statement(SnapshotRow) def snapshot_add_one(self, snapshot: SnapshotRow, *, statement) -> None: