diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -8,14 +8,14 @@ import logging import random from typing import ( - Any, Callable, Dict, Generator, Iterable, List, Optional, TypeVar + Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, TypeVar ) from cassandra import CoordinationFailure from cassandra.cluster import ( Cluster, EXEC_PROFILE_DEFAULT, ExecutionProfile, ResultSet) from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy -from cassandra.query import PreparedStatement +from cassandra.query import PreparedStatement, BoundStatement from tenacity import ( retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type, @@ -173,9 +173,37 @@ 'sha1', 'sha1_git', 'sha256', 'blake2s256', 'length', 'ctime', 'status'] + def _content_add_finalize(self, statement: BoundStatement) -> None: + """Returned currified by content_add_prepare, to be called when the + content row should be added to the primary table.""" + self._execute_with_retries(statement, None) + self._increment_counter('content', 1) + @_prepared_insert_statement('content', _content_keys) - def content_add_one(self, content, *, statement) -> None: - self._add_one(statement, 'content', content, self._content_keys) + def content_add_prepare( + self, content, *, statement) -> Tuple[int, Callable[[], None]]: + """Prepares insertion of a Content to the main 'content' table. + Returns a token (to be used in secondary tables), and a function to be + called to perform the insertion in the main table.""" + statement = statement.bind([ + getattr(content, key) for key in self._content_keys]) + + # Type used for hashing keys (usually, it will be + # cassandra.metadata.Murmur3Token) + token_class = self._cluster.metadata.token_map.token_class + + # Token of the row when it will be inserted. This is equivalent to + # "SELECT token({', '.join(self._content_pk)}) FROM content WHERE ..." + # after the row is inserted; but we need the token to insert in the + # index tables *before* inserting to the main 'content' table + token = token_class.from_key(statement.routing_key).value + assert TOKEN_BEGIN <= token <= TOKEN_END + + # Function to be called after the indexes contain their respective + # row + finalizer = functools.partial(self._content_add_finalize, statement) + + return (token, finalizer) @_prepared_statement('SELECT * FROM content WHERE ' + ' AND '.join(map('%s = ?'.__mod__, HASH_ALGORITHMS))) @@ -190,6 +218,12 @@ else: return None + @_prepared_statement('SELECT * FROM content WHERE token(' + + ', '.join(_content_pk) + + ') = ?') + def content_get_from_token(self, token, *, statement) -> Iterable[Row]: + return self._execute_with_retries(statement, [token]) + @_prepared_statement('SELECT * FROM content WHERE token(%s) > ? LIMIT 1' % ', '.join(_content_pk)) def content_get_random(self, *, statement) -> Optional[Row]: @@ -213,19 +247,21 @@ self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) - def content_index_add_one(self, main_algo: str, content: Content) -> None: - query = 'INSERT INTO content_by_{algo} ({cols}) VALUES ({values})' \ - .format(algo=main_algo, cols=', '.join(self._content_pk), - values=', '.join('%s' for _ in self._content_pk)) + def content_index_add_one( + self, algo: str, content: Content, token: int) -> None: + """Adds a row mapping content[algo] to the token of the Content in + the main 'content' table.""" + query = ( + f'INSERT INTO content_by_{algo} ({algo}, target_token) ' + f'VALUES (%s, %s)') self._execute_with_retries( - query, [content.get_hash(algo) for algo in self._content_pk]) + query, [content.get_hash(algo), token]) - def content_get_pks_from_single_hash( - self, algo: str, hash_: bytes) -> List[Row]: + def content_get_tokens_from_single_hash( + self, algo: str, hash_: bytes) -> Iterable[int]: assert algo in HASH_ALGORITHMS - query = 'SELECT * FROM content_by_{algo} WHERE {algo} = %s'.format( - algo=algo) - return list(self._execute_with_retries(query, [hash_])) + query = f'SELECT target_token FROM content_by_{algo} WHERE {algo} = %s' + return (tok for (tok,) in self._execute_with_retries(query, [hash_])) ########################## # 'skipped_content' table diff --git a/swh/storage/cassandra/schema.py b/swh/storage/cassandra/schema.py --- a/swh/storage/cassandra/schema.py +++ b/swh/storage/cassandra/schema.py @@ -192,12 +192,11 @@ '''.split('\n\n') CONTENT_INDEX_TEMPLATE = ''' +-- Secondary table, used for looking up "content" from a single hash CREATE TABLE IF NOT EXISTS content_by_{main_algo} ( - sha1 blob, - sha1_git blob, - sha256 blob, - blake2s256 blob, - PRIMARY KEY (({main_algo}), {other_algos}) + {main_algo} blob, + target_token bigint, -- value of token(pk) on the "primary" table + PRIMARY KEY (({main_algo}), target_token) ); CREATE TABLE IF NOT EXISTS skipped_content_by_{main_algo} ( diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -26,7 +26,6 @@ from .common import TOKEN_BEGIN, TOKEN_END from .converters import ( revision_to_db, revision_from_db, release_to_db, release_from_db, - row_to_content_hashes, ) from .cql import CqlRunner from .schema import HASH_ALGORITHMS @@ -52,6 +51,24 @@ return True + def _content_get_from_hash(self, algo, hash_) -> Iterable: + """From the name of a hash algorithm and a value of that hash, + looks up the "hash -> token" secondary table (content_by_{algo}) + to get tokens. + Then, looks up the main table (content) to get all contents with + that token, and filters out contents whose hash doesn't match.""" + found_tokens = self._cql_runner.content_get_tokens_from_single_hash( + algo, hash_) + + for token in found_tokens: + # Query the main table ('content'). + res = self._cql_runner.content_get_from_token(token) + + for row in res: + # re-check the the hash (in case of murmur3 collision) + if getattr(row, algo) == hash_: + yield row + def _content_add(self, contents: List[Content], with_data: bool) -> Dict: # Filter-out content already in the database. contents = [c for c in contents @@ -74,31 +91,46 @@ for content in contents: content_add += 1 - # Then add to index tables - for algo in HASH_ALGORITHMS: - self._cql_runner.content_index_add_one(algo, content) - - # Then to the main table - self._cql_runner.content_add_one(content) - - # Note that we check for collisions *after* inserting. This - # differs significantly from the pgsql storage, but checking - # before insertion does not provide any guarantee in case - # another thread inserts the colliding hash at the same time. + # Check for sha1 or sha1_git collisions. This test is not atomic + # with the insertion, so it won't detect a collision if both + # contents are inserted at the same time, but it's good enough. # # The proper way to do it would probably be a BATCH, but this # would be inefficient because of the number of partitions we # need to affect (len(HASH_ALGORITHMS)+1, which is currently 5) for algo in {'sha1', 'sha1_git'}: - pks = self._cql_runner.content_get_pks_from_single_hash( + collisions = [] + # Get tokens of 'content' rows with the same value for + # sha1/sha1_git + rows = self._content_get_from_hash( algo, content.get_hash(algo)) - if len(pks) > 1: - # There are more than the one we just inserted. - colliding_content_hashes = [ - row_to_content_hashes(pk) for pk in pks - ] + for row in rows: + if getattr(row, algo) != content.get_hash(algo): + # collision of token(partition key), ignore this + # row + continue + + for algo in HASH_ALGORITHMS: + if getattr(row, algo) != content.get_hash(algo): + # This hash didn't match; discard the row. + collisions.append({ + algo: getattr(row, algo) + for algo in HASH_ALGORITHMS}) + + if collisions: + collisions.append(content.hashes()) raise HashCollision( - algo, content.get_hash(algo), colliding_content_hashes) + algo, content.get_hash(algo), collisions) + + (token, insertion_finalizer) = \ + self._cql_runner.content_add_prepare(content) + + # Then add to index tables + for algo in HASH_ALGORITHMS: + self._cql_runner.content_index_add_one(algo, content, token) + + # Then to the main table + insertion_finalizer() summary = { 'content:add': content_add, @@ -167,22 +199,10 @@ for sha1 in contents: # Get all (sha1, sha1_git, sha256, blake2s256) whose sha1 # matches the argument, from the index table ('content_by_sha1') - pks = self._cql_runner.content_get_pks_from_single_hash( - 'sha1', sha1) - - if pks: - # TODO: what to do if there are more than one? - pk = pks[0] - - # Query the main table ('content') - res = self._cql_runner.content_get_from_pk(pk._asdict()) - - # Rows in 'content' are inserted after corresponding - # rows in 'content_by_*', so we might be missing it - if res: - content_metadata = res._asdict() - content_metadata.pop('ctime') - result[content_metadata['sha1']].append(content_metadata) + for row in self._content_get_from_hash('sha1', sha1): + content_metadata = row._asdict() + content_metadata.pop('ctime') + result[content_metadata['sha1']].append(content_metadata) return result def content_find(self, content): @@ -195,27 +215,21 @@ '%s' % ', '.join(sorted(HASH_ALGORITHMS))) common_algo = filter_algos[0] - # Find all contents whose common_algo matches at least one - # of the requests. - found_pks = self._cql_runner.content_get_pks_from_single_hash( - common_algo, content[common_algo]) - found_pks = [pk._asdict() for pk in found_pks] - - # Filter with the other hash algorithms. - for algo in filter_algos[1:]: - found_pks = [pk for pk in found_pks if pk[algo] == content[algo]] - results = [] - for pk in found_pks: - # Query the main table ('content'). - res = self._cql_runner.content_get_from_pk(pk) - - # Rows in 'content' are inserted after corresponding - # rows in 'content_by_*', so we might be missing it - if res: + rows = self._content_get_from_hash( + common_algo, content[common_algo]) + for row in rows: + # Re-check all the hashes, in case of collisions (either of the + # hash of the partition key, or the hashes in it) + for algo in HASH_ALGORITHMS: + if content.get(algo) and getattr(row, algo) != content[algo]: + # This hash didn't match; discard the row. + break + else: + # All hashes match, keep this row. results.append({ - **res._asdict(), - 'ctime': res.ctime.replace(tzinfo=datetime.timezone.utc) + **row._asdict(), + 'ctime': row.ctime.replace(tzinfo=datetime.timezone.utc) }) return results diff --git a/swh/storage/tests/test_cassandra.py b/swh/storage/tests/test_cassandra.py --- a/swh/storage/tests/test_cassandra.py +++ b/swh/storage/tests/test_cassandra.py @@ -3,6 +3,8 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from collections import namedtuple +import datetime import os import signal import socket @@ -13,12 +15,14 @@ from swh.storage import get_storage from swh.storage.cassandra import create_keyspace -from swh.storage.cassandra.schema import TABLES +from swh.storage.cassandra.schema import TABLES, HASH_ALGORITHMS from swh.storage.tests.test_storage import TestStorage as _TestStorage from swh.storage.tests.test_storage import TestStorageGeneratedData \ as _TestStorageGeneratedData +from .storage_data import data + CONFIG_TEMPLATE = ''' data_file_directories: @@ -176,6 +180,135 @@ @pytest.mark.cassandra class TestCassandraStorage(_TestStorage): + def test_content_add_murmur3_collision(self, swh_storage, mocker): + """The Murmur3 token is used as link from index tables to the main + table; and non-matching contents with colliding murmur3-hash + are filtered-out when reading the main table. + This test checks the content methods do filter out these collision. + """ + called = 0 + + # always return a token + def mock_cgtfsh(algo, hash_): + nonlocal called + called += 1 + assert algo in ('sha1', 'sha1_git') + return [123456] + mocker.patch.object( + swh_storage.storage._cql_runner, + 'content_get_tokens_from_single_hash', + mock_cgtfsh) + + # For all tokens, always return data.cont + Row = namedtuple('Row', HASH_ALGORITHMS) + + def mock_cgft(token): + nonlocal called + called += 1 + return [Row(**{algo: data.cont[algo] for algo in HASH_ALGORITHMS})] + mocker.patch.object( + swh_storage.storage._cql_runner, + 'content_get_from_token', + mock_cgft) + + actual_result = swh_storage.content_add([data.cont2]) + + assert called == 4 + assert actual_result == { + 'content:add': 1, + 'content:add:bytes': data.cont2['length'], + } + + def test_content_get_metadata_murmur3_collision(self, swh_storage, mocker): + """The Murmur3 token is used as link from index tables to the main + table; and non-matching contents with colliding murmur3-hash + are filtered-out when reading the main table. + This test checks the content methods do filter out these collision. + """ + called = 0 + + # always return a token + def mock_cgtfsh(algo, hash_): + nonlocal called + called += 1 + assert algo in ('sha1', 'sha1_git') + return [123456] + mocker.patch.object( + swh_storage.storage._cql_runner, + 'content_get_tokens_from_single_hash', + mock_cgtfsh) + + # For all tokens, always return data.cont and data.cont2 + cols = list(set(data.cont) - {'data'}) + Row = namedtuple('Row', cols + ['ctime']) + + def mock_cgft(token): + nonlocal called + called += 1 + return [Row(ctime=42, **{col: cont[col] for col in cols}) + for cont in [data.cont, data.cont2]] + mocker.patch.object( + swh_storage.storage._cql_runner, + 'content_get_from_token', + mock_cgft) + + expected_cont = data.cont.copy() + del expected_cont['data'] + + actual_result = swh_storage.content_get_metadata([data.cont['sha1']]) + + assert called == 2 + + # but data.cont2 should be filtered out + assert actual_result == { + data.cont['sha1']: [expected_cont] + } + + def test_content_find_murmur3_collision(self, swh_storage, mocker): + """The Murmur3 token is used as link from index tables to the main + table; and non-matching contents with colliding murmur3-hash + are filtered-out when reading the main table. + This test checks the content methods do filter out these collision. + """ + called = 0 + + # always return a token + def mock_cgtfsh(algo, hash_): + nonlocal called + called += 1 + assert algo in ('sha1', 'sha1_git') + return [123456] + mocker.patch.object( + swh_storage.storage._cql_runner, + 'content_get_tokens_from_single_hash', + mock_cgtfsh) + + # For all tokens, always return data.cont and data.cont2 + cols = list(set(data.cont) - {'data'}) + Row = namedtuple('Row', cols + ['ctime']) + + def mock_cgft(token): + nonlocal called + called += 1 + return [Row(ctime=datetime.datetime.now(), + **{col: cont[col] for col in cols}) + for cont in [data.cont, data.cont2]] + mocker.patch.object( + swh_storage.storage._cql_runner, + 'content_get_from_token', + mock_cgft) + + expected_cont = data.cont.copy() + del expected_cont['data'] + + actual_result = swh_storage.content_find({'sha1': data.cont['sha1']}) + + assert called == 2 + + # but data.cont2 should be filtered out + del actual_result[0]['ctime'] + assert actual_result == [expected_cont] + @pytest.mark.skip('content_update is not yet implemented for Cassandra') def test_content_update(self): pass