diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py index dc6f8cac..88019d24 100644 --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -1,1491 +1,1494 @@ # Copyright (C) 2019-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import Counter import dataclasses import datetime import functools import itertools import logging import random from typing import ( Any, Callable, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, Type, TypeVar, Union, ) from cassandra import ConsistencyLevel, CoordinationFailure from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile, ResultSet from cassandra.concurrent import execute_concurrent_with_args from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy from cassandra.query import BoundStatement, PreparedStatement, dict_factory from mypy_extensions import NamedArg from tenacity import ( retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential, ) from swh.core.utils import grouper from swh.model.model import ( Content, Person, Sha1Git, SkippedContent, Timestamp, TimestampWithTimezone, ) from swh.model.swhids import CoreSWHID from swh.storage.interface import ListOrder from ..utils import remove_keys from .common import TOKEN_BEGIN, TOKEN_END, hash_url from .model import ( MAGIC_NULL_PK, BaseRow, ContentRow, DirectoryEntryRow, DirectoryRow, ExtIDByTargetRow, ExtIDRow, MetadataAuthorityRow, MetadataFetcherRow, ObjectCountRow, OriginRow, OriginVisitRow, OriginVisitStatusRow, RawExtrinsicMetadataByIdRow, RawExtrinsicMetadataRow, ReleaseRow, RevisionParentRow, RevisionRow, SkippedContentRow, SnapshotBranchRow, SnapshotRow, content_index_table_name, ) from .schema import CREATE_TABLES_QUERIES, HASH_ALGORITHMS PARTITION_KEY_RESTRICTION_MAX_SIZE = 100 """Maximum number of restrictions in a single query. Usually this is a very low number (eg. SELECT ... FROM ... WHERE x=?), but some queries can request arbitrarily many (eg. SELECT ... FROM ... WHERE x IN ?). This can cause performance issues, as the node getting the query need to coordinate with other nodes to get the complete results. See for details and rationale. """ BATCH_INSERT_MAX_SIZE = 1000 logger = logging.getLogger(__name__) def get_execution_profiles( consistency_level: str = "ONE", ) -> Dict[object, ExecutionProfile]: if consistency_level not in ConsistencyLevel.name_to_value: raise ValueError( f"Configuration error: Unknown consistency level '{consistency_level}'" ) return { EXEC_PROFILE_DEFAULT: ExecutionProfile( load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy()), row_factory=dict_factory, consistency_level=ConsistencyLevel.name_to_value[consistency_level], ) } # Configuration for cassandra-driver's access to servers: # * hit the right server directly when sending a query (TokenAwarePolicy), # * if there's more than one, then pick one at random that's in the same # datacenter as the client (DCAwareRoundRobinPolicy) def create_keyspace( hosts: List[str], keyspace: str, port: int = 9042, *, durable_writes=True ): cluster = Cluster(hosts, port=port, execution_profiles=get_execution_profiles()) session = cluster.connect() extra_params = "" if not durable_writes: extra_params = "AND durable_writes = false" session.execute( """CREATE KEYSPACE IF NOT EXISTS "%s" WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 } %s; """ % (keyspace, extra_params) ) session.execute('USE "%s"' % keyspace) for query in CREATE_TABLES_QUERIES: session.execute(query) TRet = TypeVar("TRet") def _prepared_statement( query: str, ) -> Callable[[Callable[..., TRet]], Callable[..., TRet]]: """Returns a decorator usable on methods of CqlRunner, to inject them with a 'statement' argument, that is a prepared statement corresponding to the query. This only works on methods of CqlRunner, as preparing a statement requires a connection to a Cassandra server.""" def decorator(f): @functools.wraps(f) - def newf(self, *args, **kwargs) -> TRet: + def newf(self: "CqlRunner", *args, **kwargs) -> TRet: if f.__name__ not in self._prepared_statements: - statement: PreparedStatement = self._session.prepare(query) + statement: PreparedStatement = self._session.prepare( + query.format(keyspace=self.keyspace) + ) self._prepared_statements[f.__name__] = statement return f( self, *args, **kwargs, statement=self._prepared_statements[f.__name__] ) return newf return decorator TArg = TypeVar("TArg") TSelf = TypeVar("TSelf") -def _insert_query(row_class): +def _insert_query(row_class: Type[BaseRow]) -> str: columns = row_class.cols() return ( - f"INSERT INTO {row_class.TABLE} ({', '.join(columns)}) " + f"INSERT INTO {{keyspace}}.{row_class.TABLE} ({', '.join(columns)}) " f"VALUES ({', '.join('?' for _ in columns)})" ) def _prepared_insert_statement( row_class: Type[BaseRow], ) -> Callable[ [Callable[[TSelf, TArg, NamedArg(Any, "statement")], TRet]], # noqa Callable[[TSelf, TArg], TRet], ]: """Shorthand for using `_prepared_statement` for `INSERT INTO` statements.""" return _prepared_statement(_insert_query(row_class)) def _prepared_exists_statement( table_name: str, ) -> Callable[ [Callable[[TSelf, TArg, NamedArg(Any, "statement")], TRet]], # noqa Callable[[TSelf, TArg], TRet], ]: """Shorthand for using `_prepared_statement` for queries that only check which ids in a list exist in the table.""" - return _prepared_statement(f"SELECT id FROM {table_name} WHERE id = ?") + return _prepared_statement(f"SELECT id FROM {{keyspace}}.{table_name} WHERE id = ?") def _prepared_select_statement( row_class: Type[BaseRow], clauses: str = "", cols: Optional[List[str]] = None, ) -> Callable[[Callable[..., TRet]], Callable[..., TRet]]: if cols is None: cols = row_class.cols() return _prepared_statement( - f"SELECT {', '.join(cols)} FROM {row_class.TABLE} {clauses}" + f"SELECT {', '.join(cols)} FROM {{keyspace}}.{row_class.TABLE} {clauses}" ) def _prepared_select_statements( row_class: Type[BaseRow], queries: Dict[Any, str], ) -> Callable[[Callable[..., TRet]], Callable[..., TRet]]: """Like _prepared_statement, but supports multiple statements, passed a dict, and passes a dict of prepared statements to the decorated method""" cols = row_class.cols() - statement_start = f"SELECT {', '.join(cols)} FROM {row_class.TABLE} " + statement_template = "SELECT {cols} FROM {keyspace}.{table} {rest}" def decorator(f): @functools.wraps(f) - def newf(self, *args, **kwargs) -> TRet: + def newf(self: "CqlRunner", *args, **kwargs) -> TRet: if f.__name__ not in self._prepared_statements: self._prepared_statements[f.__name__] = { - key: self._session.prepare(statement_start + query) + key: self._session.prepare( + statement_template.format( + cols=", ".join(cols), + keyspace=self.keyspace, + table=row_class.TABLE, + rest=query, + ) + ) for (key, query) in queries.items() } return f( self, *args, **kwargs, statements=self._prepared_statements[f.__name__] ) return newf return decorator def _next_bytes_value(value: bytes) -> bytes: """Returns the next bytes value by incrementing the integer representation of the provided value and converting it back to bytes. For instance when prefix is b"abcd", it returns b"abce". """ next_value_int = int.from_bytes(value, byteorder="big") + 1 return next_value_int.to_bytes( (next_value_int.bit_length() + 7) // 8, byteorder="big" ) class CqlRunner: """Class managing prepared statements and building queries to be sent to Cassandra.""" def __init__( self, hosts: List[str], keyspace: str, port: int, consistency_level: str ): self._cluster = Cluster( hosts, port=port, execution_profiles=get_execution_profiles(consistency_level), ) - self._session = self._cluster.connect(keyspace) + self.keyspace = keyspace + self._session = self._cluster.connect() self._cluster.register_user_type( keyspace, "microtimestamp_with_timezone", TimestampWithTimezone ) self._cluster.register_user_type(keyspace, "microtimestamp", Timestamp) self._cluster.register_user_type(keyspace, "person", Person) # directly a PreparedStatement for methods decorated with # @_prepared_statements (and its wrappers, _prepared_insert_statement, # _prepared_exists_statement, and _prepared_select_statement); # and a dict of PreparedStatements with @_prepared_select_statements self._prepared_statements: Dict[ str, Union[PreparedStatement, Dict[Any, PreparedStatement]] ] = {} ########################## # Common utility functions ########################## MAX_RETRIES = 3 @retry( wait=wait_random_exponential(multiplier=1, max=10), stop=stop_after_attempt(MAX_RETRIES), retry=retry_if_exception_type(CoordinationFailure), ) def _execute_with_retries(self, statement, args: Optional[Sequence]) -> ResultSet: return self._session.execute(statement, args, timeout=1000.0) @retry( wait=wait_random_exponential(multiplier=1, max=10), stop=stop_after_attempt(MAX_RETRIES), retry=retry_if_exception_type(CoordinationFailure), ) def _execute_many_with_retries( self, statement, args_list: Sequence[Tuple] ) -> Iterable[Dict[str, Any]]: for res in execute_concurrent_with_args(self._session, statement, args_list): yield from res.result_or_exc def _add_one(self, statement, obj: BaseRow) -> None: self._execute_with_retries(statement, dataclasses.astuple(obj)) def _add_many(self, statement, objs: Sequence[BaseRow]) -> None: tables = {obj.TABLE for obj in objs} assert len(tables) == 1, f"Cannot insert to multiple tables: {tables}" rows = list(map(dataclasses.astuple, objs)) for _ in self._execute_many_with_retries(statement, rows): # Need to consume the generator to actually run the INSERTs pass _T = TypeVar("_T", bound=BaseRow) def _get_random_row(self, row_class: Type[_T], statement) -> Optional[_T]: # noqa """Takes a prepared statement of the form "SELECT * FROM WHERE token() > ? LIMIT 1" and uses it to return a random row""" token = random.randint(TOKEN_BEGIN, TOKEN_END) rows = self._execute_with_retries(statement, [token]) if not rows: # There are no row with a greater token; wrap around to get # the row with the smallest token rows = self._execute_with_retries(statement, [TOKEN_BEGIN]) if rows: return row_class.from_dict(rows.one()) # type: ignore else: return None def _missing(self, statement: PreparedStatement, ids): found_ids = set() if not ids: return [] for row in self._execute_many_with_retries(statement, [(id_,) for id_ in ids]): found_ids.add(row["id"]) return [id_ for id_ in ids if id_ not in found_ids] ########################## # 'content' table ########################## def _content_add_finalize(self, statement: BoundStatement) -> None: """Returned currified by content_add_prepare, to be called when the content row should be added to the primary table.""" self._execute_with_retries(statement, None) @_prepared_insert_statement(ContentRow) def content_add_prepare( self, content: ContentRow, *, statement ) -> Tuple[int, Callable[[], None]]: """Prepares insertion of a Content to the main 'content' table. Returns a token (to be used in secondary tables), and a function to be called to perform the insertion in the main table.""" statement = statement.bind(dataclasses.astuple(content)) # Type used for hashing keys (usually, it will be # cassandra.metadata.Murmur3Token) token_class = self._cluster.metadata.token_map.token_class # Token of the row when it will be inserted. This is equivalent to # "SELECT token({', '.join(ContentRow.PARTITION_KEY)}) FROM content WHERE ..." # after the row is inserted; but we need the token to insert in the # index tables *before* inserting to the main 'content' table token = token_class.from_key(statement.routing_key).value assert TOKEN_BEGIN <= token <= TOKEN_END # Function to be called after the indexes contain their respective # row finalizer = functools.partial(self._content_add_finalize, statement) return (token, finalizer) @_prepared_select_statement( ContentRow, f"WHERE {' AND '.join(map('%s = ?'.__mod__, HASH_ALGORITHMS))}" ) def content_get_from_pk( self, content_hashes: Dict[str, bytes], *, statement ) -> Optional[ContentRow]: rows = list( self._execute_with_retries( statement, [content_hashes[algo] for algo in HASH_ALGORITHMS] ) ) assert len(rows) <= 1 if rows: return ContentRow(**rows[0]) else: return None def content_missing_from_all_hashes( self, contents_hashes: List[Dict[str, bytes]] ) -> Iterator[Dict[str, bytes]]: for group in grouper(contents_hashes, PARTITION_KEY_RESTRICTION_MAX_SIZE): group = list(group) # Get all contents that share a sha256 with one of the contents in the group present = set( self._content_get_hashes_from_sha256( [content["sha256"] for content in group] ) ) for content in group: for algo in HASH_ALGORITHMS: assert content.get(algo) is not None, ( "content_missing_from_all_hashes must not be called with " "partial hashes." ) if tuple(content[algo] for algo in HASH_ALGORITHMS) not in present: yield content @_prepared_select_statement(ContentRow, "WHERE sha256 IN ?", HASH_ALGORITHMS) def _content_get_hashes_from_sha256( self, ids: List[bytes], *, statement ) -> Iterator[Tuple[bytes, bytes, bytes, bytes]]: for row in self._execute_with_retries(statement, [ids]): yield tuple(row[algo] for algo in HASH_ALGORITHMS) # type: ignore @_prepared_select_statement( ContentRow, f"WHERE token({', '.join(ContentRow.PARTITION_KEY)}) = ?" ) def content_get_from_tokens(self, tokens, *, statement) -> Iterable[ContentRow]: return map( ContentRow.from_dict, self._execute_many_with_retries(statement, [(token,) for token in tokens]), ) @_prepared_select_statement( ContentRow, f"WHERE token({', '.join(ContentRow.PARTITION_KEY)}) > ? LIMIT 1" ) def content_get_random(self, *, statement) -> Optional[ContentRow]: return self._get_random_row(ContentRow, statement) @_prepared_statement( """ - SELECT token({pk}) AS tok, {cols} FROM {table} + SELECT token({pk}) AS tok, {cols} FROM {{keyspace}}.{table} WHERE token({pk}) >= ? AND token({pk}) <= ? LIMIT ? """.format( pk=", ".join(ContentRow.PARTITION_KEY), cols=", ".join(ContentRow.cols()), table=ContentRow.TABLE, ) ) def content_get_token_range( self, start: int, end: int, limit: int, *, statement ) -> Iterable[Tuple[int, ContentRow]]: """Returns an iterable of (token, row)""" return ( (row["tok"], ContentRow.from_dict(remove_keys(row, ("tok",)))) for row in self._execute_with_retries(statement, [start, end, limit]) ) ########################## # 'content_by_*' tables ########################## def content_index_add_one(self, algo: str, content: Content, token: int) -> None: """Adds a row mapping content[algo] to the token of the Content in the main 'content' table.""" + table = content_index_table_name(algo, skipped_content=False) query = f""" - INSERT INTO {content_index_table_name(algo, skipped_content=False)} - ({algo}, target_token) - VALUES (%s, %s) + INSERT INTO {self.keyspace}.{table} ({algo}, target_token) VALUES (%s, %s) """ self._execute_with_retries(query, [content.get_hash(algo), token]) def content_get_tokens_from_single_algo( self, algo: str, hashes: List[bytes] ) -> Iterable[int]: assert algo in HASH_ALGORITHMS - query = f""" - SELECT target_token - FROM {content_index_table_name(algo, skipped_content=False)} - WHERE {algo} = %s - """ + table = content_index_table_name(algo, skipped_content=False) + query = f"SELECT target_token FROM {self.keyspace}.{table} WHERE {algo} = %s" return ( row["target_token"] # type: ignore for row in self._execute_many_with_retries( query, [(hash_,) for hash_ in hashes] ) ) ########################## # 'skipped_content' table ########################## def _skipped_content_add_finalize(self, statement: BoundStatement) -> None: """Returned currified by skipped_content_add_prepare, to be called when the content row should be added to the primary table.""" self._execute_with_retries(statement, None) @_prepared_insert_statement(SkippedContentRow) def skipped_content_add_prepare( self, content, *, statement ) -> Tuple[int, Callable[[], None]]: """Prepares insertion of a Content to the main 'skipped_content' table. Returns a token (to be used in secondary tables), and a function to be called to perform the insertion in the main table.""" # Replace NULLs (which are not allowed in the partition key) with # an empty byte string for key in SkippedContentRow.PARTITION_KEY: if getattr(content, key) is None: setattr(content, key, MAGIC_NULL_PK) statement = statement.bind(dataclasses.astuple(content)) # Type used for hashing keys (usually, it will be # cassandra.metadata.Murmur3Token) token_class = self._cluster.metadata.token_map.token_class # Token of the row when it will be inserted. This is equivalent to # "SELECT token({', '.join(SkippedContentRow.PARTITION_KEY)}) # FROM skipped_content WHERE ..." # after the row is inserted; but we need the token to insert in the # index tables *before* inserting to the main 'skipped_content' table token = token_class.from_key(statement.routing_key).value assert TOKEN_BEGIN <= token <= TOKEN_END # Function to be called after the indexes contain their respective # row finalizer = functools.partial(self._skipped_content_add_finalize, statement) return (token, finalizer) @_prepared_select_statement( SkippedContentRow, f"WHERE {' AND '.join(map('%s = ?'.__mod__, HASH_ALGORITHMS))}", ) def skipped_content_get_from_pk( self, content_hashes: Dict[str, bytes], *, statement ) -> Optional[SkippedContentRow]: rows = list( self._execute_with_retries( statement, [content_hashes[algo] or MAGIC_NULL_PK for algo in HASH_ALGORITHMS], ) ) assert len(rows) <= 1 if rows: return SkippedContentRow.from_dict(rows[0]) else: return None @_prepared_select_statement( SkippedContentRow, f"WHERE token({', '.join(SkippedContentRow.PARTITION_KEY)}) = ?", ) def skipped_content_get_from_token( self, token, *, statement ) -> Iterable[SkippedContentRow]: return map( SkippedContentRow.from_dict, self._execute_with_retries(statement, [token]) ) ########################## # 'skipped_content_by_*' tables ########################## def skipped_content_index_add_one( self, algo: str, content: SkippedContent, token: int ) -> None: """Adds a row mapping content[algo] to the token of the SkippedContent in the main 'skipped_content' table.""" query = ( - f"INSERT INTO skipped_content_by_{algo} ({algo}, target_token) " + f"INSERT INTO {self.keyspace}.skipped_content_by_{algo} ({algo}, target_token) " f"VALUES (%s, %s)" ) self._execute_with_retries( query, [content.get_hash(algo) or MAGIC_NULL_PK, token] ) def skipped_content_get_tokens_from_single_hash( self, algo: str, hash_: bytes ) -> Iterable[int]: assert algo in HASH_ALGORITHMS - query = f""" - SELECT target_token - FROM {content_index_table_name(algo, skipped_content=True)} - WHERE {algo} = %s - """ + table = content_index_table_name(algo, skipped_content=True) + query = f"SELECT target_token FROM {self.keyspace}.{table} WHERE {algo} = %s" return ( row["target_token"] for row in self._execute_with_retries(query, [hash_]) ) ########################## # 'revision' table ########################## @_prepared_exists_statement("revision") def revision_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) @_prepared_insert_statement(RevisionRow) def revision_add_one(self, revision: RevisionRow, *, statement) -> None: self._add_one(statement, revision) @_prepared_select_statement(RevisionRow, "WHERE id IN ?", ["id"]) def revision_get_ids(self, revision_ids, *, statement) -> Iterable[int]: return ( row["id"] for row in self._execute_with_retries(statement, [revision_ids]) ) @_prepared_select_statement(RevisionRow, "WHERE id IN ?") def revision_get( self, revision_ids: List[Sha1Git], *, statement ) -> Iterable[RevisionRow]: return map( RevisionRow.from_dict, self._execute_with_retries(statement, [revision_ids]) ) @_prepared_select_statement(RevisionRow, "WHERE token(id) > ? LIMIT 1") def revision_get_random(self, *, statement) -> Optional[RevisionRow]: return self._get_random_row(RevisionRow, statement) ########################## # 'revision_parent' table ########################## @_prepared_insert_statement(RevisionParentRow) def revision_parent_add_one( self, revision_parent: RevisionParentRow, *, statement ) -> None: self._add_one(statement, revision_parent) @_prepared_select_statement(RevisionParentRow, "WHERE id = ?", ["parent_id"]) def revision_parent_get( self, revision_id: Sha1Git, *, statement ) -> Iterable[bytes]: return ( row["parent_id"] for row in self._execute_with_retries(statement, [revision_id]) ) ########################## # 'release' table ########################## @_prepared_exists_statement("release") def release_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) @_prepared_insert_statement(ReleaseRow) def release_add_one(self, release: ReleaseRow, *, statement) -> None: self._add_one(statement, release) @_prepared_select_statement(ReleaseRow, "WHERE id in ?") def release_get( self, release_ids: List[Sha1Git], *, statement ) -> Iterable[ReleaseRow]: return map( ReleaseRow.from_dict, self._execute_with_retries(statement, [release_ids]) ) @_prepared_select_statement(ReleaseRow, "WHERE token(id) > ? LIMIT 1") def release_get_random(self, *, statement) -> Optional[ReleaseRow]: return self._get_random_row(ReleaseRow, statement) ########################## # 'directory' table ########################## @_prepared_exists_statement("directory") def directory_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) @_prepared_insert_statement(DirectoryRow) def directory_add_one(self, directory: DirectoryRow, *, statement) -> None: """Called after all calls to directory_entry_add_one, to commit/finalize the directory.""" self._add_one(statement, directory) @_prepared_select_statement(DirectoryRow, "WHERE token(id) > ? LIMIT 1") def directory_get_random(self, *, statement) -> Optional[DirectoryRow]: return self._get_random_row(DirectoryRow, statement) @_prepared_select_statement(DirectoryRow, "WHERE id in ?") def directory_get( self, directory_ids: List[Sha1Git], *, statement ) -> Iterable[DirectoryRow]: """Return fields from the main directory table (e.g. raw_manifest, but not entries)""" return map( DirectoryRow.from_dict, self._execute_with_retries(statement, [directory_ids]), ) ########################## # 'directory_entry' table ########################## @_prepared_insert_statement(DirectoryEntryRow) def directory_entry_add_one(self, entry: DirectoryEntryRow, *, statement) -> None: self._add_one(statement, entry) @_prepared_insert_statement(DirectoryEntryRow) def directory_entry_add_concurrent( self, entries: List[DirectoryEntryRow], *, statement ) -> None: if len(entries) == 0: # nothing to do return assert ( len({entry.directory_id for entry in entries}) == 1 ), "directory_entry_add_many must be called with entries for a single dir" self._add_many(statement, entries) @_prepared_statement( "BEGIN UNLOGGED BATCH\n" + (_insert_query(DirectoryEntryRow) + ";\n") * BATCH_INSERT_MAX_SIZE + "APPLY BATCH" ) def directory_entry_add_batch( self, entries: List[DirectoryEntryRow], *, statement ) -> None: if len(entries) == 0: # nothing to do return assert ( len({entry.directory_id for entry in entries}) == 1 ), "directory_entry_add_many must be called with entries for a single dir" for entry_group in grouper(entries, BATCH_INSERT_MAX_SIZE): entry_group = list(entry_group) if len(entry_group) == BATCH_INSERT_MAX_SIZE: entry_group = list(map(dataclasses.astuple, entry_group)) self._execute_with_retries( statement, list(itertools.chain.from_iterable(entry_group)) ) else: # Last group, with a smaller size than the BATCH we prepared. # Creating a prepared BATCH just for this then discarding it would # create too much churn on the server side; and using unprepared # statements is annoying (we can't use _insert_query() as they have # a different format) # Fall back to inserting concurrently. self.directory_entry_add_concurrent(entry_group) @_prepared_select_statement(DirectoryEntryRow, "WHERE directory_id IN ?") def directory_entry_get( self, directory_ids, *, statement ) -> Iterable[DirectoryEntryRow]: return map( DirectoryEntryRow.from_dict, self._execute_with_retries(statement, [directory_ids]), ) @_prepared_select_statement( DirectoryEntryRow, "WHERE directory_id = ? AND name >= ? LIMIT ?" ) def directory_entry_get_from_name( self, directory_id: Sha1Git, from_: bytes, limit: int, *, statement ) -> Iterable[DirectoryEntryRow]: return map( DirectoryEntryRow.from_dict, self._execute_with_retries(statement, [directory_id, from_, limit]), ) ########################## # 'snapshot' table ########################## @_prepared_exists_statement("snapshot") def snapshot_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) @_prepared_insert_statement(SnapshotRow) def snapshot_add_one(self, snapshot: SnapshotRow, *, statement) -> None: self._add_one(statement, snapshot) @_prepared_select_statement(SnapshotRow, "WHERE token(id) > ? LIMIT 1") def snapshot_get_random(self, *, statement) -> Optional[SnapshotRow]: return self._get_random_row(SnapshotRow, statement) ########################## # 'snapshot_branch' table ########################## @_prepared_insert_statement(SnapshotBranchRow) def snapshot_branch_add_one(self, branch: SnapshotBranchRow, *, statement) -> None: self._add_one(statement, branch) @_prepared_statement( f""" SELECT ascii_bins_count(target_type) AS counts - FROM {SnapshotBranchRow.TABLE} + FROM {{keyspace}}.{SnapshotBranchRow.TABLE} WHERE snapshot_id = ? AND name >= ? """ ) def snapshot_count_branches_from_name( self, snapshot_id: Sha1Git, from_: bytes, *, statement ) -> Dict[Optional[str], int]: row = self._execute_with_retries(statement, [snapshot_id, from_]).one() (nb_none, counts) = row["counts"] return {None: nb_none, **counts} @_prepared_statement( f""" SELECT ascii_bins_count(target_type) AS counts - FROM {SnapshotBranchRow.TABLE} + FROM {{keyspace}}.{SnapshotBranchRow.TABLE} WHERE snapshot_id = ? AND name < ? """ ) def snapshot_count_branches_before_name( self, snapshot_id: Sha1Git, before: bytes, *, statement, ) -> Dict[Optional[str], int]: row = self._execute_with_retries(statement, [snapshot_id, before]).one() (nb_none, counts) = row["counts"] return {None: nb_none, **counts} def snapshot_count_branches( self, snapshot_id: Sha1Git, branch_name_exclude_prefix: Optional[bytes] = None, ) -> Dict[Optional[str], int]: """Returns a dictionary from type names to the number of branches of that type.""" prefix = branch_name_exclude_prefix if prefix is None: return self.snapshot_count_branches_from_name(snapshot_id, b"") else: # counts branches before exclude prefix counts = Counter( self.snapshot_count_branches_before_name(snapshot_id, prefix) ) # no need to execute that part if each bit of the prefix equals 1 if prefix.replace(b"\xff", b"") != b"": # counts branches after exclude prefix and update counters counts.update( self.snapshot_count_branches_from_name( snapshot_id, _next_bytes_value(prefix) ) ) return counts @_prepared_select_statement( SnapshotBranchRow, "WHERE snapshot_id = ? AND name >= ? LIMIT ?" ) def snapshot_branch_get_from_name( self, snapshot_id: Sha1Git, from_: bytes, limit: int, *, statement ) -> Iterable[SnapshotBranchRow]: return map( SnapshotBranchRow.from_dict, self._execute_with_retries(statement, [snapshot_id, from_, limit]), ) @_prepared_select_statement( SnapshotBranchRow, "WHERE snapshot_id = ? AND name >= ? AND name < ? LIMIT ?" ) def snapshot_branch_get_range( self, snapshot_id: Sha1Git, from_: bytes, before: bytes, limit: int, *, statement, ) -> Iterable[SnapshotBranchRow]: return map( SnapshotBranchRow.from_dict, self._execute_with_retries(statement, [snapshot_id, from_, before, limit]), ) def snapshot_branch_get( self, snapshot_id: Sha1Git, from_: bytes, limit: int, branch_name_exclude_prefix: Optional[bytes] = None, ) -> Iterable[SnapshotBranchRow]: prefix = branch_name_exclude_prefix if prefix is None: return self.snapshot_branch_get_from_name(snapshot_id, from_, limit) else: # get branches before the exclude prefix branches = list( self.snapshot_branch_get_range(snapshot_id, from_, prefix, limit) ) nb_branches = len(branches) # no need to execute that part if limit is reached # or if each bit of the prefix equals 1 if nb_branches < limit and prefix.replace(b"\xff", b"") != b"": # get branches after the exclude prefix and update list to return branches.extend( self.snapshot_branch_get_from_name( snapshot_id, _next_bytes_value(prefix), limit - nb_branches ) ) return branches ########################## # 'origin' table ########################## @_prepared_insert_statement(OriginRow) def origin_add_one(self, origin: OriginRow, *, statement) -> None: self._add_one(statement, origin) @_prepared_select_statement(OriginRow, "WHERE sha1 = ?") def origin_get_by_sha1(self, sha1: bytes, *, statement) -> Iterable[OriginRow]: return map(OriginRow.from_dict, self._execute_with_retries(statement, [sha1])) def origin_get_by_url(self, url: str) -> Iterable[OriginRow]: return self.origin_get_by_sha1(hash_url(url)) @_prepared_statement( f""" SELECT token(sha1) AS tok, {", ".join(OriginRow.cols())} - FROM {OriginRow.TABLE} + FROM {{keyspace}}.{OriginRow.TABLE} WHERE token(sha1) >= ? LIMIT ? """ ) def origin_list( self, start_token: int, limit: int, *, statement ) -> Iterable[Tuple[int, OriginRow]]: """Returns an iterable of (token, origin)""" return ( (row["tok"], OriginRow.from_dict(remove_keys(row, ("tok",)))) for row in self._execute_with_retries(statement, [start_token, limit]) ) @_prepared_select_statement(OriginRow) def origin_iter_all(self, *, statement) -> Iterable[OriginRow]: return map(OriginRow.from_dict, self._execute_with_retries(statement, [])) @_prepared_statement( f""" - UPDATE {OriginRow.TABLE} + UPDATE {{keyspace}}.{OriginRow.TABLE} SET next_visit_id=? WHERE sha1 = ? IF next_visit_id None: origin_sha1 = hash_url(origin_url) next_id = visit_id + 1 self._execute_with_retries(statement, [next_id, origin_sha1, next_id]) @_prepared_select_statement(OriginRow, "WHERE sha1 = ?", ["next_visit_id"]) def _origin_get_next_visit_id(self, origin_sha1: bytes, *, statement) -> int: rows = list(self._execute_with_retries(statement, [origin_sha1])) assert len(rows) == 1 # TODO: error handling return rows[0]["next_visit_id"] @_prepared_statement( f""" - UPDATE {OriginRow.TABLE} + UPDATE {{keyspace}}.{OriginRow.TABLE} SET next_visit_id=? WHERE sha1 = ? IF next_visit_id=? """ ) def origin_generate_unique_visit_id(self, origin_url: str, *, statement) -> int: origin_sha1 = hash_url(origin_url) next_id = self._origin_get_next_visit_id(origin_sha1) while True: res = list( self._execute_with_retries( statement, [next_id + 1, origin_sha1, next_id] ) ) assert len(res) == 1 if res[0]["[applied]"]: # No data race return next_id else: # Someone else updated it before we did, let's try again next_id = res[0]["next_visit_id"] # TODO: abort after too many attempts return next_id ########################## # 'origin_visit' table ########################## @_prepared_select_statements( OriginVisitRow, { (True, ListOrder.ASC): ( "WHERE origin = ? AND visit > ? ORDER BY visit ASC LIMIT ?" ), (True, ListOrder.DESC): ( "WHERE origin = ? AND visit < ? ORDER BY visit DESC LIMIT ?" ), (False, ListOrder.ASC): "WHERE origin = ? ORDER BY visit ASC LIMIT ?", (False, ListOrder.DESC): "WHERE origin = ? ORDER BY visit DESC LIMIT ?", }, ) def origin_visit_get( self, origin_url: str, last_visit: Optional[int], limit: int, order: ListOrder, *, statements, ) -> Iterable[OriginVisitRow]: args: List[Any] = [origin_url] if last_visit is not None: args.append(last_visit) args.append(limit) statement = statements[(last_visit is not None, order)] return map( OriginVisitRow.from_dict, self._execute_with_retries(statement, args) ) @_prepared_insert_statement(OriginVisitRow) def origin_visit_add_one(self, visit: OriginVisitRow, *, statement) -> None: self._add_one(statement, visit) @_prepared_select_statement(OriginVisitRow, "WHERE origin = ? AND visit = ?") def origin_visit_get_one( self, origin_url: str, visit_id: int, *, statement ) -> Optional[OriginVisitRow]: # TODO: error handling rows = list(self._execute_with_retries(statement, [origin_url, visit_id])) if rows: return OriginVisitRow.from_dict(rows[0]) else: return None @_prepared_select_statement(OriginVisitRow, "WHERE origin = ? ORDER BY visit DESC") def origin_visit_iter_all( self, origin_url: str, *, statement ) -> Iterable[OriginVisitRow]: """Returns an iterator on visits for a given origin, ordered by descending visit id.""" return map( OriginVisitRow.from_dict, self._execute_with_retries(statement, [origin_url]), ) @_prepared_select_statement(OriginVisitRow, "WHERE token(origin) >= ?") def _origin_visit_iter_from( self, min_token: int, *, statement ) -> Iterable[OriginVisitRow]: return map( OriginVisitRow.from_dict, self._execute_with_retries(statement, [min_token]) ) @_prepared_select_statement(OriginVisitRow, "WHERE token(origin) < ?") def _origin_visit_iter_to( self, max_token: int, *, statement ) -> Iterable[OriginVisitRow]: return map( OriginVisitRow.from_dict, self._execute_with_retries(statement, [max_token]) ) def origin_visit_iter(self, start_token: int) -> Iterator[OriginVisitRow]: """Returns all origin visits in order from this token, and wraps around the token space.""" yield from self._origin_visit_iter_from(start_token) yield from self._origin_visit_iter_to(start_token) ########################## # 'origin_visit_status' table ########################## @_prepared_select_statements( OriginVisitStatusRow, { (True, ListOrder.ASC): ( "WHERE origin = ? AND visit = ? AND date >= ? " "ORDER BY visit ASC LIMIT ?" ), (True, ListOrder.DESC): ( "WHERE origin = ? AND visit = ? AND date <= ? " "ORDER BY visit DESC LIMIT ?" ), (False, ListOrder.ASC): ( "WHERE origin = ? AND visit = ? ORDER BY visit ASC LIMIT ?" ), (False, ListOrder.DESC): ( "WHERE origin = ? AND visit = ? ORDER BY visit DESC LIMIT ?" ), }, ) def origin_visit_status_get_range( self, origin: str, visit: int, date_from: Optional[datetime.datetime], limit: int, order: ListOrder, *, statements, ) -> Iterable[OriginVisitStatusRow]: args: List[Any] = [origin, visit] if date_from is not None: args.append(date_from) args.append(limit) statement = statements[(date_from is not None, order)] return map( OriginVisitStatusRow.from_dict, self._execute_with_retries(statement, args) ) @_prepared_select_statement( OriginVisitStatusRow, "WHERE origin = ? AND visit >= ? AND visit <= ? ORDER BY visit ASC, date ASC", ) def origin_visit_status_get_all_range( self, origin_url: str, visit_from: int, visit_to: int, *, statement, ) -> Iterable[OriginVisitStatusRow]: args = (origin_url, visit_from, visit_to) return map( OriginVisitStatusRow.from_dict, self._execute_with_retries(statement, args) ) @_prepared_insert_statement(OriginVisitStatusRow) def origin_visit_status_add_one( self, visit_update: OriginVisitStatusRow, *, statement ) -> None: self._add_one(statement, visit_update) def origin_visit_status_get_latest( self, origin: str, visit: int, ) -> Optional[OriginVisitStatusRow]: """Given an origin visit id, return its latest origin_visit_status""" return next(self.origin_visit_status_get(origin, visit), None) @_prepared_select_statement( OriginVisitStatusRow, # 'visit DESC,' is optional with Cassandra 4, but ScyllaDB needs it "WHERE origin = ? AND visit = ? ORDER BY visit DESC, date DESC", ) def origin_visit_status_get( self, origin: str, visit: int, *, statement, ) -> Iterator[OriginVisitStatusRow]: """Return all origin visit statuses for a given visit""" return map( OriginVisitStatusRow.from_dict, self._execute_with_retries(statement, [origin, visit]), ) @_prepared_select_statement(OriginVisitStatusRow, "WHERE origin = ?", ["snapshot"]) def origin_snapshot_get_all(self, origin: str, *, statement) -> Iterable[Sha1Git]: yield from { d["snapshot"] for d in self._execute_with_retries(statement, [origin]) if d["snapshot"] is not None } ########################## # 'metadata_authority' table ########################## @_prepared_insert_statement(MetadataAuthorityRow) def metadata_authority_add(self, authority: MetadataAuthorityRow, *, statement): self._add_one(statement, authority) @_prepared_select_statement(MetadataAuthorityRow, "WHERE type = ? AND url = ?") def metadata_authority_get( self, type, url, *, statement ) -> Optional[MetadataAuthorityRow]: rows = list(self._execute_with_retries(statement, [type, url])) if rows: return MetadataAuthorityRow.from_dict(rows[0]) else: return None ########################## # 'metadata_fetcher' table ########################## @_prepared_insert_statement(MetadataFetcherRow) def metadata_fetcher_add(self, fetcher, *, statement): self._add_one(statement, fetcher) @_prepared_select_statement(MetadataFetcherRow, "WHERE name = ? AND version = ?") def metadata_fetcher_get( self, name, version, *, statement ) -> Optional[MetadataFetcherRow]: rows = list(self._execute_with_retries(statement, [name, version])) if rows: return MetadataFetcherRow.from_dict(rows[0]) else: return None ######################### # 'raw_extrinsic_metadata_by_id' table ######################### @_prepared_insert_statement(RawExtrinsicMetadataByIdRow) def raw_extrinsic_metadata_by_id_add(self, row, *, statement): self._add_one(statement, row) @_prepared_select_statement(RawExtrinsicMetadataByIdRow, "WHERE id IN ?") def raw_extrinsic_metadata_get_by_ids( self, ids: List[Sha1Git], *, statement ) -> Iterable[RawExtrinsicMetadataByIdRow]: return map( RawExtrinsicMetadataByIdRow.from_dict, self._execute_with_retries(statement, [ids]), ) ######################### # 'raw_extrinsic_metadata' table ######################### @_prepared_insert_statement(RawExtrinsicMetadataRow) def raw_extrinsic_metadata_add(self, raw_extrinsic_metadata, *, statement): self._add_one(statement, raw_extrinsic_metadata) @_prepared_select_statement( RawExtrinsicMetadataRow, "WHERE target=? AND authority_url=? AND discovery_date>? AND authority_type=?", ) def raw_extrinsic_metadata_get_after_date( self, target: str, authority_type: str, authority_url: str, after: datetime.datetime, *, statement, ) -> Iterable[RawExtrinsicMetadataRow]: return map( RawExtrinsicMetadataRow.from_dict, self._execute_with_retries( statement, [target, authority_url, after, authority_type] ), ) @_prepared_select_statement( RawExtrinsicMetadataRow, # This is equivalent to: # WHERE target=? AND authority_type = ? AND authority_url = ? " # AND (discovery_date, id) > (?, ?)" # but it needs to be written this way to work with ScyllaDB. "WHERE target=? AND (authority_type, authority_url) <= (?, ?) " "AND (authority_type, authority_url, discovery_date, id) > (?, ?, ?, ?)", ) def raw_extrinsic_metadata_get_after_date_and_id( self, target: str, authority_type: str, authority_url: str, after_date: datetime.datetime, after_id: bytes, *, statement, ) -> Iterable[RawExtrinsicMetadataRow]: return map( RawExtrinsicMetadataRow.from_dict, self._execute_with_retries( statement, [ target, authority_type, authority_url, authority_type, authority_url, after_date, after_id, ], ), ) @_prepared_select_statement( RawExtrinsicMetadataRow, "WHERE target=? AND authority_url=? AND authority_type=?", ) def raw_extrinsic_metadata_get( self, target: str, authority_type: str, authority_url: str, *, statement ) -> Iterable[RawExtrinsicMetadataRow]: return map( RawExtrinsicMetadataRow.from_dict, self._execute_with_retries( statement, [target, authority_url, authority_type] ), ) @_prepared_select_statement(RawExtrinsicMetadataRow, "WHERE target = ?") def raw_extrinsic_metadata_get_authorities( self, target: str, *, statement ) -> Iterable[Tuple[str, str]]: return ( (entry["authority_type"], entry["authority_url"]) for entry in self._execute_with_retries(statement, [target]) ) ########################## # 'extid' table ########################## def _extid_add_finalize(self, statement: BoundStatement) -> None: """Returned currified by extid_add_prepare, to be called when the extid row should be added to the primary table.""" self._execute_with_retries(statement, None) @_prepared_insert_statement(ExtIDRow) def extid_add_prepare( self, extid: ExtIDRow, *, statement ) -> Tuple[int, Callable[[], None]]: statement = statement.bind(dataclasses.astuple(extid)) token_class = self._cluster.metadata.token_map.token_class token = token_class.from_key(statement.routing_key).value assert TOKEN_BEGIN <= token <= TOKEN_END # Function to be called after the indexes contain their respective # row finalizer = functools.partial(self._extid_add_finalize, statement) return (token, finalizer) @_prepared_select_statement( ExtIDRow, "WHERE extid_type=? AND extid=? AND extid_version=? " "AND target_type=? AND target=?", ) def extid_get_from_pk( self, extid_type: str, extid: bytes, extid_version: int, target: CoreSWHID, *, statement, ) -> Optional[ExtIDRow]: rows = list( self._execute_with_retries( statement, [ extid_type, extid, extid_version, target.object_type.value, target.object_id, ], ), ) assert len(rows) <= 1 if rows: return ExtIDRow(**rows[0]) else: return None @_prepared_select_statement( ExtIDRow, "WHERE token(extid_type, extid) = ?", ) def extid_get_from_token(self, token: int, *, statement) -> Iterable[ExtIDRow]: return map( ExtIDRow.from_dict, self._execute_with_retries(statement, [token]), ) # Rows are partitioned by token(extid_type, extid), then ordered (aka. "clustered") # by (extid_type, extid, extid_version, ...). This means that, without knowing the # exact extid_type and extid, we need to scan the whole partition; which should be # reasonably small. We can change the schema later if this becomes an issue @_prepared_select_statement( ExtIDRow, "WHERE token(extid_type, extid) = ? AND extid_version = ? ALLOW FILTERING", ) def extid_get_from_token_and_extid_version( self, token: int, extid_version: int, *, statement ) -> Iterable[ExtIDRow]: return map( ExtIDRow.from_dict, self._execute_with_retries(statement, [token, extid_version]), ) @_prepared_select_statement( ExtIDRow, "WHERE extid_type=? AND extid=?", ) def extid_get_from_extid( self, extid_type: str, extid: bytes, *, statement ) -> Iterable[ExtIDRow]: return map( ExtIDRow.from_dict, self._execute_with_retries(statement, [extid_type, extid]), ) @_prepared_select_statement( ExtIDRow, "WHERE extid_type=? AND extid=? AND extid_version = ?", ) def extid_get_from_extid_and_version( self, extid_type: str, extid: bytes, extid_version: int, *, statement ) -> Iterable[ExtIDRow]: return map( ExtIDRow.from_dict, self._execute_with_retries(statement, [extid_type, extid, extid_version]), ) def extid_get_from_target( self, target_type: str, target: bytes, extid_type: Optional[str] = None, extid_version: Optional[int] = None, ) -> Iterable[ExtIDRow]: for token in self._extid_get_tokens_from_target(target_type, target): if token is not None: if extid_type is not None and extid_version is not None: extids = self.extid_get_from_token_and_extid_version( token, extid_version ) else: extids = self.extid_get_from_token(token) for extid in extids: # re-check the extid against target (in case of murmur3 collision) if ( extid is not None and extid.target_type == target_type and extid.target == target and ( (extid_version is None and extid_type is None) or ( ( extid_version is not None and extid.extid_version == extid_version and extid_type is not None and extid.extid_type == extid_type ) ) ) ): yield extid ########################## # 'extid_by_target' table ########################## @_prepared_insert_statement(ExtIDByTargetRow) def extid_index_add_one(self, row: ExtIDByTargetRow, *, statement) -> None: """Adds a row mapping extid[target_type, target] to the token of the ExtID in the main 'extid' table.""" self._add_one(statement, row) @_prepared_select_statement( ExtIDByTargetRow, "WHERE target_type = ? AND target = ?" ) def _extid_get_tokens_from_target( self, target_type: str, target: bytes, *, statement ) -> Iterable[int]: return ( row["target_token"] for row in self._execute_with_retries(statement, [target_type, target]) ) ########################## # Miscellaneous ########################## def stat_counters(self) -> Iterable[ObjectCountRow]: raise NotImplementedError( "stat_counters is not implemented by the Cassandra backend" ) - @_prepared_statement("SELECT uuid() FROM revision LIMIT 1;") + @_prepared_statement("SELECT uuid() FROM {keyspace}.revision LIMIT 1;") def check_read(self, *, statement): self._execute_with_retries(statement, []) diff --git a/swh/storage/tests/test_cassandra.py b/swh/storage/tests/test_cassandra.py index 4a6b470c..8f5d584e 100644 --- a/swh/storage/tests/test_cassandra.py +++ b/swh/storage/tests/test_cassandra.py @@ -1,804 +1,804 @@ # Copyright (C) 2018-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import itertools import os import resource import signal import socket import subprocess import time from typing import Any, Dict import attr from cassandra.cluster import NoHostAvailable import pytest from swh.core.api.classes import stream_results from swh.model import from_disk from swh.model.model import Directory, DirectoryEntry, Snapshot, SnapshotBranch from swh.storage import get_storage from swh.storage.cassandra import create_keyspace from swh.storage.cassandra.cql import BATCH_INSERT_MAX_SIZE from swh.storage.cassandra.model import ContentRow, ExtIDRow from swh.storage.cassandra.schema import HASH_ALGORITHMS, TABLES from swh.storage.cassandra.storage import DIRECTORY_ENTRIES_INSERT_ALGOS from swh.storage.tests.storage_data import StorageData from swh.storage.tests.storage_tests import ( TestStorageGeneratedData as _TestStorageGeneratedData, ) from swh.storage.tests.storage_tests import TestStorage as _TestStorage from swh.storage.utils import now, remove_keys CONFIG_TEMPLATE = """ data_file_directories: - {data_dir}/data commitlog_directory: {data_dir}/commitlog hints_directory: {data_dir}/hints saved_caches_directory: {data_dir}/saved_caches commitlog_sync: periodic commitlog_sync_period_in_ms: 1000000 partitioner: org.apache.cassandra.dht.Murmur3Partitioner endpoint_snitch: SimpleSnitch seed_provider: - class_name: org.apache.cassandra.locator.SimpleSeedProvider parameters: - seeds: "127.0.0.1" storage_port: {storage_port} native_transport_port: {native_transport_port} start_native_transport: true listen_address: 127.0.0.1 enable_user_defined_functions: true # speed-up by disabling period saving to disk key_cache_save_period: 0 row_cache_save_period: 0 trickle_fsync: false commitlog_sync_period_in_ms: 100000 """ SCYLLA_EXTRA_CONFIG_TEMPLATE = """ experimental_features: - udf view_hints_directory: {data_dir}/view_hints prometheus_port: 0 # disable prometheus server start_rpc: false # disable thrift server api_port: {api_port} """ def free_port(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.bind(("127.0.0.1", 0)) port = sock.getsockname()[1] sock.close() return port def wait_for_peer(addr, port): wait_until = time.time() + 60 while time.time() < wait_until: try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect((addr, port)) except ConnectionRefusedError: time.sleep(0.1) else: sock.close() return True return False @pytest.fixture(scope="session") def cassandra_cluster(tmpdir_factory): cassandra_conf = tmpdir_factory.mktemp("cassandra_conf") cassandra_data = tmpdir_factory.mktemp("cassandra_data") cassandra_log = tmpdir_factory.mktemp("cassandra_log") native_transport_port = free_port() storage_port = free_port() jmx_port = free_port() api_port = free_port() use_scylla = bool(os.environ.get("SWH_USE_SCYLLADB", "")) cassandra_bin = os.environ.get( "SWH_CASSANDRA_BIN", "/usr/bin/scylla" if use_scylla else "/usr/sbin/cassandra" ) if use_scylla: os.makedirs(cassandra_conf.join("conf")) config_path = cassandra_conf.join("conf/scylla.yaml") config_template = CONFIG_TEMPLATE + SCYLLA_EXTRA_CONFIG_TEMPLATE else: config_path = cassandra_conf.join("cassandra.yaml") config_template = CONFIG_TEMPLATE with open(str(config_path), "w") as fd: fd.write( config_template.format( data_dir=str(cassandra_data), storage_port=storage_port, native_transport_port=native_transport_port, api_port=api_port, ) ) if os.environ.get("SWH_CASSANDRA_LOG"): stdout = stderr = None else: stdout = stderr = subprocess.DEVNULL env = { "MAX_HEAP_SIZE": "300M", "HEAP_NEWSIZE": "50M", "JVM_OPTS": "-Xlog:gc=error:file=%s/gc.log" % cassandra_log, } if "JAVA_HOME" in os.environ: env["JAVA_HOME"] = os.environ["JAVA_HOME"] if use_scylla: env = { **env, "SCYLLA_HOME": cassandra_conf, } # prevent "NOFILE rlimit too low (recommended setting 200000, # minimum setting 10000; refusing to start." resource.setrlimit(resource.RLIMIT_NOFILE, (200000, 200000)) proc = subprocess.Popen( [ cassandra_bin, "--developer-mode=1", ], start_new_session=True, env=env, stdout=stdout, stderr=stderr, ) else: proc = subprocess.Popen( [ cassandra_bin, "-Dcassandra.config=file://%s/cassandra.yaml" % cassandra_conf, "-Dcassandra.logdir=%s" % cassandra_log, "-Dcassandra.jmx.local.port=%d" % jmx_port, "-Dcassandra-foreground=yes", ], start_new_session=True, env=env, stdout=stdout, stderr=stderr, ) listening = wait_for_peer("127.0.0.1", native_transport_port) if listening: yield (["127.0.0.1"], native_transport_port) if not listening or os.environ.get("SWH_CASSANDRA_LOG"): debug_log_path = str(cassandra_log.join("debug.log")) if os.path.exists(debug_log_path): with open(debug_log_path) as fd: print(fd.read()) if not listening: if proc.poll() is None: raise Exception("cassandra process unexpectedly not listening.") else: raise Exception("cassandra process unexpectedly stopped.") pgrp = os.getpgid(proc.pid) os.killpg(pgrp, signal.SIGKILL) class RequestHandler: def on_request(self, rf): if hasattr(rf.message, "query"): print() print(rf.message.query) @pytest.fixture(scope="session") def keyspace(cassandra_cluster): (hosts, port) = cassandra_cluster - keyspace = os.urandom(10).hex() + keyspace = "test" + os.urandom(10).hex() create_keyspace(hosts, keyspace, port) return keyspace # tests are executed using imported classes (TestStorage and # TestStorageGeneratedData) using overloaded swh_storage fixture # below @pytest.fixture def swh_storage_backend_config(cassandra_cluster, keyspace): (hosts, port) = cassandra_cluster storage_config = dict( cls="cassandra", hosts=hosts, port=port, keyspace=keyspace, journal_writer={"cls": "memory"}, objstorage={"cls": "memory"}, ) yield storage_config storage = get_storage(**storage_config) for table in TABLES: - storage._cql_runner._session.execute('TRUNCATE TABLE "%s"' % table) + storage._cql_runner._session.execute(f"TRUNCATE TABLE {keyspace}.{table}") storage._cql_runner._cluster.shutdown() @pytest.mark.cassandra class TestCassandraStorage(_TestStorage): def test_config_wrong_consistency_should_raise(self): storage_config = dict( cls="cassandra", hosts=["first"], port=9999, keyspace="any", consistency_level="fake", journal_writer={"cls": "memory"}, objstorage={"cls": "memory"}, ) with pytest.raises(ValueError, match="Unknown consistency"): get_storage(**storage_config) def test_config_consistency_used(self, swh_storage_backend_config): config_with_consistency = dict( swh_storage_backend_config, **{"consistency_level": "THREE"} ) storage = get_storage(**config_with_consistency) with pytest.raises(NoHostAvailable): storage.content_get_random() def test_content_add_murmur3_collision(self, swh_storage, mocker, sample_data): """The Murmur3 token is used as link from index tables to the main table; and non-matching contents with colliding murmur3-hash are filtered-out when reading the main table. This test checks the content methods do filter out these collision. """ called = 0 cont, cont2 = sample_data.contents[:2] # always return a token def mock_cgtfsa(algo, hashes): nonlocal called called += 1 assert algo in ("sha1", "sha1_git") return [123456] mocker.patch.object( swh_storage._cql_runner, "content_get_tokens_from_single_algo", mock_cgtfsa, ) # For all tokens, always return cont def mock_cgft(tokens): nonlocal called called += 1 return [ ContentRow( length=10, ctime=datetime.datetime.now(), status="present", **{algo: getattr(cont, algo) for algo in HASH_ALGORITHMS}, ) ] mocker.patch.object( swh_storage._cql_runner, "content_get_from_tokens", mock_cgft ) actual_result = swh_storage.content_add([cont2]) assert called == 4 assert actual_result == { "content:add": 1, "content:add:bytes": cont2.length, } def test_content_get_metadata_murmur3_collision( self, swh_storage, mocker, sample_data ): """The Murmur3 token is used as link from index tables to the main table; and non-matching contents with colliding murmur3-hash are filtered-out when reading the main table. This test checks the content methods do filter out these collisions. """ called = 0 cont, cont2 = [attr.evolve(c, ctime=now()) for c in sample_data.contents[:2]] # always return a token def mock_cgtfsa(algo, hashes): nonlocal called called += 1 assert algo in ("sha1", "sha1_git") return [123456] mocker.patch.object( swh_storage._cql_runner, "content_get_tokens_from_single_algo", mock_cgtfsa, ) # For all tokens, always return cont and cont2 cols = list(set(cont.to_dict()) - {"data"}) def mock_cgft(tokens): nonlocal called called += 1 return [ ContentRow( **{col: getattr(cont, col) for col in cols}, ) for cont in [cont, cont2] ] mocker.patch.object( swh_storage._cql_runner, "content_get_from_tokens", mock_cgft ) actual_result = swh_storage.content_get([cont.sha1]) assert called == 2 # dropping extra column not returned expected_cont = attr.evolve(cont, data=None) # but cont2 should be filtered out assert actual_result == [expected_cont] def test_content_find_murmur3_collision(self, swh_storage, mocker, sample_data): """The Murmur3 token is used as link from index tables to the main table; and non-matching contents with colliding murmur3-hash are filtered-out when reading the main table. This test checks the content methods do filter out these collisions. """ called = 0 cont, cont2 = [attr.evolve(c, ctime=now()) for c in sample_data.contents[:2]] # always return a token def mock_cgtfsa(algo, hashes): nonlocal called called += 1 assert algo in ("sha1", "sha1_git") return [123456] mocker.patch.object( swh_storage._cql_runner, "content_get_tokens_from_single_algo", mock_cgtfsa, ) # For all tokens, always return cont and cont2 cols = list(set(cont.to_dict()) - {"data"}) def mock_cgft(tokens): nonlocal called called += 1 return [ ContentRow(**{col: getattr(cont, col) for col in cols}) for cont in [cont, cont2] ] mocker.patch.object( swh_storage._cql_runner, "content_get_from_tokens", mock_cgft ) expected_content = attr.evolve(cont, data=None) actual_result = swh_storage.content_find({"sha1": cont.sha1}) assert called == 2 # but cont2 should be filtered out assert actual_result == [expected_content] def test_content_get_partition_murmur3_collision( self, swh_storage, mocker, sample_data ): """The Murmur3 token is used as link from index tables to the main table; and non-matching contents with colliding murmur3-hash are filtered-out when reading the main table. This test checks the content_get_partition endpoints return all contents, even the collisions. """ called = 0 rows: Dict[int, Dict] = {} for tok, content in enumerate(sample_data.contents): cont = attr.evolve(content, data=None, ctime=now()) row_d = {**cont.to_dict(), "tok": tok} rows[tok] = row_d # For all tokens, always return cont def mock_content_get_token_range(range_start, range_end, limit): nonlocal called called += 1 for tok in list(rows.keys()) * 3: # yield multiple times the same tok row_d = dict(rows[tok].items()) row_d.pop("tok") yield (tok, ContentRow(**row_d)) mocker.patch.object( swh_storage._cql_runner, "content_get_token_range", mock_content_get_token_range, ) actual_results = list( stream_results( swh_storage.content_get_partition, partition_id=0, nb_partitions=1 ) ) assert called > 0 # everything is listed, even collisions assert len(actual_results) == 3 * len(sample_data.contents) # as we duplicated the returned results, dropping duplicate should yield # the original length assert len(set(actual_results)) == len(sample_data.contents) @pytest.mark.skip("content_update is not yet implemented for Cassandra") def test_content_update(self): pass def test_extid_murmur3_collision(self, swh_storage, mocker, sample_data): """The Murmur3 token is used as link from index table to the main table; and non-matching extid with colliding murmur3-hash are filtered-out when reading the main table. This test checks the extid methods do filter out these collision. """ swh_storage.extid_add(sample_data.extids) # For any token, always return all extids, i.e. make as if all tokens # for all extid entries collide def mock_egft(token): return [ ExtIDRow( extid_type=extid.extid_type, extid=extid.extid, extid_version=extid.extid_version, target_type=extid.target.object_type.value, target=extid.target.object_id, ) for extid in sample_data.extids ] mocker.patch.object( swh_storage._cql_runner, "extid_get_from_token", mock_egft, ) for extid in sample_data.extids: extids = swh_storage.extid_get_from_target( target_type=extid.target.object_type, ids=[extid.target.object_id] ) assert extids == [extid] def _directory_with_entries(self, sample_data, nb_entries): """Returns a dir with ``nb_entries``, all pointing to the same content""" return Directory( entries=tuple( DirectoryEntry( name=f"file{i:10}".encode(), type="file", target=sample_data.content.sha1_git, perms=from_disk.DentryPerms.directory, ) for i in range(nb_entries) ) ) @pytest.mark.parametrize( "insert_algo,nb_entries", [ ("one-by-one", 10), ("concurrent", 10), ("batch", 1), ("batch", 2), ("batch", BATCH_INSERT_MAX_SIZE - 1), ("batch", BATCH_INSERT_MAX_SIZE), ("batch", BATCH_INSERT_MAX_SIZE + 1), ("batch", BATCH_INSERT_MAX_SIZE * 2), ], ) def test_directory_add_algos( self, swh_storage, sample_data, mocker, insert_algo, nb_entries, ): mocker.patch.object(swh_storage, "_directory_entries_insert_algo", insert_algo) class new_sample_data: content = sample_data.content directory = self._directory_with_entries(sample_data, nb_entries) self.test_directory_add(swh_storage, new_sample_data) @pytest.mark.parametrize("insert_algo", DIRECTORY_ENTRIES_INSERT_ALGOS) def test_directory_add_atomic(self, swh_storage, sample_data, mocker, insert_algo): """Checks that a crash occurring after some directory entries were written does not cause the directory to be (partially) visible. ie. checks directories are added somewhat atomically.""" # Disable the journal writer, it would detect the CrashyEntry exception too # early for this test to be relevant swh_storage.journal_writer.journal = None mocker.patch.object(swh_storage, "_directory_entries_insert_algo", insert_algo) class CrashyEntry(DirectoryEntry): def __init__(self): super().__init__(**{**directory.entries[0].to_dict(), "name": b"crash"}) def to_dict(self): return {**super().to_dict(), "perms": "abcde"} directory = self._directory_with_entries(sample_data, BATCH_INSERT_MAX_SIZE) entries = directory.entries directory = attr.evolve(directory, entries=entries + (CrashyEntry(),)) with pytest.raises(TypeError): swh_storage.directory_add([directory]) # This should have written some of the entries to the database: entry_rows = swh_storage._cql_runner.directory_entry_get([directory.id]) assert {row.name for row in entry_rows} == {entry.name for entry in entries} # BUT, because not all the entries were written, the directory should # be considered not written. assert swh_storage.directory_missing([directory.id]) == [directory.id] assert list(swh_storage.directory_ls(directory.id)) == [] assert swh_storage.directory_get_entries(directory.id) is None def test_directory_add_raw_manifest__different_entries__allow_overwrite( self, swh_storage ): """This test demonstrates a shortcoming of the Cassandra storage backend's design: 1. add a directory with an entry named "name1" and raw_manifest="abc" 2. add a directory with an entry named "name2" and the same raw_manifest 3. the directories' id is computed only from the raw_manifest, so both directories have the same id, which causes their entries to be "additive" in the database; so directory_ls returns both entries However, by default, the Cassandra storage has allow_overwrite=False, which "accidentally" avoids this issue most of the time, by skipping insertion if an object with the same id is already in the database. This can still be an issue when either allow_overwrite=True or when inserting both directories at about the same time (because of the lack of transactionality); but the likelihood of two clients inserting two different objects with the same manifest at the same time is very low, it could only happen if loaders running in parallel used different (or nondeterministic) parsers on corrupt objects. """ assert ( swh_storage._allow_overwrite is False ), "Unexpected default _allow_overwrite value" swh_storage._allow_overwrite = True # Run the other test, but skip its last assertion dir_id = self.test_directory_add_raw_manifest__different_entries( swh_storage, check_ls=False ) assert [entry["name"] for entry in swh_storage.directory_ls(dir_id)] == [ b"name1", b"name2", ] def test_snapshot_add_atomic(self, swh_storage, sample_data, mocker): """Checks that a crash occurring after some snapshot branches were written does not cause the snapshot to be (partially) visible. ie. checks snapshots are added somewhat atomically.""" # Disable the journal writer, it would detect the CrashyBranch exception too # early for this test to be relevant swh_storage.journal_writer.journal = None class MyException(Exception): pass class CrashyBranch(SnapshotBranch): def __getattribute__(self, name): if name == "target" and should_raise: raise MyException() else: return super().__getattribute__(name) snapshot = sample_data.complete_snapshot branches = snapshot.branches should_raise = False # just so that we can construct the object crashy_branch = CrashyBranch.from_dict(branches[b"directory"].to_dict()) should_raise = True snapshot = attr.evolve( snapshot, branches={ **branches, b"crashy": crashy_branch, }, ) with pytest.raises(MyException): swh_storage.snapshot_add([snapshot]) # This should have written some of the branches to the database: branch_rows = swh_storage._cql_runner.snapshot_branch_get(snapshot.id, b"", 10) assert {row.name for row in branch_rows} == set(branches) # BUT, because not all the branches were written, the snapshot should # be considered not written. assert swh_storage.snapshot_missing([snapshot.id]) == [snapshot.id] assert swh_storage.snapshot_get(snapshot.id) is None assert swh_storage.snapshot_count_branches(snapshot.id) is None assert swh_storage.snapshot_get_branches(snapshot.id) is None @pytest.mark.skip( 'The "person" table of the pgsql is a legacy thing, and not ' "supported by the cassandra backend." ) def test_person_fullname_unicity(self): pass @pytest.mark.skip( 'The "person" table of the pgsql is a legacy thing, and not ' "supported by the cassandra backend." ) def test_person_get(self): pass @pytest.mark.skip("Not supported by Cassandra") def test_origin_count(self): pass @pytest.mark.cassandra class TestCassandraStorageGeneratedData(_TestStorageGeneratedData): @pytest.mark.skip("Not supported by Cassandra") def test_origin_count(self): pass @pytest.mark.skip("Not supported by Cassandra") def test_origin_count_with_visit_no_visits(self): pass @pytest.mark.skip("Not supported by Cassandra") def test_origin_count_with_visit_with_visits_and_snapshot(self): pass @pytest.mark.skip("Not supported by Cassandra") def test_origin_count_with_visit_with_visits_no_snapshot(self): pass @pytest.mark.parametrize( "allow_overwrite,object_type", itertools.product( [False, True], # Note the absence of "content", it's tested above. ["directory", "revision", "release", "snapshot", "origin", "extid"], ), ) def test_allow_overwrite( allow_overwrite: bool, object_type: str, swh_storage_backend_config ): if object_type in ("origin", "extid"): pytest.skip( f"test_disallow_overwrite not implemented for {object_type} objects, " f"because all their columns are in the primary key." ) swh_storage = get_storage( allow_overwrite=allow_overwrite, **swh_storage_backend_config ) # directory_ls joins with content and directory table, and needs those to return # non-None entries: if object_type == "directory": swh_storage.directory_add([StorageData.directory5]) swh_storage.content_add([StorageData.content, StorageData.content2]) obj1: Any obj2: Any # Get two test objects if object_type == "directory": (obj1, obj2, *_) = StorageData.directories elif object_type == "snapshot": # StorageData.snapshots[1] is the empty snapshot, which is the corner case # that makes this test succeed for the wrong reasons obj1 = StorageData.snapshot obj2 = StorageData.complete_snapshot else: (obj1, obj2, *_) = getattr(StorageData, (object_type + "s")) # Let's make both objects have the same hash, but different content obj1 = attr.evolve(obj1, id=obj2.id) # Get the methods used to add and get these objects add = getattr(swh_storage, object_type + "_add") if object_type == "directory": def get(ids): return [ Directory( id=ids[0], entries=tuple( map( lambda entry: DirectoryEntry( name=entry["name"], type=entry["type"], target=entry["sha1_git"], perms=entry["perms"], ), swh_storage.directory_ls(ids[0]), ) ), ) ] elif object_type == "snapshot": def get(ids): return [ Snapshot.from_dict( remove_keys(swh_storage.snapshot_get(ids[0]), ("next_branch",)) ) ] else: get = getattr(swh_storage, object_type + "_get") # Add the first object add([obj1]) # It should be returned as-is assert get([obj1.id]) == [obj1] # Add the second object add([obj2]) if allow_overwrite: # obj1 was overwritten by obj2 expected = obj2 else: # obj2 was not written, because obj1 already exists and has the same hash expected = obj1 if allow_overwrite and object_type in ("directory", "snapshot"): # TODO pytest.xfail( "directory entries and snapshot branches are concatenated " "instead of being replaced" ) assert get([obj1.id]) == [expected] diff --git a/swh/storage/tests/test_cassandra_migration.py b/swh/storage/tests/test_cassandra_migration.py index bee854fa..8227150b 100644 --- a/swh/storage/tests/test_cassandra_migration.py +++ b/swh/storage/tests/test_cassandra_migration.py @@ -1,342 +1,342 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """This module tests the migration capabilities of the Cassandra backend, by sending CQL commands (eg. 'ALTER TABLE'), and by monkey-patching large parts of the implementations to simulate code updates,.""" import dataclasses import functools import operator from typing import Dict, Iterable, Optional import attr import pytest from swh.model.model import Content from swh.storage import get_storage from swh.storage.cassandra.cql import ( CqlRunner, _prepared_insert_statement, _prepared_select_statement, ) from swh.storage.cassandra.model import ContentRow from swh.storage.cassandra.schema import CONTENT_INDEX_TEMPLATE, HASH_ALGORITHMS from swh.storage.cassandra.storage import CassandraStorage from swh.storage.exc import StorageArgumentException from .storage_data import StorageData from .test_cassandra import ( # noqa, needed for swh_storage fixture cassandra_cluster, keyspace, swh_storage_backend_config, ) ############################## # Common structures def byte_xor_hash(data): # Behold, a one-line hash function: return bytes([functools.reduce(operator.xor, data)]) @attr.s class ContentWithXor(Content): """An hypothetical upgrade of Content with an extra "hash".""" byte_xor = attr.ib(type=bytes, default=None) ############################## # Test simple migrations @dataclasses.dataclass class ContentRowWithXor(ContentRow): """An hypothetical upgrade of ContentRow with an extra "hash", but not in the primary key.""" byte_xor: bytes class CqlRunnerWithXor(CqlRunner): """An hypothetical upgrade of ContentRow with an extra "hash", but not in the primary key.""" @_prepared_select_statement( ContentRowWithXor, f"WHERE {' AND '.join(map('%s = ?'.__mod__, HASH_ALGORITHMS))}", ) def content_get_from_pk( self, content_hashes: Dict[str, bytes], *, statement ) -> Optional[ContentRow]: rows = list( self._execute_with_retries( statement, [content_hashes[algo] for algo in HASH_ALGORITHMS] ) ) assert len(rows) <= 1 if rows: return ContentRowWithXor(**rows[0]) else: return None @_prepared_select_statement( ContentRowWithXor, f"WHERE token({', '.join(ContentRowWithXor.PARTITION_KEY)}) = ?", ) def content_get_from_tokens( self, tokens, *, statement ) -> Iterable[ContentRowWithXor]: return map( ContentRowWithXor.from_dict, self._execute_many_with_retries(statement, [(token,) for token in tokens]), ) # Redecorate content_add_prepare with the new ContentRow class content_add_prepare = _prepared_insert_statement(ContentRowWithXor)( # type: ignore CqlRunner.content_add_prepare.__wrapped__ # type: ignore ) def test_add_content_column( swh_storage: CassandraStorage, swh_storage_backend_config, mocker # noqa ) -> None: """Adds a column to the 'content' table and a new matching index. This is a simple migration, as it does not require an update to the primary key. """ content_xor_hash = byte_xor_hash(StorageData.content.data) # First insert some existing data swh_storage.content_add([StorageData.content, StorageData.content2]) # Then update the schema - swh_storage._cql_runner._session.execute("ALTER TABLE content ADD byte_xor blob") + session = swh_storage._cql_runner._cluster.connect(swh_storage._cql_runner.keyspace) + session.execute("ALTER TABLE content ADD byte_xor blob") for statement in CONTENT_INDEX_TEMPLATE.split("\n\n"): - swh_storage._cql_runner._session.execute(statement.format(main_algo="byte_xor")) + session.execute(statement.format(main_algo="byte_xor")) # Should not affect the running code at all: assert swh_storage.content_get([StorageData.content.sha1]) == [ attr.evolve(StorageData.content, data=None) ] with pytest.raises(StorageArgumentException): swh_storage.content_find({"byte_xor": content_xor_hash}) # Then update the running code: new_hash_algos = HASH_ALGORITHMS + ["byte_xor"] mocker.patch("swh.storage.cassandra.storage.HASH_ALGORITHMS", new_hash_algos) mocker.patch("swh.storage.cassandra.cql.HASH_ALGORITHMS", new_hash_algos) mocker.patch("swh.model.model.DEFAULT_ALGORITHMS", new_hash_algos) mocker.patch("swh.storage.cassandra.storage.Content", ContentWithXor) mocker.patch("swh.storage.cassandra.storage.ContentRow", ContentRowWithXor) mocker.patch("swh.storage.cassandra.model.ContentRow", ContentRowWithXor) mocker.patch("swh.storage.cassandra.storage.CqlRunner", CqlRunnerWithXor) # Forge new objects with this extra hash: new_content = ContentWithXor.from_dict( { "byte_xor": byte_xor_hash(StorageData.content.data), **StorageData.content.to_dict(), } ) new_content2 = ContentWithXor.from_dict( { "byte_xor": byte_xor_hash(StorageData.content2.data), **StorageData.content2.to_dict(), } ) # Simulates a restart: swh_storage._set_cql_runner() # Old algos still works, and return the new object type: assert swh_storage.content_get([StorageData.content.sha1]) == [ attr.evolve(new_content, data=None, byte_xor=None) ] # The new algo does not work, we did not backfill it yet: assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [] # A normal storage would not overwrite, because the object already exists, # as it is not aware it is missing a field: swh_storage.content_add([new_content, new_content2]) assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [] # Backfill (in production this would be done with a replayer reading from # the journal): overwriting_swh_storage = get_storage( allow_overwrite=True, **swh_storage_backend_config ) overwriting_swh_storage.content_add([new_content, new_content2]) # Now, the object can be found: assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [ attr.evolve(new_content, data=None) ] ############################## # Test complex migrations @dataclasses.dataclass class ContentRowWithXorPK(ContentRow): """An hypothetical upgrade of ContentRow with an extra "hash", in the primary key.""" TABLE = "content_v2" PARTITION_KEY = ("sha1", "sha1_git", "sha256", "blake2s256", "byte_xor") byte_xor: bytes class CqlRunnerWithXorPK(CqlRunner): """An hypothetical upgrade of ContentRow with an extra "hash", but not in the primary key.""" @_prepared_select_statement( ContentRowWithXorPK, f"WHERE {' AND '.join(map('%s = ?'.__mod__, HASH_ALGORITHMS))} AND byte_xor=?", ) def content_get_from_pk( self, content_hashes: Dict[str, bytes], *, statement ) -> Optional[ContentRow]: rows = list( self._execute_with_retries( statement, [content_hashes[algo] for algo in HASH_ALGORITHMS + ["byte_xor"]], ) ) assert len(rows) <= 1 if rows: return ContentRowWithXorPK(**rows[0]) else: return None @_prepared_select_statement( ContentRowWithXorPK, f"WHERE token({', '.join(ContentRowWithXorPK.PARTITION_KEY)}) = ?", ) def content_get_from_tokens( self, tokens, *, statement ) -> Iterable[ContentRowWithXorPK]: return map( ContentRowWithXorPK.from_dict, self._execute_many_with_retries(statement, [(token,) for token in tokens]), ) # Redecorate content_add_prepare with the new ContentRow class content_add_prepare = _prepared_insert_statement(ContentRowWithXorPK)( # type: ignore # noqa CqlRunner.content_add_prepare.__wrapped__ # type: ignore ) def test_change_content_pk( swh_storage: CassandraStorage, swh_storage_backend_config, mocker # noqa ) -> None: """Adds a column to the 'content' table and a new matching index; and make this new column part of the primary key This is a complex migration, as it requires copying the whole table """ content_xor_hash = byte_xor_hash(StorageData.content.data) + session = swh_storage._cql_runner._cluster.connect(swh_storage._cql_runner.keyspace) # First insert some existing data swh_storage.content_add([StorageData.content, StorageData.content2]) # Then add a new table and a new index - swh_storage._cql_runner._session.execute( + session.execute( """ CREATE TABLE IF NOT EXISTS content_v2 ( sha1 blob, sha1_git blob, sha256 blob, blake2s256 blob, byte_xor blob, length bigint, ctime timestamp, -- creation time, i.e. time of (first) injection into the storage status ascii, PRIMARY KEY ((sha1, sha1_git, sha256, blake2s256, byte_xor)) );""" ) for statement in CONTENT_INDEX_TEMPLATE.split("\n\n"): - swh_storage._cql_runner._session.execute(statement.format(main_algo="byte_xor")) + session.execute(statement.format(main_algo="byte_xor")) # Should not affect the running code at all: assert swh_storage.content_get([StorageData.content.sha1]) == [ attr.evolve(StorageData.content, data=None) ] with pytest.raises(StorageArgumentException): swh_storage.content_find({"byte_xor": content_xor_hash}) # Then update the running code: new_hash_algos = HASH_ALGORITHMS + ["byte_xor"] mocker.patch("swh.storage.cassandra.storage.HASH_ALGORITHMS", new_hash_algos) mocker.patch("swh.storage.cassandra.cql.HASH_ALGORITHMS", new_hash_algos) mocker.patch("swh.model.model.DEFAULT_ALGORITHMS", new_hash_algos) mocker.patch("swh.storage.cassandra.storage.Content", ContentWithXor) mocker.patch("swh.storage.cassandra.storage.ContentRow", ContentRowWithXorPK) mocker.patch("swh.storage.cassandra.model.ContentRow", ContentRowWithXorPK) mocker.patch("swh.storage.cassandra.storage.CqlRunner", CqlRunnerWithXorPK) # Forge new objects with this extra hash: new_content = ContentWithXor.from_dict( { "byte_xor": byte_xor_hash(StorageData.content.data), **StorageData.content.to_dict(), } ) new_content2 = ContentWithXor.from_dict( { "byte_xor": byte_xor_hash(StorageData.content2.data), **StorageData.content2.to_dict(), } ) # Replay to the new table. # In production this would be done with a replayer reading from the journal, # while loaders would still write to the DB. overwriting_swh_storage = get_storage( allow_overwrite=True, **swh_storage_backend_config ) overwriting_swh_storage.content_add([new_content, new_content2]) # Old algos still works, and return the new object type; # but the byte_xor value is None because it is only available in the new # table, which this storage is not yet configured to use assert swh_storage.content_get([StorageData.content.sha1]) == [ attr.evolve(new_content, data=None, byte_xor=None) ] # When the replayer gets close to the end of the logs, loaders are stopped # to allow the replayer to catch up with the end of the log. # When it does, we can switch over to the new swh-storage's code. # Simulates a restart: swh_storage._set_cql_runner() # Now, the object can be found with the new hash: assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [ attr.evolve(new_content, data=None) ] # Remove the old table: - swh_storage._cql_runner._session.execute("DROP TABLE content") + session.execute("DROP TABLE content") # Object is still available, because we don't use it anymore assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [ attr.evolve(new_content, data=None) ] # THE END. # Test teardown expects a table with this name to exist: - swh_storage._cql_runner._session.execute( - "CREATE TABLE content (foo blob PRIMARY KEY);" - ) + session.execute("CREATE TABLE content (foo blob PRIMARY KEY);") # Clean up this table, test teardown does not know about it: - swh_storage._cql_runner._session.execute("DROP TABLE content_v2;") + session.execute("DROP TABLE content_v2;")