diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py index c5cf1953..789991cd 100644 --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -1,1018 +1,1020 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import dataclasses import datetime import functools import logging import random from typing import ( Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Type, TypeVar, ) from cassandra import CoordinationFailure from cassandra.cluster import Cluster, EXEC_PROFILE_DEFAULT, ExecutionProfile, ResultSet from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy from cassandra.query import PreparedStatement, BoundStatement, dict_factory from tenacity import ( retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type, ) from mypy_extensions import NamedArg from swh.model.model import ( Content, SkippedContent, Sha1Git, TimestampWithTimezone, Timestamp, Person, ) from swh.storage.interface import ListOrder from .common import TOKEN_BEGIN, TOKEN_END, hash_url, remove_keys from .model import ( BaseRow, ContentRow, DirectoryEntryRow, DirectoryRow, MetadataAuthorityRow, MetadataFetcherRow, ObjectCountRow, OriginRow, OriginVisitRow, OriginVisitStatusRow, RawExtrinsicMetadataRow, ReleaseRow, RevisionParentRow, RevisionRow, SkippedContentRow, SnapshotBranchRow, SnapshotRow, ) from .schema import CREATE_TABLES_QUERIES, HASH_ALGORITHMS logger = logging.getLogger(__name__) _execution_profiles = { EXEC_PROFILE_DEFAULT: ExecutionProfile( load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy()), row_factory=dict_factory, ), } # Configuration for cassandra-driver's access to servers: # * hit the right server directly when sending a query (TokenAwarePolicy), # * if there's more than one, then pick one at random that's in the same # datacenter as the client (DCAwareRoundRobinPolicy) def create_keyspace( hosts: List[str], keyspace: str, port: int = 9042, *, durable_writes=True ): cluster = Cluster(hosts, port=port, execution_profiles=_execution_profiles) session = cluster.connect() extra_params = "" if not durable_writes: extra_params = "AND durable_writes = false" session.execute( """CREATE KEYSPACE IF NOT EXISTS "%s" WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 } %s; """ % (keyspace, extra_params) ) session.execute('USE "%s"' % keyspace) for query in CREATE_TABLES_QUERIES: session.execute(query) TRet = TypeVar("TRet") def _prepared_statement( query: str, ) -> Callable[[Callable[..., TRet]], Callable[..., TRet]]: """Returns a decorator usable on methods of CqlRunner, to inject them with a 'statement' argument, that is a prepared statement corresponding to the query. This only works on methods of CqlRunner, as preparing a statement requires a connection to a Cassandra server.""" def decorator(f): @functools.wraps(f) def newf(self, *args, **kwargs) -> TRet: if f.__name__ not in self._prepared_statements: statement: PreparedStatement = self._session.prepare(query) self._prepared_statements[f.__name__] = statement return f( self, *args, **kwargs, statement=self._prepared_statements[f.__name__] ) return newf return decorator TArg = TypeVar("TArg") TSelf = TypeVar("TSelf") def _prepared_insert_statement( row_class: Type[BaseRow], ) -> Callable[ [Callable[[TSelf, TArg, NamedArg(Any, "statement")], TRet]], # noqa Callable[[TSelf, TArg], TRet], ]: """Shorthand for using `_prepared_statement` for `INSERT INTO` statements.""" columns = row_class.cols() return _prepared_statement( "INSERT INTO %s (%s) VALUES (%s)" % (row_class.TABLE, ", ".join(columns), ", ".join("?" for _ in columns),) ) def _prepared_exists_statement( table_name: str, ) -> Callable[ [Callable[[TSelf, TArg, NamedArg(Any, "statement")], TRet]], # noqa Callable[[TSelf, TArg], TRet], ]: """Shorthand for using `_prepared_statement` for queries that only check which ids in a list exist in the table.""" return _prepared_statement(f"SELECT id FROM {table_name} WHERE id IN ?") def _prepared_select_statement( row_class: Type[BaseRow], clauses: str = "", cols: Optional[List[str]] = None, ) -> Callable[[Callable[..., TRet]], Callable[..., TRet]]: if cols is None: cols = row_class.cols() return _prepared_statement( f"SELECT {', '.join(cols)} FROM {row_class.TABLE} {clauses}" ) class CqlRunner: """Class managing prepared statements and building queries to be sent to Cassandra.""" def __init__(self, hosts: List[str], keyspace: str, port: int): self._cluster = Cluster( hosts, port=port, execution_profiles=_execution_profiles ) self._session = self._cluster.connect(keyspace) self._cluster.register_user_type( keyspace, "microtimestamp_with_timezone", TimestampWithTimezone ) self._cluster.register_user_type(keyspace, "microtimestamp", Timestamp) self._cluster.register_user_type(keyspace, "person", Person) self._prepared_statements: Dict[str, PreparedStatement] = {} ########################## # Common utility functions ########################## MAX_RETRIES = 3 @retry( wait=wait_random_exponential(multiplier=1, max=10), stop=stop_after_attempt(MAX_RETRIES), retry=retry_if_exception_type(CoordinationFailure), ) def _execute_with_retries(self, statement, args) -> ResultSet: return self._session.execute(statement, args, timeout=1000.0) @_prepared_statement( "UPDATE object_count SET count = count + ? " "WHERE partition_key = 0 AND object_type = ?" ) def _increment_counter( self, object_type: str, nb: int, *, statement: PreparedStatement ) -> None: self._execute_with_retries(statement, [nb, object_type]) def _add_one(self, statement, obj: BaseRow) -> None: self._increment_counter(obj.TABLE, 1) self._execute_with_retries(statement, dataclasses.astuple(obj)) _T = TypeVar("_T", bound=BaseRow) def _get_random_row(self, row_class: Type[_T], statement) -> Optional[_T]: # noqa """Takes a prepared statement of the form "SELECT * FROM WHERE token() > ? LIMIT 1" and uses it to return a random row""" token = random.randint(TOKEN_BEGIN, TOKEN_END) rows = self._execute_with_retries(statement, [token]) if not rows: # There are no row with a greater token; wrap around to get # the row with the smallest token rows = self._execute_with_retries(statement, [TOKEN_BEGIN]) if rows: return row_class.from_dict(rows.one()) # type: ignore else: return None def _missing(self, statement, ids): rows = self._execute_with_retries(statement, [ids]) found_ids = {row["id"] for row in rows} return [id_ for id_ in ids if id_ not in found_ids] ########################## # 'content' table ########################## def _content_add_finalize(self, statement: BoundStatement) -> None: """Returned currified by content_add_prepare, to be called when the content row should be added to the primary table.""" self._execute_with_retries(statement, None) self._increment_counter("content", 1) @_prepared_insert_statement(ContentRow) def content_add_prepare( self, content: ContentRow, *, statement ) -> Tuple[int, Callable[[], None]]: """Prepares insertion of a Content to the main 'content' table. Returns a token (to be used in secondary tables), and a function to be called to perform the insertion in the main table.""" statement = statement.bind(dataclasses.astuple(content)) # Type used for hashing keys (usually, it will be # cassandra.metadata.Murmur3Token) token_class = self._cluster.metadata.token_map.token_class # Token of the row when it will be inserted. This is equivalent to # "SELECT token({', '.join(ContentRow.PARTITION_KEY)}) FROM content WHERE ..." # after the row is inserted; but we need the token to insert in the # index tables *before* inserting to the main 'content' table token = token_class.from_key(statement.routing_key).value assert TOKEN_BEGIN <= token <= TOKEN_END # Function to be called after the indexes contain their respective # row finalizer = functools.partial(self._content_add_finalize, statement) return (token, finalizer) @_prepared_select_statement( ContentRow, f"WHERE {' AND '.join(map('%s = ?'.__mod__, HASH_ALGORITHMS))}" ) def content_get_from_pk( self, content_hashes: Dict[str, bytes], *, statement ) -> Optional[ContentRow]: rows = list( self._execute_with_retries( statement, [content_hashes[algo] for algo in HASH_ALGORITHMS] ) ) assert len(rows) <= 1 if rows: return ContentRow(**rows[0]) else: return None @_prepared_select_statement( ContentRow, f"WHERE token({', '.join(ContentRow.PARTITION_KEY)}) = ?" ) def content_get_from_token(self, token, *, statement) -> Iterable[ContentRow]: return map(ContentRow.from_dict, self._execute_with_retries(statement, [token])) @_prepared_select_statement( ContentRow, f"WHERE token({', '.join(ContentRow.PARTITION_KEY)}) > ? LIMIT 1" ) def content_get_random(self, *, statement) -> Optional[ContentRow]: return self._get_random_row(ContentRow, statement) @_prepared_statement( ( "SELECT token({0}) AS tok, {1} FROM content " "WHERE token({0}) >= ? AND token({0}) <= ? LIMIT ?" ).format(", ".join(ContentRow.PARTITION_KEY), ", ".join(ContentRow.cols())) ) def content_get_token_range( self, start: int, end: int, limit: int, *, statement ) -> Iterable[Tuple[int, ContentRow]]: """Returns an iterable of (token, row)""" return ( (row["tok"], ContentRow.from_dict(remove_keys(row, ("tok",)))) for row in self._execute_with_retries(statement, [start, end, limit]) ) ########################## # 'content_by_*' tables ########################## @_prepared_statement( "SELECT sha1_git AS id FROM content_by_sha1_git WHERE sha1_git IN ?" ) def content_missing_by_sha1_git( self, ids: List[bytes], *, statement ) -> List[bytes]: return self._missing(statement, ids) def content_index_add_one(self, algo: str, content: Content, token: int) -> None: """Adds a row mapping content[algo] to the token of the Content in the main 'content' table.""" query = ( f"INSERT INTO content_by_{algo} ({algo}, target_token) " f"VALUES (%s, %s)" ) self._execute_with_retries(query, [content.get_hash(algo), token]) def content_get_tokens_from_single_hash( self, algo: str, hash_: bytes ) -> Iterable[int]: assert algo in HASH_ALGORITHMS query = f"SELECT target_token FROM content_by_{algo} WHERE {algo} = %s" return ( row["target_token"] for row in self._execute_with_retries(query, [hash_]) ) ########################## # 'skipped_content' table ########################## _magic_null_pk = b"" """ NULLs (or all-empty blobs) are not allowed in primary keys; instead use a special value that can't possibly be a valid hash. """ def _skipped_content_add_finalize(self, statement: BoundStatement) -> None: """Returned currified by skipped_content_add_prepare, to be called when the content row should be added to the primary table.""" self._execute_with_retries(statement, None) self._increment_counter("skipped_content", 1) @_prepared_insert_statement(SkippedContentRow) def skipped_content_add_prepare( self, content, *, statement ) -> Tuple[int, Callable[[], None]]: """Prepares insertion of a Content to the main 'skipped_content' table. Returns a token (to be used in secondary tables), and a function to be called to perform the insertion in the main table.""" # Replace NULLs (which are not allowed in the partition key) with # an empty byte string for key in SkippedContentRow.PARTITION_KEY: if getattr(content, key) is None: setattr(content, key, self._magic_null_pk) statement = statement.bind(dataclasses.astuple(content)) # Type used for hashing keys (usually, it will be # cassandra.metadata.Murmur3Token) token_class = self._cluster.metadata.token_map.token_class # Token of the row when it will be inserted. This is equivalent to # "SELECT token({', '.join(SkippedContentRow.PARTITION_KEY)}) # FROM skipped_content WHERE ..." # after the row is inserted; but we need the token to insert in the # index tables *before* inserting to the main 'skipped_content' table token = token_class.from_key(statement.routing_key).value assert TOKEN_BEGIN <= token <= TOKEN_END # Function to be called after the indexes contain their respective # row finalizer = functools.partial(self._skipped_content_add_finalize, statement) return (token, finalizer) @_prepared_select_statement( SkippedContentRow, f"WHERE {' AND '.join(map('%s = ?'.__mod__, HASH_ALGORITHMS))}", ) def skipped_content_get_from_pk( self, content_hashes: Dict[str, bytes], *, statement ) -> Optional[SkippedContentRow]: rows = list( self._execute_with_retries( statement, [ content_hashes[algo] or self._magic_null_pk for algo in HASH_ALGORITHMS ], ) ) assert len(rows) <= 1 if rows: # TODO: convert _magic_null_pk back to None? return SkippedContentRow.from_dict(rows[0]) else: return None ########################## # 'skipped_content_by_*' tables ########################## def skipped_content_index_add_one( self, algo: str, content: SkippedContent, token: int ) -> None: """Adds a row mapping content[algo] to the token of the SkippedContent in the main 'skipped_content' table.""" query = ( f"INSERT INTO skipped_content_by_{algo} ({algo}, target_token) " f"VALUES (%s, %s)" ) self._execute_with_retries( query, [content.get_hash(algo) or self._magic_null_pk, token] ) ########################## # 'revision' table ########################## @_prepared_exists_statement("revision") def revision_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) @_prepared_insert_statement(RevisionRow) def revision_add_one(self, revision: RevisionRow, *, statement) -> None: self._add_one(statement, revision) @_prepared_statement("SELECT id FROM revision WHERE id IN ?") def revision_get_ids(self, revision_ids, *, statement) -> Iterable[int]: return ( row["id"] for row in self._execute_with_retries(statement, [revision_ids]) ) @_prepared_select_statement(RevisionRow, "WHERE id IN ?") - def revision_get(self, revision_ids, *, statement) -> Iterable[RevisionRow]: + def revision_get( + self, revision_ids: List[Sha1Git], *, statement + ) -> Iterable[RevisionRow]: return map( RevisionRow.from_dict, self._execute_with_retries(statement, [revision_ids]) ) @_prepared_select_statement(RevisionRow, "WHERE token(id) > ? LIMIT 1") def revision_get_random(self, *, statement) -> Optional[RevisionRow]: return self._get_random_row(RevisionRow, statement) ########################## # 'revision_parent' table ########################## @_prepared_insert_statement(RevisionParentRow) def revision_parent_add_one( self, revision_parent: RevisionParentRow, *, statement ) -> None: self._add_one(statement, revision_parent) @_prepared_statement("SELECT parent_id FROM revision_parent WHERE id = ?") def revision_parent_get( self, revision_id: Sha1Git, *, statement ) -> Iterable[bytes]: return ( row["parent_id"] for row in self._execute_with_retries(statement, [revision_id]) ) ########################## # 'release' table ########################## @_prepared_exists_statement("release") def release_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) @_prepared_insert_statement(ReleaseRow) def release_add_one(self, release: ReleaseRow, *, statement) -> None: self._add_one(statement, release) @_prepared_select_statement(ReleaseRow, "WHERE id in ?") def release_get(self, release_ids: List[str], *, statement) -> Iterable[ReleaseRow]: return map( ReleaseRow.from_dict, self._execute_with_retries(statement, [release_ids]) ) @_prepared_select_statement(ReleaseRow, "WHERE token(id) > ? LIMIT 1") def release_get_random(self, *, statement) -> Optional[ReleaseRow]: return self._get_random_row(ReleaseRow, statement) ########################## # 'directory' table ########################## @_prepared_exists_statement("directory") def directory_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) @_prepared_insert_statement(DirectoryRow) def directory_add_one(self, directory: DirectoryRow, *, statement) -> None: """Called after all calls to directory_entry_add_one, to commit/finalize the directory.""" self._add_one(statement, directory) @_prepared_select_statement(DirectoryRow, "WHERE token(id) > ? LIMIT 1") def directory_get_random(self, *, statement) -> Optional[DirectoryRow]: return self._get_random_row(DirectoryRow, statement) ########################## # 'directory_entry' table ########################## @_prepared_insert_statement(DirectoryEntryRow) def directory_entry_add_one(self, entry: DirectoryEntryRow, *, statement) -> None: self._add_one(statement, entry) @_prepared_select_statement(DirectoryEntryRow, "WHERE directory_id IN ?") def directory_entry_get( self, directory_ids, *, statement ) -> Iterable[DirectoryEntryRow]: return map( DirectoryEntryRow.from_dict, self._execute_with_retries(statement, [directory_ids]), ) ########################## # 'snapshot' table ########################## @_prepared_exists_statement("snapshot") def snapshot_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) @_prepared_insert_statement(SnapshotRow) def snapshot_add_one(self, snapshot: SnapshotRow, *, statement) -> None: self._add_one(statement, snapshot) @_prepared_select_statement(SnapshotRow, "WHERE id = ?") def snapshot_get(self, snapshot_id: Sha1Git, *, statement) -> ResultSet: return map( SnapshotRow.from_dict, self._execute_with_retries(statement, [snapshot_id]) ) @_prepared_select_statement(SnapshotRow, "WHERE token(id) > ? LIMIT 1") def snapshot_get_random(self, *, statement) -> Optional[SnapshotRow]: return self._get_random_row(SnapshotRow, statement) ########################## # 'snapshot_branch' table ########################## @_prepared_insert_statement(SnapshotBranchRow) def snapshot_branch_add_one(self, branch: SnapshotBranchRow, *, statement) -> None: self._add_one(statement, branch) @_prepared_statement( "SELECT ascii_bins_count(target_type) AS counts " "FROM snapshot_branch " "WHERE snapshot_id = ? " ) def snapshot_count_branches( self, snapshot_id: Sha1Git, *, statement ) -> Dict[Optional[str], int]: """Returns a dictionary from type names to the number of branches of that type.""" row = self._execute_with_retries(statement, [snapshot_id]).one() (nb_none, counts) = row["counts"] return {None: nb_none, **counts} @_prepared_select_statement( SnapshotBranchRow, "WHERE snapshot_id = ? AND name >= ? LIMIT ?" ) def snapshot_branch_get( self, snapshot_id: Sha1Git, from_: bytes, limit: int, *, statement ) -> Iterable[SnapshotBranchRow]: return map( SnapshotBranchRow.from_dict, self._execute_with_retries(statement, [snapshot_id, from_, limit]), ) ########################## # 'origin' table ########################## @_prepared_insert_statement(OriginRow) def origin_add_one(self, origin: OriginRow, *, statement) -> None: self._add_one(statement, origin) @_prepared_select_statement(OriginRow, "WHERE sha1 = ?") def origin_get_by_sha1(self, sha1: bytes, *, statement) -> Iterable[OriginRow]: return map(OriginRow.from_dict, self._execute_with_retries(statement, [sha1])) def origin_get_by_url(self, url: str) -> Iterable[OriginRow]: return self.origin_get_by_sha1(hash_url(url)) @_prepared_statement( f'SELECT token(sha1) AS tok, {", ".join(OriginRow.cols())} ' f"FROM origin WHERE token(sha1) >= ? LIMIT ?" ) def origin_list( self, start_token: int, limit: int, *, statement ) -> Iterable[Tuple[int, OriginRow]]: """Returns an iterable of (token, origin)""" return ( (row["tok"], OriginRow.from_dict(remove_keys(row, ("tok",)))) for row in self._execute_with_retries(statement, [start_token, limit]) ) @_prepared_select_statement(OriginRow) def origin_iter_all(self, *, statement) -> Iterable[OriginRow]: return map(OriginRow.from_dict, self._execute_with_retries(statement, [])) @_prepared_statement("SELECT next_visit_id FROM origin WHERE sha1 = ?") def _origin_get_next_visit_id(self, origin_sha1: bytes, *, statement) -> int: rows = list(self._execute_with_retries(statement, [origin_sha1])) assert len(rows) == 1 # TODO: error handling return rows[0]["next_visit_id"] @_prepared_statement( "UPDATE origin SET next_visit_id=? WHERE sha1 = ? IF next_visit_id=?" ) def origin_generate_unique_visit_id(self, origin_url: str, *, statement) -> int: origin_sha1 = hash_url(origin_url) next_id = self._origin_get_next_visit_id(origin_sha1) while True: res = list( self._execute_with_retries( statement, [next_id + 1, origin_sha1, next_id] ) ) assert len(res) == 1 if res[0]["[applied]"]: # No data race return next_id else: # Someone else updated it before we did, let's try again next_id = res[0]["next_visit_id"] # TODO: abort after too many attempts return next_id ########################## # 'origin_visit' table ########################## @_prepared_select_statement( OriginVisitRow, "WHERE origin = ? AND visit > ? ORDER BY visit ASC" ) def _origin_visit_get_pagination_asc_no_limit( self, origin_url: str, last_visit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, last_visit]) @_prepared_select_statement( OriginVisitRow, "WHERE origin = ? AND visit > ? ORDER BY visit ASC LIMIT ?" ) def _origin_visit_get_pagination_asc_limit( self, origin_url: str, last_visit: int, limit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, last_visit, limit]) @_prepared_select_statement( OriginVisitRow, "WHERE origin = ? AND visit < ? ORDER BY visit DESC" ) def _origin_visit_get_pagination_desc_no_limit( self, origin_url: str, last_visit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, last_visit]) @_prepared_select_statement( OriginVisitRow, "WHERE origin = ? AND visit < ? ORDER BY visit DESC LIMIT ?" ) def _origin_visit_get_pagination_desc_limit( self, origin_url: str, last_visit: int, limit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, last_visit, limit]) @_prepared_select_statement( OriginVisitRow, "WHERE origin = ? ORDER BY visit ASC LIMIT ?" ) def _origin_visit_get_no_pagination_asc_limit( self, origin_url: str, limit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, limit]) @_prepared_select_statement(OriginVisitRow, "WHERE origin = ? ORDER BY visit ASC ") def _origin_visit_get_no_pagination_asc_no_limit( self, origin_url: str, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url]) @_prepared_select_statement(OriginVisitRow, "WHERE origin = ? ORDER BY visit DESC") def _origin_visit_get_no_pagination_desc_no_limit( self, origin_url: str, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url]) @_prepared_select_statement( OriginVisitRow, "WHERE origin = ? ORDER BY visit DESC LIMIT ?" ) def _origin_visit_get_no_pagination_desc_limit( self, origin_url: str, limit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, limit]) def origin_visit_get( self, origin_url: str, last_visit: Optional[int], limit: Optional[int], order: ListOrder, ) -> Iterable[OriginVisitRow]: args: List[Any] = [origin_url] if last_visit is not None: page_name = "pagination" args.append(last_visit) else: page_name = "no_pagination" if limit is not None: limit_name = "limit" args.append(limit) else: limit_name = "no_limit" method_name = f"_origin_visit_get_{page_name}_{order.value}_{limit_name}" origin_visit_get_method = getattr(self, method_name) return map(OriginVisitRow.from_dict, origin_visit_get_method(*args)) @_prepared_select_statement( OriginVisitStatusRow, "WHERE origin = ? AND visit = ? AND date >= ? ORDER BY date ASC LIMIT ?", ) def _origin_visit_status_get_with_date_asc_limit( self, origin: str, visit: int, date_from: datetime.datetime, limit: int, *, statement, ) -> ResultSet: return self._execute_with_retries(statement, [origin, visit, date_from, limit]) @_prepared_select_statement( OriginVisitStatusRow, "WHERE origin = ? AND visit = ? AND date <= ? ORDER BY visit DESC LIMIT ?", ) def _origin_visit_status_get_with_date_desc_limit( self, origin: str, visit: int, date_from: datetime.datetime, limit: int, *, statement, ) -> ResultSet: return self._execute_with_retries(statement, [origin, visit, date_from, limit]) @_prepared_select_statement( OriginVisitStatusRow, "WHERE origin = ? AND visit = ? ORDER BY visit ASC LIMIT ?", ) def _origin_visit_status_get_with_no_date_asc_limit( self, origin: str, visit: int, limit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin, visit, limit]) @_prepared_select_statement( OriginVisitStatusRow, "WHERE origin = ? AND visit = ? ORDER BY visit DESC LIMIT ?", ) def _origin_visit_status_get_with_no_date_desc_limit( self, origin: str, visit: int, limit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin, visit, limit]) def origin_visit_status_get_range( self, origin: str, visit: int, date_from: Optional[datetime.datetime], limit: int, order: ListOrder, ) -> Iterable[OriginVisitStatusRow]: args: List[Any] = [origin, visit] if date_from is not None: date_name = "date" args.append(date_from) else: date_name = "no_date" args.append(limit) method_name = f"_origin_visit_status_get_with_{date_name}_{order.value}_limit" origin_visit_status_get_method = getattr(self, method_name) return map( OriginVisitStatusRow.from_dict, origin_visit_status_get_method(*args) ) @_prepared_insert_statement(OriginVisitRow) def origin_visit_add_one(self, visit: OriginVisitRow, *, statement) -> None: self._add_one(statement, visit) @_prepared_insert_statement(OriginVisitStatusRow) def origin_visit_status_add_one( self, visit_update: OriginVisitStatusRow, *, statement ) -> None: self._add_one(statement, visit_update) def origin_visit_status_get_latest( self, origin: str, visit: int, ) -> Optional[OriginVisitStatusRow]: """Given an origin visit id, return its latest origin_visit_status """ return next(self.origin_visit_status_get(origin, visit), None) @_prepared_select_statement( OriginVisitStatusRow, "WHERE origin = ? AND visit = ? ORDER BY date DESC" ) def origin_visit_status_get( self, origin: str, visit: int, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, *, statement, ) -> Iterator[OriginVisitStatusRow]: """Return all origin visit statuses for a given visit """ return map( OriginVisitStatusRow.from_dict, self._execute_with_retries(statement, [origin, visit]), ) @_prepared_select_statement(OriginVisitRow, "WHERE origin = ? AND visit = ?") def origin_visit_get_one( self, origin_url: str, visit_id: int, *, statement ) -> Optional[OriginVisitRow]: # TODO: error handling rows = list(self._execute_with_retries(statement, [origin_url, visit_id])) if rows: return OriginVisitRow.from_dict(rows[0]) else: return None @_prepared_select_statement(OriginVisitRow, "WHERE origin = ?") def origin_visit_get_all( self, origin_url: str, *, statement ) -> Iterable[OriginVisitRow]: return map( OriginVisitRow.from_dict, self._execute_with_retries(statement, [origin_url]), ) @_prepared_select_statement(OriginVisitRow, "WHERE token(origin) >= ?") def _origin_visit_iter_from( self, min_token: int, *, statement ) -> Iterable[OriginVisitRow]: return map( OriginVisitRow.from_dict, self._execute_with_retries(statement, [min_token]) ) @_prepared_select_statement(OriginVisitRow, "WHERE token(origin) < ?") def _origin_visit_iter_to( self, max_token: int, *, statement ) -> Iterable[OriginVisitRow]: return map( OriginVisitRow.from_dict, self._execute_with_retries(statement, [max_token]) ) def origin_visit_iter(self, start_token: int) -> Iterator[OriginVisitRow]: """Returns all origin visits in order from this token, and wraps around the token space.""" yield from self._origin_visit_iter_from(start_token) yield from self._origin_visit_iter_to(start_token) ########################## # 'metadata_authority' table ########################## @_prepared_insert_statement(MetadataAuthorityRow) def metadata_authority_add(self, authority: MetadataAuthorityRow, *, statement): self._add_one(statement, authority) @_prepared_select_statement(MetadataAuthorityRow, "WHERE type = ? AND url = ?") def metadata_authority_get( self, type, url, *, statement ) -> Optional[MetadataAuthorityRow]: rows = list(self._execute_with_retries(statement, [type, url])) if rows: return MetadataAuthorityRow.from_dict(rows[0]) else: return None ########################## # 'metadata_fetcher' table ########################## @_prepared_insert_statement(MetadataFetcherRow) def metadata_fetcher_add(self, fetcher, *, statement): self._add_one(statement, fetcher) @_prepared_select_statement(MetadataFetcherRow, "WHERE name = ? AND version = ?") def metadata_fetcher_get( self, name, version, *, statement ) -> Optional[MetadataFetcherRow]: rows = list(self._execute_with_retries(statement, [name, version])) if rows: return MetadataFetcherRow.from_dict(rows[0]) else: return None ######################### # 'raw_extrinsic_metadata' table ######################### @_prepared_insert_statement(RawExtrinsicMetadataRow) def raw_extrinsic_metadata_add(self, raw_extrinsic_metadata, *, statement): self._add_one(statement, raw_extrinsic_metadata) @_prepared_select_statement( RawExtrinsicMetadataRow, "WHERE id=? AND authority_url=? AND discovery_date>? AND authority_type=?", ) def raw_extrinsic_metadata_get_after_date( self, id: str, authority_type: str, authority_url: str, after: datetime.datetime, *, statement, ) -> Iterable[RawExtrinsicMetadataRow]: return map( RawExtrinsicMetadataRow.from_dict, self._execute_with_retries( statement, [id, authority_url, after, authority_type] ), ) @_prepared_select_statement( RawExtrinsicMetadataRow, "WHERE id=? AND authority_type=? AND authority_url=? " "AND (discovery_date, fetcher_name, fetcher_version) > (?, ?, ?)", ) def raw_extrinsic_metadata_get_after_date_and_fetcher( self, id: str, authority_type: str, authority_url: str, after_date: datetime.datetime, after_fetcher_name: str, after_fetcher_version: str, *, statement, ) -> Iterable[RawExtrinsicMetadataRow]: return map( RawExtrinsicMetadataRow.from_dict, self._execute_with_retries( statement, [ id, authority_type, authority_url, after_date, after_fetcher_name, after_fetcher_version, ], ), ) @_prepared_select_statement( RawExtrinsicMetadataRow, "WHERE id=? AND authority_url=? AND authority_type=?" ) def raw_extrinsic_metadata_get( self, id: str, authority_type: str, authority_url: str, *, statement ) -> Iterable[RawExtrinsicMetadataRow]: return map( RawExtrinsicMetadataRow.from_dict, self._execute_with_retries(statement, [id, authority_url, authority_type]), ) ########################## # Miscellaneous ########################## @_prepared_statement("SELECT uuid() FROM revision LIMIT 1;") def check_read(self, *, statement): self._execute_with_retries(statement, []) @_prepared_select_statement(ObjectCountRow, "WHERE partition_key=0") def stat_counters(self, *, statement) -> Iterable[ObjectCountRow]: return map(ObjectCountRow.from_dict, self._execute_with_retries(statement, [])) diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py index 5fa8404e..1f6fbc82 100644 --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -1,1217 +1,1186 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 import bisect import collections import copy import datetime import functools import itertools import random import re from collections import defaultdict from datetime import timedelta from typing import ( Any, Callable, Dict, Generic, Hashable, Iterable, Iterator, List, Optional, - Set, Tuple, Type, TypeVar, Union, ) import attr from swh.core.api.serializers import msgpack_loads, msgpack_dumps from swh.model.identifiers import SWHID from swh.model.model import ( Content, SkippedContent, - Revision, Release, Snapshot, OriginVisit, OriginVisitStatus, Origin, MetadataAuthority, MetadataAuthorityType, MetadataFetcher, MetadataTargetType, RawExtrinsicMetadata, Sha1Git, ) from swh.storage.cassandra import CassandraStorage from swh.storage.cassandra.model import ( BaseRow, ContentRow, DirectoryRow, DirectoryEntryRow, ObjectCountRow, + RevisionRow, + RevisionParentRow, SkippedContentRow, ) from swh.storage.interface import ( ListOrder, PagedResult, PartialBranches, VISIT_STATUSES, ) from swh.storage.objstorage import ObjStorage from swh.storage.utils import now from .converters import origin_url_to_sha1 from .exc import StorageArgumentException from .writer import JournalWriter # Max block size of contents to return BULK_BLOCK_CONTENT_LEN_MAX = 10000 SortedListItem = TypeVar("SortedListItem") SortedListKey = TypeVar("SortedListKey") FetcherKey = Tuple[str, str] class SortedList(collections.UserList, Generic[SortedListKey, SortedListItem]): data: List[Tuple[SortedListKey, SortedListItem]] # https://github.com/python/mypy/issues/708 # key: Callable[[SortedListItem], SortedListKey] def __init__( self, data: List[SortedListItem] = None, key: Optional[Callable[[SortedListItem], SortedListKey]] = None, ): if key is None: def key(item): return item assert key is not None # for mypy super().__init__(sorted((key(x), x) for x in data or [])) self.key: Callable[[SortedListItem], SortedListKey] = key def add(self, item: SortedListItem): k = self.key(item) bisect.insort(self.data, (k, item)) def __iter__(self) -> Iterator[SortedListItem]: for (k, item) in self.data: yield item def iter_from(self, start_key: Any) -> Iterator[SortedListItem]: """Returns an iterator over all the elements whose key is greater or equal to `start_key`. (This is an efficient equivalent to: `(x for x in L if key(x) >= start_key)`) """ from_index = bisect.bisect_left(self.data, (start_key,)) for (k, item) in itertools.islice(self.data, from_index, None): yield item def iter_after(self, start_key: Any) -> Iterator[SortedListItem]: """Same as iter_from, but using a strict inequality.""" it = self.iter_from(start_key) for item in it: if self.key(item) > start_key: # type: ignore yield item break yield from it TRow = TypeVar("TRow", bound=BaseRow) class Table(Generic[TRow]): def __init__(self, row_class: Type[TRow]): self.row_class = row_class self.primary_key_cols = row_class.PARTITION_KEY + row_class.CLUSTERING_KEY # Map from tokens to clustering keys to rows # These are not actually partitions (or rather, there is one partition # for each token) and they aren't sorted. # But it is good enough if we don't care about performance; # and makes the code a lot simpler. self.data: Dict[int, Dict[Tuple, TRow]] = defaultdict(dict) def __repr__(self): return f"<__module__.Table[{self.row_class.__name__}] object>" def partition_key(self, row: Union[TRow, Dict[str, Any]]) -> Tuple: """Returns the partition key of a row (ie. the cells which get hashed into the token.""" if isinstance(row, dict): row_d = row else: row_d = row.to_dict() return tuple(row_d[col] for col in self.row_class.PARTITION_KEY) def clustering_key(self, row: Union[TRow, Dict[str, Any]]) -> Tuple: """Returns the clustering key of a row (ie. the cells which are used for sorting rows within a partition.""" if isinstance(row, dict): row_d = row else: row_d = row.to_dict() return tuple(row_d[col] for col in self.row_class.CLUSTERING_KEY) def primary_key(self, row): return self.partition_key(row) + self.clustering_key(row) def primary_key_from_dict(self, d: Dict[str, Any]) -> Tuple: """Returns the primary key (ie. concatenation of partition key and clustering key) of the given dictionary interpreted as a row.""" return tuple(d[col] for col in self.primary_key_cols) def token(self, key: Tuple): """Returns the token of a row (ie. the hash of its partition key).""" return hash(key) def get_partition(self, token: int) -> Dict[Tuple, TRow]: """Returns the partition that contains this token.""" return self.data[token] def insert(self, row: TRow): partition = self.data[self.token(self.partition_key(row))] partition[self.clustering_key(row)] = row def split_primary_key(self, key: Tuple) -> Tuple[Tuple, Tuple]: """Returns (partition_key, clustering_key) from a partition key""" assert len(key) == len(self.primary_key_cols) partition_key = key[0 : len(self.row_class.PARTITION_KEY)] clustering_key = key[len(self.row_class.PARTITION_KEY) :] return (partition_key, clustering_key) def get_from_partition_key(self, partition_key: Tuple) -> Iterable[TRow]: """Returns at most one row, from its partition key.""" token = self.token(partition_key) for row in self.get_from_token(token): if self.partition_key(row) == partition_key: yield row def get_from_primary_key(self, primary_key: Tuple) -> Optional[TRow]: """Returns at most one row, from its primary key.""" (partition_key, clustering_key) = self.split_primary_key(primary_key) token = self.token(partition_key) partition = self.get_partition(token) return partition.get(clustering_key) def get_from_token(self, token: int) -> Iterable[TRow]: """Returns all rows whose token (ie. non-cryptographic hash of the partition key) is the one passed as argument.""" return (v for (k, v) in sorted(self.get_partition(token).items())) def iter_all(self) -> Iterator[Tuple[Tuple, TRow]]: return ( (self.primary_key(row), row) for (token, partition) in self.data.items() for (clustering_key, row) in partition.items() ) def get_random(self) -> Optional[TRow]: return random.choice([row for (pk, row) in self.iter_all()]) class InMemoryCqlRunner: def __init__(self): self._contents = Table(ContentRow) self._content_indexes = defaultdict(lambda: defaultdict(set)) self._skipped_contents = Table(ContentRow) self._skipped_content_indexes = defaultdict(lambda: defaultdict(set)) self._directories = Table(DirectoryRow) self._directory_entries = Table(DirectoryEntryRow) + self._revisions = Table(RevisionRow) + self._revision_parents = Table(RevisionParentRow) self._stat_counters = defaultdict(int) def increment_counter(self, object_type: str, nb: int): self._stat_counters[object_type] += nb def stat_counters(self) -> Iterable[ObjectCountRow]: for (object_type, count) in self._stat_counters.items(): yield ObjectCountRow(partition_key=0, object_type=object_type, count=count) ########################## # 'content' table ########################## def _content_add_finalize(self, content: ContentRow) -> None: self._contents.insert(content) self.increment_counter("content", 1) def content_add_prepare(self, content: ContentRow): finalizer = functools.partial(self._content_add_finalize, content) return (self._contents.token(self._contents.partition_key(content)), finalizer) def content_get_from_pk( self, content_hashes: Dict[str, bytes] ) -> Optional[ContentRow]: primary_key = self._contents.primary_key_from_dict(content_hashes) return self._contents.get_from_primary_key(primary_key) def content_get_from_token(self, token: int) -> Iterable[ContentRow]: return self._contents.get_from_token(token) def content_get_random(self) -> Optional[ContentRow]: return self._contents.get_random() def content_get_token_range( self, start: int, end: int, limit: int, ) -> Iterable[Tuple[int, ContentRow]]: matches = [ (token, row) for (token, partition) in self._contents.data.items() for (clustering_key, row) in partition.items() if start <= token <= end ] matches.sort() return matches[0:limit] ########################## # 'content_by_*' tables ########################## def content_missing_by_sha1_git(self, ids: List[bytes]) -> List[bytes]: missing = [] for id_ in ids: if id_ not in self._content_indexes["sha1_git"]: missing.append(id_) - return missing def content_index_add_one(self, algo: str, content: Content, token: int) -> None: self._content_indexes[algo][content.get_hash(algo)].add(token) def content_get_tokens_from_single_hash( self, algo: str, hash_: bytes ) -> Iterable[int]: return self._content_indexes[algo][hash_] ########################## # 'skipped_content' table ########################## def _skipped_content_add_finalize(self, content: SkippedContentRow) -> None: self._skipped_contents.insert(content) self.increment_counter("skipped_content", 1) def skipped_content_add_prepare(self, content: SkippedContentRow): finalizer = functools.partial(self._skipped_content_add_finalize, content) return ( self._skipped_contents.token(self._contents.partition_key(content)), finalizer, ) def skipped_content_get_from_pk( self, content_hashes: Dict[str, bytes] ) -> Optional[SkippedContentRow]: primary_key = self._skipped_contents.primary_key_from_dict(content_hashes) return self._skipped_contents.get_from_primary_key(primary_key) ########################## # 'skipped_content_by_*' tables ########################## def skipped_content_index_add_one( self, algo: str, content: SkippedContent, token: int ) -> None: self._skipped_content_indexes[algo][content.get_hash(algo)].add(token) ########################## # 'directory' table ########################## def directory_missing(self, ids: List[bytes]) -> List[bytes]: missing = [] for id_ in ids: if self._directories.get_from_primary_key((id_,)) is None: missing.append(id_) - return missing def directory_add_one(self, directory: DirectoryRow) -> None: self._directories.insert(directory) self.increment_counter("directory", 1) def directory_get_random(self) -> Optional[DirectoryRow]: return self._directories.get_random() ########################## # 'directory_entry' table ########################## def directory_entry_add_one(self, entry: DirectoryEntryRow) -> None: self._directory_entries.insert(entry) def directory_entry_get( self, directory_ids: List[Sha1Git] ) -> Iterable[DirectoryEntryRow]: for id_ in directory_ids: yield from self._directory_entries.get_from_partition_key((id_,)) ########################## # 'revision' table ########################## - def revision_missing(self, ids: List[bytes]) -> List[bytes]: - return ids + def revision_missing(self, ids: List[bytes]) -> Iterable[bytes]: + missing = [] + for id_ in ids: + if self._revisions.get_from_primary_key((id_,)) is None: + missing.append(id_) + return missing + + def revision_add_one(self, revision: RevisionRow) -> None: + self._revisions.insert(revision) + self.increment_counter("revision", 1) + + def revision_get_ids(self, revision_ids) -> Iterable[int]: + for id_ in revision_ids: + if self._revisions.get_from_primary_key((id_,)) is not None: + yield id_ + + def revision_get(self, revision_ids: List[Sha1Git]) -> Iterable[RevisionRow]: + for id_ in revision_ids: + row = self._revisions.get_from_primary_key((id_,)) + if row: + yield row + + def revision_get_random(self) -> Optional[RevisionRow]: + return self._revisions.get_random() + + ########################## + # 'revision_parent' table + ########################## + + def revision_parent_add_one(self, revision_parent: RevisionParentRow) -> None: + self._revision_parents.insert(revision_parent) + + def revision_parent_get(self, revision_id: Sha1Git) -> Iterable[bytes]: + for parent in self._revision_parents.get_from_partition_key((revision_id,)): + yield parent.parent_id ########################## # 'release' table ########################## def release_missing(self, ids: List[bytes]) -> List[bytes]: return ids class InMemoryStorage(CassandraStorage): _cql_runner: InMemoryCqlRunner # type: ignore def __init__(self, journal_writer=None): self.reset() self.journal_writer = JournalWriter(journal_writer) def reset(self): self._cql_runner = InMemoryCqlRunner() - self._revisions = {} self._releases = {} self._snapshots = {} self._origins = {} self._origins_by_sha1 = {} self._origin_visits = {} self._origin_visit_statuses: Dict[Tuple[str, int], List[OriginVisitStatus]] = {} self._persons = {} # {object_type: {id: {authority: [metadata]}}} self._raw_extrinsic_metadata: Dict[ MetadataTargetType, Dict[ Union[str, SWHID], Dict[ Hashable, SortedList[ Tuple[datetime.datetime, FetcherKey], RawExtrinsicMetadata ], ], ], ] = defaultdict( lambda: defaultdict( lambda: defaultdict( lambda: SortedList( key=lambda x: ( x.discovery_date, self._metadata_fetcher_key(x.fetcher), ) ) ) ) ) # noqa self._metadata_fetchers: Dict[FetcherKey, MetadataFetcher] = {} self._metadata_authorities: Dict[Hashable, MetadataAuthority] = {} self._objects = defaultdict(list) self._sorted_sha1s = SortedList[bytes, bytes]() self.objstorage = ObjStorage({"cls": "memory", "args": {}}) def check_config(self, *, check_write: bool) -> bool: return True - def revision_add(self, revisions: List[Revision]) -> Dict: - revisions = [rev for rev in revisions if rev.id not in self._revisions] - self.journal_writer.revision_add(revisions) - - count = 0 - for revision in revisions: - revision = attr.evolve( - revision, - committer=self._person_add(revision.committer), - author=self._person_add(revision.author), - ) - self._revisions[revision.id] = revision - self._objects[revision.id].append(("revision", revision.id)) - count += 1 - - self._cql_runner.increment_counter("revision", len(revisions)) - - return {"revision:add": count} - - def revision_missing(self, revisions: List[Sha1Git]) -> Iterable[Sha1Git]: - for id in revisions: - if id not in self._revisions: - yield id - - def revision_get( - self, revisions: List[Sha1Git] - ) -> Iterable[Optional[Dict[str, Any]]]: - for id in revisions: - if id in self._revisions: - yield self._revisions.get(id).to_dict() - else: - yield None - - def __get_parent_revs( - self, rev_id: Sha1Git, seen: Set[Sha1Git], limit: Optional[int] - ) -> Iterable[Dict[str, Any]]: - if limit and len(seen) >= limit: - return - if rev_id in seen or rev_id not in self._revisions: - return - seen.add(rev_id) - yield self._revisions[rev_id].to_dict() - for parent in self._revisions[rev_id].parents: - yield from self.__get_parent_revs(parent, seen, limit) - - def revision_log( - self, revisions: List[Sha1Git], limit: Optional[int] = None - ) -> Iterable[Optional[Dict[str, Any]]]: - seen: Set[Sha1Git] = set() - for rev_id in revisions: - yield from self.__get_parent_revs(rev_id, seen, limit) - - def revision_shortlog( - self, revisions: List[Sha1Git], limit: Optional[int] = None - ) -> Iterable[Optional[Tuple[Sha1Git, Tuple[Sha1Git, ...]]]]: - yield from ( - (rev["id"], rev["parents"]) if rev else None - for rev in self.revision_log(revisions, limit) - ) - - def revision_get_random(self) -> Sha1Git: - return random.choice(list(self._revisions)) - def release_add(self, releases: List[Release]) -> Dict: to_add = [] for rel in releases: if rel.id not in self._releases and rel not in to_add: to_add.append(rel) self.journal_writer.release_add(to_add) for rel in to_add: if rel.author: self._person_add(rel.author) self._objects[rel.id].append(("release", rel.id)) self._releases[rel.id] = rel self._cql_runner.increment_counter("release", len(to_add)) return {"release:add": len(to_add)} def release_missing(self, releases: List[Sha1Git]) -> Iterable[Sha1Git]: yield from (rel for rel in releases if rel not in self._releases) def release_get( self, releases: List[Sha1Git] ) -> Iterable[Optional[Dict[str, Any]]]: for rel_id in releases: if rel_id in self._releases: yield self._releases[rel_id].to_dict() else: yield None def release_get_random(self) -> Sha1Git: return random.choice(list(self._releases)) def snapshot_add(self, snapshots: List[Snapshot]) -> Dict: count = 0 snapshots = [snap for snap in snapshots if snap.id not in self._snapshots] for snapshot in snapshots: self.journal_writer.snapshot_add([snapshot]) self._snapshots[snapshot.id] = snapshot self._objects[snapshot.id].append(("snapshot", snapshot.id)) count += 1 self._cql_runner.increment_counter("snapshot", len(snapshots)) return {"snapshot:add": count} def snapshot_missing(self, snapshots: List[Sha1Git]) -> Iterable[Sha1Git]: for id in snapshots: if id not in self._snapshots: yield id def snapshot_get(self, snapshot_id: Sha1Git) -> Optional[Dict[str, Any]]: d = self.snapshot_get_branches(snapshot_id) if d is None: return None return { "id": d["id"], "branches": { name: branch.to_dict() if branch else None for (name, branch) in d["branches"].items() }, "next_branch": d["next_branch"], } def snapshot_get_by_origin_visit( self, origin: str, visit: int ) -> Optional[Dict[str, Any]]: origin_url = self._get_origin_url(origin) if not origin_url: return None if origin_url not in self._origins or visit > len( self._origin_visits[origin_url] ): return None visit_d = self._origin_visit_get_updated(origin_url, visit) snapshot_id = visit_d["snapshot"] if snapshot_id: return self.snapshot_get(snapshot_id) else: return None def snapshot_count_branches( self, snapshot_id: Sha1Git ) -> Optional[Dict[Optional[str], int]]: snapshot = self._snapshots[snapshot_id] return collections.Counter( branch.target_type.value if branch else None for branch in snapshot.branches.values() ) def snapshot_get_branches( self, snapshot_id: Sha1Git, branches_from: bytes = b"", branches_count: int = 1000, target_types: Optional[List[str]] = None, ) -> Optional[PartialBranches]: snapshot = self._snapshots.get(snapshot_id) if snapshot is None: return None sorted_branches = sorted(snapshot.branches.items()) sorted_branch_names = [k for (k, v) in sorted_branches] from_index = bisect.bisect_left(sorted_branch_names, branches_from) if target_types: next_branch = None branches: Dict = {} for (branch_name, branch) in sorted_branches: if branch_name in sorted_branch_names[from_index:]: if branch and branch.target_type.value in target_types: if len(branches) < branches_count: branches[branch_name] = branch else: next_branch = branch_name break else: # As there is no 'target_types', we can do that much faster to_index = from_index + branches_count returned_branch_names = frozenset(sorted_branch_names[from_index:to_index]) branches = dict( (branch_name, branch) for (branch_name, branch) in snapshot.branches.items() if branch_name in returned_branch_names ) if to_index >= len(sorted_branch_names): next_branch = None else: next_branch = sorted_branch_names[to_index] return PartialBranches( id=snapshot_id, branches=branches, next_branch=next_branch, ) def snapshot_get_random(self) -> Sha1Git: return random.choice(list(self._snapshots)) def object_find_by_sha1_git(self, ids: List[Sha1Git]) -> Dict[Sha1Git, List[Dict]]: ret = super().object_find_by_sha1_git(ids) for id_ in ids: objs = self._objects.get(id_, []) ret[id_].extend([{"sha1_git": id_, "type": obj[0],} for obj in objs]) return ret def _convert_origin(self, t): if t is None: return None return t.to_dict() def origin_get_one(self, origin_url: str) -> Optional[Origin]: return self._origins.get(origin_url) def origin_get(self, origins: List[str]) -> Iterable[Optional[Origin]]: return [self.origin_get_one(origin_url) for origin_url in origins] def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]: return [self._convert_origin(self._origins_by_sha1.get(sha1)) for sha1 in sha1s] def origin_list( self, page_token: Optional[str] = None, limit: int = 100 ) -> PagedResult[Origin]: origin_urls = sorted(self._origins) from_ = bisect.bisect_left(origin_urls, page_token) if page_token else 0 next_page_token = None # Take one more origin so we can reuse it as the next page token if any origins = [Origin(url=url) for url in origin_urls[from_ : from_ + limit + 1]] if len(origins) > limit: # last origin id is the next page token next_page_token = str(origins[-1].url) # excluding that origin from the result to respect the limit size origins = origins[:limit] assert len(origins) <= limit return PagedResult(results=origins, next_page_token=next_page_token) def origin_search( self, url_pattern: str, page_token: Optional[str] = None, limit: int = 50, regexp: bool = False, with_visit: bool = False, ) -> PagedResult[Origin]: next_page_token = None offset = int(page_token) if page_token else 0 origins = self._origins.values() if regexp: pat = re.compile(url_pattern) origins = [orig for orig in origins if pat.search(orig.url)] else: origins = [orig for orig in origins if url_pattern in orig.url] if with_visit: filtered_origins = [] for orig in origins: visits = ( self._origin_visit_get_updated(ov.origin, ov.visit) for ov in self._origin_visits[orig.url] ) for ov in visits: snapshot = ov["snapshot"] if snapshot and snapshot in self._snapshots: filtered_origins.append(orig) break else: filtered_origins = origins # Take one more origin so we can reuse it as the next page token if any origins = filtered_origins[offset : offset + limit + 1] if len(origins) > limit: # next offset next_page_token = str(offset + limit) # excluding that origin from the result to respect the limit size origins = origins[:limit] assert len(origins) <= limit return PagedResult(results=origins, next_page_token=next_page_token) def origin_count( self, url_pattern: str, regexp: bool = False, with_visit: bool = False ) -> int: actual_page = self.origin_search( url_pattern, regexp=regexp, with_visit=with_visit, limit=len(self._origins), ) assert actual_page.next_page_token is None return len(actual_page.results) def origin_add(self, origins: List[Origin]) -> Dict[str, int]: added = 0 for origin in origins: if origin.url not in self._origins: self.origin_add_one(origin) added += 1 self._cql_runner.increment_counter("origin", added) return {"origin:add": added} def origin_add_one(self, origin: Origin) -> str: if origin.url not in self._origins: self.journal_writer.origin_add([origin]) self._origins[origin.url] = origin self._origins_by_sha1[origin_url_to_sha1(origin.url)] = origin self._origin_visits[origin.url] = [] self._objects[origin.url].append(("origin", origin.url)) return origin.url def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]: for visit in visits: origin = self.origin_get_one(visit.origin) if not origin: # Cannot add a visit without an origin raise StorageArgumentException("Unknown origin %s", visit.origin) all_visits = [] for visit in visits: origin_url = visit.origin if origin_url in self._origins: origin = self._origins[origin_url] if visit.visit: self.journal_writer.origin_visit_add([visit]) while len(self._origin_visits[origin_url]) < visit.visit: self._origin_visits[origin_url].append(None) self._origin_visits[origin_url][visit.visit - 1] = visit else: # visit ids are in the range [1, +inf[ visit_id = len(self._origin_visits[origin_url]) + 1 visit = attr.evolve(visit, visit=visit_id) self.journal_writer.origin_visit_add([visit]) self._origin_visits[origin_url].append(visit) visit_key = (origin_url, visit.visit) self._objects[visit_key].append(("origin_visit", None)) assert visit.visit is not None self._origin_visit_status_add_one( OriginVisitStatus( origin=visit.origin, visit=visit.visit, date=visit.date, status="created", snapshot=None, ) ) all_visits.append(visit) self._cql_runner.increment_counter("origin_visit", len(all_visits)) return all_visits def _origin_visit_status_add_one(self, visit_status: OriginVisitStatus) -> None: """Add an origin visit status without checks. If already present, do nothing. """ self.journal_writer.origin_visit_status_add([visit_status]) visit_key = (visit_status.origin, visit_status.visit) self._origin_visit_statuses.setdefault(visit_key, []) visit_statuses = self._origin_visit_statuses[visit_key] if visit_status not in visit_statuses: visit_statuses.append(visit_status) def origin_visit_status_add(self, visit_statuses: List[OriginVisitStatus],) -> None: # First round to check existence (fail early if any is ko) for visit_status in visit_statuses: origin_url = self.origin_get_one(visit_status.origin) if not origin_url: raise StorageArgumentException(f"Unknown origin {visit_status.origin}") for visit_status in visit_statuses: self._origin_visit_status_add_one(visit_status) def _origin_visit_status_get_latest( self, origin: str, visit_id: int ) -> Tuple[OriginVisit, OriginVisitStatus]: """Return a tuple of OriginVisit, latest associated OriginVisitStatus. """ assert visit_id >= 1 visit = self._origin_visits[origin][visit_id - 1] assert visit is not None visit_key = (origin, visit_id) visit_update = max(self._origin_visit_statuses[visit_key], key=lambda v: v.date) return visit, visit_update def _origin_visit_get_updated(self, origin: str, visit_id: int) -> Dict[str, Any]: """Merge origin visit and latest origin visit status """ visit, visit_update = self._origin_visit_status_get_latest(origin, visit_id) assert visit is not None and visit_update is not None return { # default to the values in visit **visit.to_dict(), # override with the last update **visit_update.to_dict(), # but keep the date of the creation of the origin visit "date": visit.date, } def origin_visit_get( self, origin: str, page_token: Optional[str] = None, order: ListOrder = ListOrder.ASC, limit: int = 10, ) -> PagedResult[OriginVisit]: next_page_token = None page_token = page_token or "0" if not isinstance(order, ListOrder): raise StorageArgumentException("order must be a ListOrder value") if not isinstance(page_token, str): raise StorageArgumentException("page_token must be a string.") visit_from = int(page_token) origin_url = self._get_origin_url(origin) extra_limit = limit + 1 visits = sorted( self._origin_visits.get(origin_url, []), key=lambda v: v.visit, reverse=(order == ListOrder.DESC), ) if visit_from > 0 and order == ListOrder.ASC: visits = [v for v in visits if v.visit > visit_from] elif visit_from > 0 and order == ListOrder.DESC: visits = [v for v in visits if v.visit < visit_from] visits = visits[:extra_limit] assert len(visits) <= extra_limit if len(visits) == extra_limit: visits = visits[:limit] next_page_token = str(visits[-1].visit) return PagedResult(results=visits, next_page_token=next_page_token) def origin_visit_find_by_date( self, origin: str, visit_date: datetime.datetime ) -> Optional[OriginVisit]: origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits: visits = self._origin_visits[origin_url] visit = min(visits, key=lambda v: (abs(v.date - visit_date), -v.visit)) return visit return None def origin_visit_get_by(self, origin: str, visit: int) -> Optional[OriginVisit]: origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits and visit <= len( self._origin_visits[origin_url] ): found_visit, _ = self._origin_visit_status_get_latest(origin, visit) return found_visit return None def origin_visit_get_latest( self, origin: str, type: Optional[str] = None, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, ) -> Optional[OriginVisit]: if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): raise StorageArgumentException( f"Unknown allowed statuses {','.join(allowed_statuses)}, only " f"{','.join(VISIT_STATUSES)} authorized" ) ori = self._origins.get(origin) if not ori: return None visits = sorted( self._origin_visits[ori.url], key=lambda v: (v.date, v.visit), reverse=True, ) for visit in visits: if type is not None and visit.type != type: continue visit_statuses = self._origin_visit_statuses[origin, visit.visit] if allowed_statuses is not None: visit_statuses = [ vs for vs in visit_statuses if vs.status in allowed_statuses ] if require_snapshot: visit_statuses = [vs for vs in visit_statuses if vs.snapshot] if visit_statuses: # we found visit statuses matching criteria visit_status = max(visit_statuses, key=lambda vs: (vs.date, vs.visit)) assert visit.origin == visit_status.origin assert visit.visit == visit_status.visit return visit return None def origin_visit_status_get( self, origin: str, visit: int, page_token: Optional[str] = None, order: ListOrder = ListOrder.ASC, limit: int = 10, ) -> PagedResult[OriginVisitStatus]: next_page_token = None date_from = None if page_token is not None: date_from = datetime.datetime.fromisoformat(page_token) visit_statuses = sorted( self._origin_visit_statuses.get((origin, visit), []), key=lambda v: v.date, reverse=(order == ListOrder.DESC), ) if date_from is not None: if order == ListOrder.ASC: visit_statuses = [v for v in visit_statuses if v.date >= date_from] elif order == ListOrder.DESC: visit_statuses = [v for v in visit_statuses if v.date <= date_from] # Take one more visit status so we can reuse it as the next page token if any visit_statuses = visit_statuses[: limit + 1] if len(visit_statuses) > limit: # last visit status date is the next page token next_page_token = str(visit_statuses[-1].date) # excluding that visit status from the result to respect the limit size visit_statuses = visit_statuses[:limit] return PagedResult(results=visit_statuses, next_page_token=next_page_token) def origin_visit_status_get_latest( self, origin_url: str, visit: int, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, ) -> Optional[OriginVisitStatus]: if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): raise StorageArgumentException( f"Unknown allowed statuses {','.join(allowed_statuses)}, only " f"{','.join(VISIT_STATUSES)} authorized" ) ori = self._origins.get(origin_url) if not ori: return None visit_key = (origin_url, visit) visits = self._origin_visit_statuses.get(visit_key) if not visits: return None if allowed_statuses is not None: visits = [visit for visit in visits if visit.status in allowed_statuses] if require_snapshot: visits = [visit for visit in visits if visit.snapshot] visit_status = max(visits, key=lambda v: (v.date, v.visit), default=None) return visit_status def _select_random_origin_visit_by_type(self, type: str) -> str: while True: url = random.choice(list(self._origin_visits.keys())) random_origin_visits = self._origin_visits[url] if random_origin_visits[0].type == type: return url def origin_visit_status_get_random( self, type: str ) -> Optional[Tuple[OriginVisit, OriginVisitStatus]]: url = self._select_random_origin_visit_by_type(type) random_origin_visits = copy.deepcopy(self._origin_visits[url]) random_origin_visits.reverse() back_in_the_day = now() - timedelta(weeks=12) # 3 months back # This should be enough for tests for visit in random_origin_visits: origin_visit, latest_visit_status = self._origin_visit_status_get_latest( url, visit.visit ) assert latest_visit_status is not None if ( origin_visit.date > back_in_the_day and latest_visit_status.status == "full" ): return origin_visit, latest_visit_status else: return None def raw_extrinsic_metadata_add(self, metadata: List[RawExtrinsicMetadata],) -> None: self.journal_writer.raw_extrinsic_metadata_add(metadata) for metadata_entry in metadata: authority_key = self._metadata_authority_key(metadata_entry.authority) if authority_key not in self._metadata_authorities: raise StorageArgumentException( f"Unknown authority {metadata_entry.authority}" ) fetcher_key = self._metadata_fetcher_key(metadata_entry.fetcher) if fetcher_key not in self._metadata_fetchers: raise StorageArgumentException( f"Unknown fetcher {metadata_entry.fetcher}" ) raw_extrinsic_metadata_list = self._raw_extrinsic_metadata[ metadata_entry.type ][metadata_entry.id][authority_key] for existing_raw_extrinsic_metadata in raw_extrinsic_metadata_list: if ( self._metadata_fetcher_key(existing_raw_extrinsic_metadata.fetcher) == fetcher_key and existing_raw_extrinsic_metadata.discovery_date == metadata_entry.discovery_date ): # Duplicate of an existing one; ignore it. break else: raw_extrinsic_metadata_list.add(metadata_entry) def raw_extrinsic_metadata_get( self, type: MetadataTargetType, id: Union[str, SWHID], authority: MetadataAuthority, after: Optional[datetime.datetime] = None, page_token: Optional[bytes] = None, limit: int = 1000, ) -> PagedResult[RawExtrinsicMetadata]: authority_key = self._metadata_authority_key(authority) if type == MetadataTargetType.ORIGIN: if isinstance(id, SWHID): raise StorageArgumentException( f"raw_extrinsic_metadata_get called with type='origin', " f"but provided id is an SWHID: {id!r}" ) else: if not isinstance(id, SWHID): raise StorageArgumentException( f"raw_extrinsic_metadata_get called with type!='origin', " f"but provided id is not an SWHID: {id!r}" ) if page_token is not None: (after_time, after_fetcher) = msgpack_loads(base64.b64decode(page_token)) after_fetcher = tuple(after_fetcher) if after is not None and after > after_time: raise StorageArgumentException( "page_token is inconsistent with the value of 'after'." ) entries = self._raw_extrinsic_metadata[type][id][authority_key].iter_after( (after_time, after_fetcher) ) elif after is not None: entries = self._raw_extrinsic_metadata[type][id][authority_key].iter_from( (after,) ) entries = (entry for entry in entries if entry.discovery_date > after) else: entries = iter(self._raw_extrinsic_metadata[type][id][authority_key]) if limit: entries = itertools.islice(entries, 0, limit + 1) results = [] for entry in entries: entry_authority = self._metadata_authorities[ self._metadata_authority_key(entry.authority) ] entry_fetcher = self._metadata_fetchers[ self._metadata_fetcher_key(entry.fetcher) ] if after: assert entry.discovery_date > after results.append( attr.evolve( entry, authority=attr.evolve(entry_authority, metadata=None), fetcher=attr.evolve(entry_fetcher, metadata=None), ) ) if len(results) > limit: results.pop() assert len(results) == limit last_result = results[-1] next_page_token: Optional[str] = base64.b64encode( msgpack_dumps( ( last_result.discovery_date, self._metadata_fetcher_key(last_result.fetcher), ) ) ).decode() else: next_page_token = None return PagedResult(next_page_token=next_page_token, results=results,) def metadata_fetcher_add(self, fetchers: List[MetadataFetcher]) -> None: self.journal_writer.metadata_fetcher_add(fetchers) for fetcher in fetchers: if fetcher.metadata is None: raise StorageArgumentException( "MetadataFetcher.metadata may not be None in metadata_fetcher_add." ) key = self._metadata_fetcher_key(fetcher) if key not in self._metadata_fetchers: self._metadata_fetchers[key] = fetcher def metadata_fetcher_get( self, name: str, version: str ) -> Optional[MetadataFetcher]: return self._metadata_fetchers.get( self._metadata_fetcher_key(MetadataFetcher(name=name, version=version)) ) def metadata_authority_add(self, authorities: List[MetadataAuthority]) -> None: self.journal_writer.metadata_authority_add(authorities) for authority in authorities: if authority.metadata is None: raise StorageArgumentException( "MetadataAuthority.metadata may not be None in " "metadata_authority_add." ) key = self._metadata_authority_key(authority) self._metadata_authorities[key] = authority def metadata_authority_get( self, type: MetadataAuthorityType, url: str ) -> Optional[MetadataAuthority]: return self._metadata_authorities.get( self._metadata_authority_key(MetadataAuthority(type=type, url=url)) ) def _get_origin_url(self, origin): if isinstance(origin, str): return origin else: raise TypeError("origin must be a string.") def _person_add(self, person): key = ("person", person.fullname) if key not in self._objects: self._persons[person.fullname] = person self._objects[key].append(key) return self._persons[person.fullname] @staticmethod def _metadata_fetcher_key(fetcher: MetadataFetcher) -> FetcherKey: return (fetcher.name, fetcher.version) @staticmethod def _metadata_authority_key(authority: MetadataAuthority) -> Hashable: return (authority.type, authority.url) def diff_directories(self, from_dir, to_dir, track_renaming=False): raise NotImplementedError("InMemoryStorage.diff_directories") def diff_revisions(self, from_rev, to_rev, track_renaming=False): raise NotImplementedError("InMemoryStorage.diff_revisions") def diff_revision(self, revision, track_renaming=False): raise NotImplementedError("InMemoryStorage.diff_revision") def clear_buffers(self, object_types: Optional[List[str]] = None) -> None: """Do nothing """ return None def flush(self, object_types: Optional[List[str]] = None) -> Dict: return {} diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py index 40a858a7..7343326b 100644 --- a/swh/storage/tests/test_api_client.py +++ b/swh/storage/tests/test_api_client.py @@ -1,66 +1,73 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest import swh.storage.api.server as server import swh.storage.storage from swh.storage import get_storage from swh.storage.tests.test_storage import TestStorageGeneratedData # noqa from swh.storage.tests.test_storage import TestStorage as _TestStorage # tests are executed using imported classes (TestStorage and # TestStorageGeneratedData) using overloaded swh_storage fixture # below @pytest.fixture def app_server(): server.storage = swh.storage.get_storage( cls="memory", journal_writer={"cls": "memory"} ) yield server @pytest.fixture def app(app_server): return app_server.app @pytest.fixture def swh_rpc_client_class(): def storage_factory(**kwargs): storage_config = { "cls": "remote", **kwargs, } return get_storage(**storage_config) return storage_factory @pytest.fixture def swh_storage(swh_rpc_client, app_server): # This version of the swh_storage fixture uses the swh_rpc_client fixture # to instantiate a RemoteStorage (see swh_rpc_client_class above) that # proxies, via the swh.core RPC mechanism, the local (in memory) storage # configured in the app_server fixture above. # # Also note that, for the sake of # making it easier to write tests, the in-memory journal writer of the # in-memory backend storage is attached to the RemoteStorage as its # journal_writer attribute. storage = swh_rpc_client journal_writer = getattr(storage, "journal_writer", None) storage.journal_writer = app_server.storage.journal_writer yield storage storage.journal_writer = journal_writer class TestStorage(_TestStorage): @pytest.mark.skip("content_update is not yet implemented for Cassandra") def test_content_update(self): pass + + @pytest.mark.skip( + 'The "person" table of the pgsql is a legacy thing, and not ' + "supported by the cassandra backend." + ) + def test_person_fullname_unicity(self): + pass diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py index 09c16950..8ddc0622 100644 --- a/swh/storage/tests/test_in_memory.py +++ b/swh/storage/tests/test_in_memory.py @@ -1,163 +1,170 @@ # Copyright (C) 2018-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import dataclasses import pytest from swh.storage.cassandra.model import BaseRow from swh.storage.in_memory import SortedList, Table from swh.storage.tests.test_storage import TestStorage as _TestStorage from swh.storage.tests.test_storage import TestStorageGeneratedData # noqa # tests are executed using imported classes (TestStorage and # TestStorageGeneratedData) using overloaded swh_storage fixture # below @pytest.fixture def swh_storage_backend_config(): yield { "cls": "memory", "journal_writer": {"cls": "memory",}, } parametrize = pytest.mark.parametrize( "items", [ [1, 2, 3, 4, 5, 6, 10, 100], [10, 100, 6, 5, 4, 3, 2, 1], [10, 4, 5, 6, 1, 2, 3, 100], ], ) @parametrize def test_sorted_list_iter(items): list1 = SortedList() for item in items: list1.add(item) assert list(list1) == sorted(items) list2 = SortedList(items) assert list(list2) == sorted(items) @parametrize def test_sorted_list_iter__key(items): list1 = SortedList(key=lambda item: -item) for item in items: list1.add(item) assert list(list1) == list(reversed(sorted(items))) list2 = SortedList(items, key=lambda item: -item) assert list(list2) == list(reversed(sorted(items))) @parametrize def test_sorted_list_iter_from(items): list_ = SortedList(items) for split in items: expected = sorted(item for item in items if item >= split) assert list(list_.iter_from(split)) == expected, f"split: {split}" @parametrize def test_sorted_list_iter_from__key(items): list_ = SortedList(items, key=lambda item: -item) for split in items: expected = reversed(sorted(item for item in items if item <= split)) assert list(list_.iter_from(-split)) == list(expected), f"split: {split}" @parametrize def test_sorted_list_iter_after(items): list_ = SortedList(items) for split in items: expected = sorted(item for item in items if item > split) assert list(list_.iter_after(split)) == expected, f"split: {split}" @parametrize def test_sorted_list_iter_after__key(items): list_ = SortedList(items, key=lambda item: -item) for split in items: expected = reversed(sorted(item for item in items if item < split)) assert list(list_.iter_after(-split)) == list(expected), f"split: {split}" @dataclasses.dataclass class Row(BaseRow): PARTITION_KEY = ("col1", "col2") CLUSTERING_KEY = ("col3", "col4") col1: str col2: str col3: str col4: str col5: str col6: int def test_table_keys(): table = Table(Row) primary_key = ("foo", "bar", "baz", "qux") partition_key = ("foo", "bar") clustering_key = ("baz", "qux") row = Row(col1="foo", col2="bar", col3="baz", col4="qux", col5="quux", col6=4) assert table.partition_key(row) == partition_key assert table.clustering_key(row) == clustering_key assert table.primary_key(row) == primary_key assert table.primary_key_from_dict(row.to_dict()) == primary_key assert table.split_primary_key(primary_key) == (partition_key, clustering_key) def test_table(): table = Table(Row) row1 = Row(col1="foo", col2="bar", col3="baz", col4="qux", col5="quux", col6=4) row2 = Row(col1="foo", col2="bar", col3="baz", col4="qux2", col5="quux", col6=4) row3 = Row(col1="foo", col2="bar", col3="baz", col4="qux1", col5="quux", col6=4) row4 = Row(col1="foo", col2="bar2", col3="baz", col4="qux1", col5="quux", col6=4) partition_key = ("foo", "bar") partition_key4 = ("foo", "bar2") primary_key1 = ("foo", "bar", "baz", "qux") primary_key2 = ("foo", "bar", "baz", "qux2") primary_key3 = ("foo", "bar", "baz", "qux1") primary_key4 = ("foo", "bar2", "baz", "qux1") table.insert(row1) table.insert(row2) table.insert(row3) table.insert(row4) assert table.get_from_primary_key(primary_key1) == row1 assert table.get_from_primary_key(primary_key2) == row2 assert table.get_from_primary_key(primary_key3) == row3 assert table.get_from_primary_key(primary_key4) == row4 # order matters assert list(table.get_from_token(table.token(partition_key))) == [row1, row3, row2] # order matters assert list(table.get_from_partition_key(partition_key)) == [row1, row3, row2] assert list(table.get_from_partition_key(partition_key4)) == [row4] all_rows = list(table.iter_all()) assert len(all_rows) == 4 for row in (row1, row2, row3, row4): assert (table.primary_key(row), row) in all_rows class TestInMemoryStorage(_TestStorage): + @pytest.mark.skip( + 'The "person" table of the pgsql is a legacy thing, and not ' + "supported by the cassandra backend." + ) + def test_person_fullname_unicity(self): + pass + @pytest.mark.skip("content_update is not yet implemented for Cassandra") def test_content_update(self): pass diff --git a/swh/storage/tests/test_replay.py b/swh/storage/tests/test_replay.py index c8ba472f..afec6ac5 100644 --- a/swh/storage/tests/test_replay.py +++ b/swh/storage/tests/test_replay.py @@ -1,411 +1,411 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import dataclasses import datetime import functools import logging from typing import Any, Container, Dict, Optional import attr import pytest from swh.model.hashutil import hash_to_hex, MultiHash, DEFAULT_ALGORITHMS from swh.storage import get_storage from swh.storage.cassandra.model import ContentRow, SkippedContentRow from swh.storage.in_memory import InMemoryStorage from swh.storage.replay import process_replay_objects from swh.journal.serializers import key_to_kafka, value_to_kafka from swh.journal.client import JournalClient from swh.journal.tests.journal_data import ( TEST_OBJECTS, DUPLICATE_CONTENTS, ) UTC = datetime.timezone.utc def nullify_ctime(obj): if isinstance(obj, (ContentRow, SkippedContentRow)): return dataclasses.replace(obj, ctime=None) else: return obj @pytest.fixture() def replayer_storage_and_client( kafka_prefix: str, kafka_consumer_group: str, kafka_server: str ): journal_writer_config = { "cls": "kafka", "brokers": [kafka_server], "client_id": "kafka_writer", "prefix": kafka_prefix, } storage_config: Dict[str, Any] = { "cls": "memory", "journal_writer": journal_writer_config, } storage = get_storage(**storage_config) replayer = JournalClient( brokers=kafka_server, group_id=kafka_consumer_group, prefix=kafka_prefix, stop_on_eof=True, ) yield storage, replayer def test_storage_replayer(replayer_storage_and_client, caplog): """Optimal replayer scenario. This: - writes objects to a source storage - replayer consumes objects from the topic and replays them - a destination storage is filled from this In the end, both storages should have the same content. """ src, replayer = replayer_storage_and_client # Fill Kafka using a source storage nb_sent = 0 for object_type, objects in TEST_OBJECTS.items(): method = getattr(src, object_type + "_add") method(objects) if object_type == "origin_visit": nb_sent += len(objects) # origin-visit-add adds origin-visit-status as well nb_sent += len(objects) caplog.set_level(logging.ERROR, "swh.journal.replay") # Fill the destination storage from Kafka dst = get_storage(cls="memory") worker_fn = functools.partial(process_replay_objects, storage=dst) nb_inserted = replayer.process(worker_fn) assert nb_sent == nb_inserted _check_replayed(src, dst) collision = 0 for record in caplog.records: logtext = record.getMessage() if "Colliding contents:" in logtext: collision += 1 assert collision == 0, "No collision should be detected" def test_storage_play_with_collision(replayer_storage_and_client, caplog): """Another replayer scenario with collisions. This: - writes objects to the topic, including colliding contents - replayer consumes objects from the topic and replay them - This drops the colliding contents from the replay when detected """ src, replayer = replayer_storage_and_client # Fill Kafka using a source storage nb_sent = 0 for object_type, objects in TEST_OBJECTS.items(): method = getattr(src, object_type + "_add") method(objects) if object_type == "origin_visit": nb_sent += len(objects) # origin-visit-add adds origin-visit-status as well nb_sent += len(objects) # Create collision in input data # These should not be written in the destination producer = src.journal_writer.journal.producer prefix = src.journal_writer.journal._prefix for content in DUPLICATE_CONTENTS: topic = f"{prefix}.content" key = content.sha1 now = datetime.datetime.now(tz=UTC) content = attr.evolve(content, ctime=now) producer.produce( topic=topic, key=key_to_kafka(key), value=value_to_kafka(content.to_dict()), ) nb_sent += 1 producer.flush() caplog.set_level(logging.ERROR, "swh.journal.replay") # Fill the destination storage from Kafka dst = get_storage(cls="memory") worker_fn = functools.partial(process_replay_objects, storage=dst) nb_inserted = replayer.process(worker_fn) assert nb_sent == nb_inserted # check the logs for the collision being properly detected nb_collisions = 0 actual_collision: Dict for record in caplog.records: logtext = record.getMessage() if "Collision detected:" in logtext: nb_collisions += 1 actual_collision = record.args["collision"] assert nb_collisions == 1, "1 collision should be detected" algo = "sha1" assert actual_collision["algo"] == algo expected_colliding_hash = hash_to_hex(DUPLICATE_CONTENTS[0].get_hash(algo)) assert actual_collision["hash"] == expected_colliding_hash actual_colliding_hashes = actual_collision["objects"] assert len(actual_colliding_hashes) == len(DUPLICATE_CONTENTS) for content in DUPLICATE_CONTENTS: expected_content_hashes = { k: hash_to_hex(v) for k, v in content.hashes().items() } assert expected_content_hashes in actual_colliding_hashes # all objects from the src should exists in the dst storage _check_replayed(src, dst, exclude=["contents"]) # but the dst has one content more (one of the 2 colliding ones) assert ( len(list(src._cql_runner._contents.iter_all())) == len(list(dst._cql_runner._contents.iter_all())) - 1 ) def test_replay_skipped_content(replayer_storage_and_client): """Test the 'skipped_content' topic is properly replayed.""" src, replayer = replayer_storage_and_client _check_replay_skipped_content(src, replayer, "skipped_content") def test_replay_skipped_content_bwcompat(replayer_storage_and_client): """Test the 'content' topic can be used to replay SkippedContent objects.""" src, replayer = replayer_storage_and_client _check_replay_skipped_content(src, replayer, "content") # utility functions def _check_replayed( src: InMemoryStorage, dst: InMemoryStorage, exclude: Optional[Container] = None ): """Simple utility function to compare the content of 2 in_memory storages """ expected_persons = set(src._persons.values()) got_persons = set(dst._persons.values()) assert got_persons == expected_persons for attr_ in ( - "revisions", "releases", "snapshots", "origins", "origin_visits", "origin_visit_statuses", ): if exclude and attr_ in exclude: continue expected_objects = sorted(getattr(src, f"_{attr_}").items()) got_objects = sorted(getattr(dst, f"_{attr_}").items()) assert got_objects == expected_objects, f"Mismatch object list for {attr_}" for attr_ in ( "contents", "skipped_contents", "directories", + "revisions", ): if exclude and attr_ in exclude: continue expected_objects = [ (id, nullify_ctime(obj)) for id, obj in sorted(getattr(src._cql_runner, f"_{attr_}").iter_all()) ] got_objects = [ (id, nullify_ctime(obj)) for id, obj in sorted(getattr(dst._cql_runner, f"_{attr_}").iter_all()) ] assert got_objects == expected_objects, f"Mismatch object list for {attr_}" def _check_replay_skipped_content(storage, replayer, topic): skipped_contents = _gen_skipped_contents(100) nb_sent = len(skipped_contents) producer = storage.journal_writer.journal.producer prefix = storage.journal_writer.journal._prefix for i, obj in enumerate(skipped_contents): producer.produce( topic=f"{prefix}.{topic}", key=key_to_kafka({"sha1": obj["sha1"]}), value=value_to_kafka(obj), ) producer.flush() dst_storage = get_storage(cls="memory") worker_fn = functools.partial(process_replay_objects, storage=dst_storage) nb_inserted = replayer.process(worker_fn) assert nb_sent == nb_inserted for content in skipped_contents: assert not storage.content_find({"sha1": content["sha1"]}) # no skipped_content_find API endpoint, so use this instead assert not list(dst_storage.skipped_content_missing(skipped_contents)) def _updated(d1, d2): d1.update(d2) d1.pop("data", None) return d1 def _gen_skipped_contents(n=10): # we do not use the hypothesis strategy here because this does not play well with # pytest fixtures, and it makes test execution very slow algos = DEFAULT_ALGORITHMS | {"length"} now = datetime.datetime.now(tz=UTC) return [ _updated( MultiHash.from_data(data=f"foo{i}".encode(), hash_names=algos).digest(), { "status": "absent", "reason": "why not", "origin": f"https://somewhere/{i}", "ctime": now, }, ) for i in range(n) ] def test_storage_play_anonymized( kafka_prefix: str, kafka_consumer_group: str, kafka_server: str ): """Optimal replayer scenario. This: - writes objects to the topic - replayer consumes objects from the topic and replay them """ writer_config = { "cls": "kafka", "brokers": [kafka_server], "client_id": "kafka_writer", "prefix": kafka_prefix, "anonymize": True, } src_config: Dict[str, Any] = {"cls": "memory", "journal_writer": writer_config} storage = get_storage(**src_config) # Fill the src storage nb_sent = 0 for obj_type, objs in TEST_OBJECTS.items(): if obj_type in ("origin_visit", "origin_visit_status"): # these are unrelated with what we want to test here continue method = getattr(storage, obj_type + "_add") method(objs) nb_sent += len(objs) # Fill a destination storage from Kafka **using anonymized topics** dst_storage = get_storage(cls="memory") replayer = JournalClient( brokers=kafka_server, group_id=kafka_consumer_group, prefix=kafka_prefix, stop_after_objects=nb_sent, privileged=False, ) worker_fn = functools.partial(process_replay_objects, storage=dst_storage) nb_inserted = replayer.process(worker_fn) assert nb_sent == nb_inserted check_replayed(storage, dst_storage, expected_anonymized=True) # Fill a destination storage from Kafka **with stock (non-anonymized) topics** dst_storage = get_storage(cls="memory") replayer = JournalClient( brokers=kafka_server, group_id=kafka_consumer_group, prefix=kafka_prefix, stop_after_objects=nb_sent, privileged=True, ) worker_fn = functools.partial(process_replay_objects, storage=dst_storage) nb_inserted = replayer.process(worker_fn) assert nb_sent == nb_inserted check_replayed(storage, dst_storage, expected_anonymized=False) def check_replayed(src, dst, expected_anonymized=False): """Simple utility function to compare the content of 2 in_memory storages If expected_anonymized is True, objects from the source storage are anonymized before comparing with the destination storage. """ def maybe_anonymize(attr_, row): if expected_anonymized: if hasattr(row, "anonymize"): # for model objects; cases below are for BaseRow objects row = row.anonymize() or row elif attr_ == "releases": row = dataclasses.replace(row, author=row.author.anonymize()) elif attr_ == "revisions": row = dataclasses.replace( row, author=row.author.anonymize(), committer=row.committer.anonymize(), ) return row expected_persons = { maybe_anonymize("persons", person) for person in src._persons.values() } got_persons = set(dst._persons.values()) assert got_persons == expected_persons for attr_ in ( - "revisions", "releases", "snapshots", "origins", "origin_visit_statuses", ): expected_objects = [ (id, maybe_anonymize(attr_, obj)) for id, obj in sorted(getattr(src, f"_{attr_}").items()) ] got_objects = [ (id, obj) for id, obj in sorted(getattr(dst, f"_{attr_}").items()) ] assert got_objects == expected_objects, f"Mismatch object list for {attr_}" for attr_ in ( "contents", "skipped_contents", "directories", + "revisions", ): expected_objects = [ (id, nullify_ctime(maybe_anonymize(attr_, obj))) for id, obj in sorted(getattr(src._cql_runner, f"_{attr_}").iter_all()) ] got_objects = [ (id, nullify_ctime(obj)) for id, obj in sorted(getattr(dst._cql_runner, f"_{attr_}").iter_all()) ] assert got_objects == expected_objects, f"Mismatch object list for {attr_}"