diff --git a/requirements.txt b/requirements.txt index c712cdd2..619b1325 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,10 @@ click flask psycopg2 vcversioner aiohttp tenacity cassandra-driver >= 3.19.0, != 3.21.0 deprecated typing-extensions +mypy_extensions diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py index 5f216304..e10e19c7 100644 --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -1,1017 +1,1034 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import dataclasses import datetime import functools import logging import random from typing import ( Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Type, TypeVar, ) from cassandra import CoordinationFailure from cassandra.cluster import Cluster, EXEC_PROFILE_DEFAULT, ExecutionProfile, ResultSet from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy from cassandra.query import PreparedStatement, BoundStatement, dict_factory from tenacity import ( retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type, ) +from mypy_extensions import NamedArg from swh.model.model import ( Content, SkippedContent, Sha1Git, TimestampWithTimezone, Timestamp, Person, ) from swh.storage.interface import ListOrder from .common import TOKEN_BEGIN, TOKEN_END, hash_url, remove_keys from .model import ( BaseRow, ContentRow, DirectoryEntryRow, DirectoryRow, MetadataAuthorityRow, MetadataFetcherRow, ObjectCountRow, OriginRow, OriginVisitRow, OriginVisitStatusRow, RawExtrinsicMetadataRow, ReleaseRow, RevisionParentRow, RevisionRow, SkippedContentRow, SnapshotBranchRow, SnapshotRow, ) from .schema import CREATE_TABLES_QUERIES, HASH_ALGORITHMS logger = logging.getLogger(__name__) _execution_profiles = { EXEC_PROFILE_DEFAULT: ExecutionProfile( load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy()), row_factory=dict_factory, ), } # Configuration for cassandra-driver's access to servers: # * hit the right server directly when sending a query (TokenAwarePolicy), # * if there's more than one, then pick one at random that's in the same # datacenter as the client (DCAwareRoundRobinPolicy) def create_keyspace( hosts: List[str], keyspace: str, port: int = 9042, *, durable_writes=True ): cluster = Cluster(hosts, port=port, execution_profiles=_execution_profiles) session = cluster.connect() extra_params = "" if not durable_writes: extra_params = "AND durable_writes = false" session.execute( """CREATE KEYSPACE IF NOT EXISTS "%s" WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 } %s; """ % (keyspace, extra_params) ) session.execute('USE "%s"' % keyspace) for query in CREATE_TABLES_QUERIES: session.execute(query) -T = TypeVar("T") +TRet = TypeVar("TRet") -def _prepared_statement(query: str) -> Callable[[Callable[..., T]], Callable[..., T]]: +def _prepared_statement( + query: str, +) -> Callable[[Callable[..., TRet]], Callable[..., TRet]]: """Returns a decorator usable on methods of CqlRunner, to inject them with a 'statement' argument, that is a prepared statement corresponding to the query. This only works on methods of CqlRunner, as preparing a statement requires a connection to a Cassandra server.""" def decorator(f): @functools.wraps(f) - def newf(self, *args, **kwargs) -> T: + def newf(self, *args, **kwargs) -> TRet: if f.__name__ not in self._prepared_statements: statement: PreparedStatement = self._session.prepare(query) self._prepared_statements[f.__name__] = statement return f( self, *args, **kwargs, statement=self._prepared_statements[f.__name__] ) return newf return decorator -def _prepared_insert_statement(table_name: str, columns: List[str]): +TArg = TypeVar("TArg") +TSelf = TypeVar("TSelf") + + +def _prepared_insert_statement( + table_name: str, columns: List[str] +) -> Callable[ + [Callable[[TSelf, TArg, NamedArg(Any, "statement")], TRet]], # noqa + Callable[[TSelf, TArg], TRet], +]: """Shorthand for using `_prepared_statement` for `INSERT INTO` statements.""" return _prepared_statement( "INSERT INTO %s (%s) VALUES (%s)" % (table_name, ", ".join(columns), ", ".join("?" for _ in columns),) ) -def _prepared_exists_statement(table_name: str): +def _prepared_exists_statement( + table_name: str, +) -> Callable[ + [Callable[[TSelf, TArg, NamedArg(Any, "statement")], TRet]], # noqa + Callable[[TSelf, TArg], TRet], +]: """Shorthand for using `_prepared_statement` for queries that only check which ids in a list exist in the table.""" return _prepared_statement(f"SELECT id FROM {table_name} WHERE id IN ?") class CqlRunner: """Class managing prepared statements and building queries to be sent to Cassandra.""" def __init__(self, hosts: List[str], keyspace: str, port: int): self._cluster = Cluster( hosts, port=port, execution_profiles=_execution_profiles ) self._session = self._cluster.connect(keyspace) self._cluster.register_user_type( keyspace, "microtimestamp_with_timezone", TimestampWithTimezone ) self._cluster.register_user_type(keyspace, "microtimestamp", Timestamp) self._cluster.register_user_type(keyspace, "person", Person) self._prepared_statements: Dict[str, PreparedStatement] = {} ########################## # Common utility functions ########################## MAX_RETRIES = 3 @retry( wait=wait_random_exponential(multiplier=1, max=10), stop=stop_after_attempt(MAX_RETRIES), retry=retry_if_exception_type(CoordinationFailure), ) def _execute_with_retries(self, statement, args) -> ResultSet: return self._session.execute(statement, args, timeout=1000.0) @_prepared_statement( "UPDATE object_count SET count = count + ? " "WHERE partition_key = 0 AND object_type = ?" ) def _increment_counter( self, object_type: str, nb: int, *, statement: PreparedStatement ) -> None: self._execute_with_retries(statement, [nb, object_type]) def _add_one(self, statement, object_type: Optional[str], obj: BaseRow) -> None: if object_type: self._increment_counter(object_type, 1) self._execute_with_retries(statement, dataclasses.astuple(obj)) _T = TypeVar("_T", bound=BaseRow) def _get_random_row(self, row_class: Type[_T], statement) -> Optional[_T]: # noqa """Takes a prepared statement of the form "SELECT * FROM WHERE token() > ? LIMIT 1" and uses it to return a random row""" token = random.randint(TOKEN_BEGIN, TOKEN_END) rows = self._execute_with_retries(statement, [token]) if not rows: # There are no row with a greater token; wrap around to get # the row with the smallest token rows = self._execute_with_retries(statement, [TOKEN_BEGIN]) if rows: return row_class.from_dict(rows.one()) # type: ignore else: return None def _missing(self, statement, ids): rows = self._execute_with_retries(statement, [ids]) found_ids = {row["id"] for row in rows} return [id_ for id_ in ids if id_ not in found_ids] ########################## # 'content' table ########################## _content_pk = ["sha1", "sha1_git", "sha256", "blake2s256"] def _content_add_finalize(self, statement: BoundStatement) -> None: """Returned currified by content_add_prepare, to be called when the content row should be added to the primary table.""" self._execute_with_retries(statement, None) self._increment_counter("content", 1) @_prepared_insert_statement("content", ContentRow.cols()) def content_add_prepare( self, content: ContentRow, *, statement ) -> Tuple[int, Callable[[], None]]: """Prepares insertion of a Content to the main 'content' table. Returns a token (to be used in secondary tables), and a function to be called to perform the insertion in the main table.""" statement = statement.bind(dataclasses.astuple(content)) # Type used for hashing keys (usually, it will be # cassandra.metadata.Murmur3Token) token_class = self._cluster.metadata.token_map.token_class # Token of the row when it will be inserted. This is equivalent to # "SELECT token({', '.join(self._content_pk)}) FROM content WHERE ..." # after the row is inserted; but we need the token to insert in the # index tables *before* inserting to the main 'content' table token = token_class.from_key(statement.routing_key).value assert TOKEN_BEGIN <= token <= TOKEN_END # Function to be called after the indexes contain their respective # row finalizer = functools.partial(self._content_add_finalize, statement) return (token, finalizer) @_prepared_statement( "SELECT * FROM content WHERE " + " AND ".join(map("%s = ?".__mod__, HASH_ALGORITHMS)) ) def content_get_from_pk( self, content_hashes: Dict[str, bytes], *, statement ) -> Optional[ContentRow]: rows = list( self._execute_with_retries( statement, [content_hashes[algo] for algo in HASH_ALGORITHMS] ) ) assert len(rows) <= 1 if rows: return ContentRow(**rows[0]) else: return None @_prepared_statement( "SELECT * FROM content WHERE token(" + ", ".join(_content_pk) + ") = ?" ) def content_get_from_token(self, token, *, statement) -> Iterable[ContentRow]: return map(ContentRow.from_dict, self._execute_with_retries(statement, [token])) @_prepared_statement( "SELECT * FROM content WHERE token(%s) > ? LIMIT 1" % ", ".join(_content_pk) ) def content_get_random(self, *, statement) -> Optional[ContentRow]: return self._get_random_row(ContentRow, statement) @_prepared_statement( ( "SELECT token({0}) AS tok, {1} FROM content " "WHERE token({0}) >= ? AND token({0}) <= ? LIMIT ?" ).format(", ".join(_content_pk), ", ".join(ContentRow.cols())) ) def content_get_token_range( self, start: int, end: int, limit: int, *, statement ) -> Iterable[Tuple[int, ContentRow]]: """Returns an iterable of (token, row)""" return ( (row["tok"], ContentRow.from_dict(remove_keys(row, ("tok",)))) for row in self._execute_with_retries(statement, [start, end, limit]) ) ########################## # 'content_by_*' tables ########################## @_prepared_statement( "SELECT sha1_git AS id FROM content_by_sha1_git WHERE sha1_git IN ?" ) def content_missing_by_sha1_git( self, ids: List[bytes], *, statement ) -> List[bytes]: return self._missing(statement, ids) def content_index_add_one(self, algo: str, content: Content, token: int) -> None: """Adds a row mapping content[algo] to the token of the Content in the main 'content' table.""" query = ( f"INSERT INTO content_by_{algo} ({algo}, target_token) " f"VALUES (%s, %s)" ) self._execute_with_retries(query, [content.get_hash(algo), token]) def content_get_tokens_from_single_hash( self, algo: str, hash_: bytes ) -> Iterable[int]: assert algo in HASH_ALGORITHMS query = f"SELECT target_token FROM content_by_{algo} WHERE {algo} = %s" return ( row["target_token"] for row in self._execute_with_retries(query, [hash_]) ) ########################## # 'skipped_content' table ########################## _skipped_content_pk = ["sha1", "sha1_git", "sha256", "blake2s256"] _magic_null_pk = b"" """ NULLs (or all-empty blobs) are not allowed in primary keys; instead use a special value that can't possibly be a valid hash. """ def _skipped_content_add_finalize(self, statement: BoundStatement) -> None: """Returned currified by skipped_content_add_prepare, to be called when the content row should be added to the primary table.""" self._execute_with_retries(statement, None) self._increment_counter("skipped_content", 1) @_prepared_insert_statement("skipped_content", SkippedContentRow.cols()) def skipped_content_add_prepare( self, content, *, statement ) -> Tuple[int, Callable[[], None]]: """Prepares insertion of a Content to the main 'skipped_content' table. Returns a token (to be used in secondary tables), and a function to be called to perform the insertion in the main table.""" # Replace NULLs (which are not allowed in the partition key) with # an empty byte string for key in self._skipped_content_pk: if getattr(content, key) is None: setattr(content, key, self._magic_null_pk) statement = statement.bind(dataclasses.astuple(content)) # Type used for hashing keys (usually, it will be # cassandra.metadata.Murmur3Token) token_class = self._cluster.metadata.token_map.token_class # Token of the row when it will be inserted. This is equivalent to # "SELECT token({', '.join(self._content_pk)}) # FROM skipped_content WHERE ..." # after the row is inserted; but we need the token to insert in the # index tables *before* inserting to the main 'skipped_content' table token = token_class.from_key(statement.routing_key).value assert TOKEN_BEGIN <= token <= TOKEN_END # Function to be called after the indexes contain their respective # row finalizer = functools.partial(self._skipped_content_add_finalize, statement) return (token, finalizer) @_prepared_statement( "SELECT * FROM skipped_content WHERE " + " AND ".join(map("%s = ?".__mod__, HASH_ALGORITHMS)) ) def skipped_content_get_from_pk( self, content_hashes: Dict[str, bytes], *, statement ) -> Optional[SkippedContentRow]: rows = list( self._execute_with_retries( statement, [ content_hashes[algo] or self._magic_null_pk for algo in HASH_ALGORITHMS ], ) ) assert len(rows) <= 1 if rows: # TODO: convert _magic_null_pk back to None? return SkippedContentRow.from_dict(rows[0]) else: return None ########################## # 'skipped_content_by_*' tables ########################## def skipped_content_index_add_one( self, algo: str, content: SkippedContent, token: int ) -> None: """Adds a row mapping content[algo] to the token of the SkippedContent in the main 'skipped_content' table.""" query = ( f"INSERT INTO skipped_content_by_{algo} ({algo}, target_token) " f"VALUES (%s, %s)" ) self._execute_with_retries( query, [content.get_hash(algo) or self._magic_null_pk, token] ) ########################## # 'revision' table ########################## @_prepared_exists_statement("revision") def revision_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) @_prepared_insert_statement("revision", RevisionRow.cols()) def revision_add_one(self, revision: RevisionRow, *, statement) -> None: self._add_one(statement, "revision", revision) @_prepared_statement("SELECT id FROM revision WHERE id IN ?") def revision_get_ids(self, revision_ids, *, statement) -> Iterable[int]: return ( row["id"] for row in self._execute_with_retries(statement, [revision_ids]) ) @_prepared_statement("SELECT * FROM revision WHERE id IN ?") def revision_get(self, revision_ids, *, statement) -> Iterable[RevisionRow]: return map( RevisionRow.from_dict, self._execute_with_retries(statement, [revision_ids]) ) @_prepared_statement("SELECT * FROM revision WHERE token(id) > ? LIMIT 1") def revision_get_random(self, *, statement) -> Optional[RevisionRow]: return self._get_random_row(RevisionRow, statement) ########################## # 'revision_parent' table ########################## @_prepared_insert_statement("revision_parent", RevisionParentRow.cols()) def revision_parent_add_one( self, revision_parent: RevisionParentRow, *, statement ) -> None: self._add_one(statement, None, revision_parent) @_prepared_statement("SELECT parent_id FROM revision_parent WHERE id = ?") def revision_parent_get( self, revision_id: Sha1Git, *, statement ) -> Iterable[bytes]: return ( row["parent_id"] for row in self._execute_with_retries(statement, [revision_id]) ) ########################## # 'release' table ########################## @_prepared_exists_statement("release") def release_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) @_prepared_insert_statement("release", ReleaseRow.cols()) def release_add_one(self, release: ReleaseRow, *, statement) -> None: self._add_one(statement, "release", release) @_prepared_statement("SELECT * FROM release WHERE id in ?") def release_get(self, release_ids: List[str], *, statement) -> Iterable[ReleaseRow]: return map( ReleaseRow.from_dict, self._execute_with_retries(statement, [release_ids]) ) @_prepared_statement("SELECT * FROM release WHERE token(id) > ? LIMIT 1") def release_get_random(self, *, statement) -> Optional[ReleaseRow]: return self._get_random_row(ReleaseRow, statement) ########################## # 'directory' table ########################## @_prepared_exists_statement("directory") def directory_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) @_prepared_insert_statement("directory", DirectoryRow.cols()) def directory_add_one(self, directory: DirectoryRow, *, statement) -> None: """Called after all calls to directory_entry_add_one, to commit/finalize the directory.""" self._add_one(statement, "directory", directory) @_prepared_statement("SELECT * FROM directory WHERE token(id) > ? LIMIT 1") def directory_get_random(self, *, statement) -> Optional[DirectoryRow]: return self._get_random_row(DirectoryRow, statement) ########################## # 'directory_entry' table ########################## @_prepared_insert_statement("directory_entry", DirectoryEntryRow.cols()) def directory_entry_add_one(self, entry: DirectoryEntryRow, *, statement) -> None: self._add_one(statement, None, entry) @_prepared_statement("SELECT * FROM directory_entry WHERE directory_id IN ?") def directory_entry_get( self, directory_ids, *, statement ) -> Iterable[DirectoryEntryRow]: return map( DirectoryEntryRow.from_dict, self._execute_with_retries(statement, [directory_ids]), ) ########################## # 'snapshot' table ########################## @_prepared_exists_statement("snapshot") def snapshot_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) @_prepared_insert_statement("snapshot", SnapshotRow.cols()) def snapshot_add_one(self, snapshot: SnapshotRow, *, statement) -> None: self._add_one(statement, "snapshot", snapshot) @_prepared_statement("SELECT * FROM snapshot WHERE id = ?") def snapshot_get(self, snapshot_id: Sha1Git, *, statement) -> ResultSet: return map( SnapshotRow.from_dict, self._execute_with_retries(statement, [snapshot_id]) ) @_prepared_statement("SELECT * FROM snapshot WHERE token(id) > ? LIMIT 1") def snapshot_get_random(self, *, statement) -> Optional[SnapshotRow]: return self._get_random_row(SnapshotRow, statement) ########################## # 'snapshot_branch' table ########################## @_prepared_insert_statement("snapshot_branch", SnapshotBranchRow.cols()) def snapshot_branch_add_one(self, branch: SnapshotBranchRow, *, statement) -> None: self._add_one(statement, None, branch) @_prepared_statement( "SELECT ascii_bins_count(target_type) AS counts " "FROM snapshot_branch " "WHERE snapshot_id = ? " ) def snapshot_count_branches( self, snapshot_id: Sha1Git, *, statement ) -> Dict[Optional[str], int]: """Returns a dictionary from type names to the number of branches of that type.""" row = self._execute_with_retries(statement, [snapshot_id]).one() (nb_none, counts) = row["counts"] return {None: nb_none, **counts} @_prepared_statement( "SELECT * FROM snapshot_branch WHERE snapshot_id = ? AND name >= ? LIMIT ?" ) def snapshot_branch_get( self, snapshot_id: Sha1Git, from_: bytes, limit: int, *, statement ) -> Iterable[SnapshotBranchRow]: return map( SnapshotBranchRow.from_dict, self._execute_with_retries(statement, [snapshot_id, from_, limit]), ) ########################## # 'origin' table ########################## @_prepared_insert_statement("origin", OriginRow.cols()) def origin_add_one(self, origin: OriginRow, *, statement) -> None: self._add_one(statement, "origin", origin) @_prepared_statement("SELECT * FROM origin WHERE sha1 = ?") def origin_get_by_sha1(self, sha1: bytes, *, statement) -> Iterable[OriginRow]: return map(OriginRow.from_dict, self._execute_with_retries(statement, [sha1])) def origin_get_by_url(self, url: str) -> Iterable[OriginRow]: return self.origin_get_by_sha1(hash_url(url)) @_prepared_statement( f'SELECT token(sha1) AS tok, {", ".join(OriginRow.cols())} ' f"FROM origin WHERE token(sha1) >= ? LIMIT ?" ) def origin_list( self, start_token: int, limit: int, *, statement ) -> Iterable[Tuple[int, OriginRow]]: """Returns an iterable of (token, origin)""" return ( (row["tok"], OriginRow.from_dict(remove_keys(row, ("tok",)))) for row in self._execute_with_retries(statement, [start_token, limit]) ) @_prepared_statement("SELECT * FROM origin") def origin_iter_all(self, *, statement) -> Iterable[OriginRow]: return map(OriginRow.from_dict, self._execute_with_retries(statement, [])) @_prepared_statement("SELECT next_visit_id FROM origin WHERE sha1 = ?") def _origin_get_next_visit_id(self, origin_sha1: bytes, *, statement) -> int: rows = list(self._execute_with_retries(statement, [origin_sha1])) assert len(rows) == 1 # TODO: error handling return rows[0]["next_visit_id"] @_prepared_statement( "UPDATE origin SET next_visit_id=? WHERE sha1 = ? IF next_visit_id=?" ) def origin_generate_unique_visit_id(self, origin_url: str, *, statement) -> int: origin_sha1 = hash_url(origin_url) next_id = self._origin_get_next_visit_id(origin_sha1) while True: res = list( self._execute_with_retries( statement, [next_id + 1, origin_sha1, next_id] ) ) assert len(res) == 1 if res[0]["[applied]"]: # No data race return next_id else: # Someone else updated it before we did, let's try again next_id = res[0]["next_visit_id"] # TODO: abort after too many attempts return next_id ########################## # 'origin_visit' table ########################## @_prepared_statement( "SELECT * FROM origin_visit WHERE origin = ? AND visit > ? " "ORDER BY visit ASC" ) def _origin_visit_get_pagination_asc_no_limit( self, origin_url: str, last_visit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, last_visit]) @_prepared_statement( "SELECT * FROM origin_visit WHERE origin = ? AND visit > ? " "ORDER BY visit ASC " "LIMIT ?" ) def _origin_visit_get_pagination_asc_limit( self, origin_url: str, last_visit: int, limit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, last_visit, limit]) @_prepared_statement( "SELECT * FROM origin_visit WHERE origin = ? AND visit < ? " "ORDER BY visit DESC" ) def _origin_visit_get_pagination_desc_no_limit( self, origin_url: str, last_visit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, last_visit]) @_prepared_statement( "SELECT * FROM origin_visit WHERE origin = ? AND visit < ? " "ORDER BY visit DESC " "LIMIT ?" ) def _origin_visit_get_pagination_desc_limit( self, origin_url: str, last_visit: int, limit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, last_visit, limit]) @_prepared_statement( "SELECT * FROM origin_visit WHERE origin = ? ORDER BY visit ASC LIMIT ?" ) def _origin_visit_get_no_pagination_asc_limit( self, origin_url: str, limit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, limit]) @_prepared_statement( "SELECT * FROM origin_visit WHERE origin = ? ORDER BY visit ASC " ) def _origin_visit_get_no_pagination_asc_no_limit( self, origin_url: str, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url]) @_prepared_statement( "SELECT * FROM origin_visit WHERE origin = ? ORDER BY visit DESC" ) def _origin_visit_get_no_pagination_desc_no_limit( self, origin_url: str, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url]) @_prepared_statement( "SELECT * FROM origin_visit WHERE origin = ? ORDER BY visit DESC LIMIT ?" ) def _origin_visit_get_no_pagination_desc_limit( self, origin_url: str, limit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, limit]) def origin_visit_get( self, origin_url: str, last_visit: Optional[int], limit: Optional[int], order: ListOrder, ) -> Iterable[OriginVisitRow]: args: List[Any] = [origin_url] if last_visit is not None: page_name = "pagination" args.append(last_visit) else: page_name = "no_pagination" if limit is not None: limit_name = "limit" args.append(limit) else: limit_name = "no_limit" method_name = f"_origin_visit_get_{page_name}_{order.value}_{limit_name}" origin_visit_get_method = getattr(self, method_name) return map(OriginVisitRow.from_dict, origin_visit_get_method(*args)) @_prepared_statement( "SELECT * FROM origin_visit_status WHERE origin = ? " "AND visit = ? AND date >= ? " "ORDER BY date ASC " "LIMIT ?" ) def _origin_visit_status_get_with_date_asc_limit( self, origin: str, visit: int, date_from: datetime.datetime, limit: int, *, statement, ) -> ResultSet: return self._execute_with_retries(statement, [origin, visit, date_from, limit]) @_prepared_statement( "SELECT * FROM origin_visit_status WHERE origin = ? " "AND visit = ? AND date <= ? " "ORDER BY visit DESC " "LIMIT ?" ) def _origin_visit_status_get_with_date_desc_limit( self, origin: str, visit: int, date_from: datetime.datetime, limit: int, *, statement, ) -> ResultSet: return self._execute_with_retries(statement, [origin, visit, date_from, limit]) @_prepared_statement( "SELECT * FROM origin_visit_status WHERE origin = ? AND visit = ? " "ORDER BY visit ASC " "LIMIT ?" ) def _origin_visit_status_get_with_no_date_asc_limit( self, origin: str, visit: int, limit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin, visit, limit]) @_prepared_statement( "SELECT * FROM origin_visit_status WHERE origin = ? AND visit = ? " "ORDER BY visit DESC " "LIMIT ?" ) def _origin_visit_status_get_with_no_date_desc_limit( self, origin: str, visit: int, limit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin, visit, limit]) def origin_visit_status_get_range( self, origin: str, visit: int, date_from: Optional[datetime.datetime], limit: int, order: ListOrder, ) -> Iterable[OriginVisitStatusRow]: args: List[Any] = [origin, visit] if date_from is not None: date_name = "date" args.append(date_from) else: date_name = "no_date" args.append(limit) method_name = f"_origin_visit_status_get_with_{date_name}_{order.value}_limit" origin_visit_status_get_method = getattr(self, method_name) return map( OriginVisitStatusRow.from_dict, origin_visit_status_get_method(*args) ) @_prepared_insert_statement("origin_visit", OriginVisitRow.cols()) def origin_visit_add_one(self, visit: OriginVisitRow, *, statement) -> None: self._add_one(statement, "origin_visit", visit) @_prepared_insert_statement("origin_visit_status", OriginVisitStatusRow.cols()) def origin_visit_status_add_one( self, visit_update: OriginVisitStatusRow, *, statement ) -> None: self._add_one(statement, None, visit_update) def origin_visit_status_get_latest( self, origin: str, visit: int, ) -> Optional[OriginVisitStatusRow]: """Given an origin visit id, return its latest origin_visit_status """ return next(self.origin_visit_status_get(origin, visit), None) @_prepared_statement( "SELECT * FROM origin_visit_status " "WHERE origin = ? AND visit = ? " "ORDER BY date DESC" ) def origin_visit_status_get( self, origin: str, visit: int, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, *, statement, ) -> Iterator[OriginVisitStatusRow]: """Return all origin visit statuses for a given visit """ return map( OriginVisitStatusRow.from_dict, self._execute_with_retries(statement, [origin, visit]), ) @_prepared_statement("SELECT * FROM origin_visit WHERE origin = ? AND visit = ?") def origin_visit_get_one( self, origin_url: str, visit_id: int, *, statement ) -> Optional[OriginVisitRow]: # TODO: error handling rows = list(self._execute_with_retries(statement, [origin_url, visit_id])) if rows: return OriginVisitRow.from_dict(rows[0]) else: return None @_prepared_statement("SELECT * FROM origin_visit WHERE origin = ?") def origin_visit_get_all( self, origin_url: str, *, statement ) -> Iterable[OriginVisitRow]: return map( OriginVisitRow.from_dict, self._execute_with_retries(statement, [origin_url]), ) @_prepared_statement("SELECT * FROM origin_visit WHERE token(origin) >= ?") def _origin_visit_iter_from( self, min_token: int, *, statement ) -> Iterable[OriginVisitRow]: return map( OriginVisitRow.from_dict, self._execute_with_retries(statement, [min_token]) ) @_prepared_statement("SELECT * FROM origin_visit WHERE token(origin) < ?") def _origin_visit_iter_to( self, max_token: int, *, statement ) -> Iterable[OriginVisitRow]: return map( OriginVisitRow.from_dict, self._execute_with_retries(statement, [max_token]) ) def origin_visit_iter(self, start_token: int) -> Iterator[OriginVisitRow]: """Returns all origin visits in order from this token, and wraps around the token space.""" yield from self._origin_visit_iter_from(start_token) yield from self._origin_visit_iter_to(start_token) ########################## # 'metadata_authority' table ########################## @_prepared_insert_statement("metadata_authority", MetadataAuthorityRow.cols()) def metadata_authority_add(self, authority: MetadataAuthorityRow, *, statement): self._add_one(statement, None, authority) @_prepared_statement("SELECT * from metadata_authority WHERE type = ? AND url = ?") def metadata_authority_get( self, type, url, *, statement ) -> Optional[MetadataAuthorityRow]: rows = list(self._execute_with_retries(statement, [type, url])) if rows: return MetadataAuthorityRow.from_dict(rows[0]) else: return None ########################## # 'metadata_fetcher' table ########################## @_prepared_insert_statement("metadata_fetcher", MetadataFetcherRow.cols()) def metadata_fetcher_add(self, fetcher, *, statement): self._add_one(statement, None, fetcher) @_prepared_statement( "SELECT * from metadata_fetcher WHERE name = ? AND version = ?" ) def metadata_fetcher_get( self, name, version, *, statement ) -> Optional[MetadataFetcherRow]: rows = list(self._execute_with_retries(statement, [name, version])) if rows: return MetadataFetcherRow.from_dict(rows[0]) else: return None ######################### # 'raw_extrinsic_metadata' table ######################### @_prepared_insert_statement( "raw_extrinsic_metadata", RawExtrinsicMetadataRow.cols() ) def raw_extrinsic_metadata_add(self, raw_extrinsic_metadata, *, statement): self._add_one(statement, None, raw_extrinsic_metadata) @_prepared_statement( "SELECT * from raw_extrinsic_metadata " "WHERE id=? AND authority_url=? AND discovery_date>? AND authority_type=?" ) def raw_extrinsic_metadata_get_after_date( self, id: str, authority_type: str, authority_url: str, after: datetime.datetime, *, statement, ) -> Iterable[RawExtrinsicMetadataRow]: return map( RawExtrinsicMetadataRow.from_dict, self._execute_with_retries( statement, [id, authority_url, after, authority_type] ), ) @_prepared_statement( "SELECT * from raw_extrinsic_metadata " "WHERE id=? AND authority_type=? AND authority_url=? " "AND (discovery_date, fetcher_name, fetcher_version) > (?, ?, ?)" ) def raw_extrinsic_metadata_get_after_date_and_fetcher( self, id: str, authority_type: str, authority_url: str, after_date: datetime.datetime, after_fetcher_name: str, after_fetcher_version: str, *, statement, ) -> Iterable[RawExtrinsicMetadataRow]: return map( RawExtrinsicMetadataRow.from_dict, self._execute_with_retries( statement, [ id, authority_type, authority_url, after_date, after_fetcher_name, after_fetcher_version, ], ), ) @_prepared_statement( "SELECT * from raw_extrinsic_metadata " "WHERE id=? AND authority_url=? AND authority_type=?" ) def raw_extrinsic_metadata_get( self, id: str, authority_type: str, authority_url: str, *, statement ) -> Iterable[RawExtrinsicMetadataRow]: return map( RawExtrinsicMetadataRow.from_dict, self._execute_with_retries(statement, [id, authority_url, authority_type]), ) ########################## # Miscellaneous ########################## @_prepared_statement("SELECT uuid() FROM revision LIMIT 1;") def check_read(self, *, statement): self._execute_with_retries(statement, []) @_prepared_statement("SELECT * FROM object_count WHERE partition_key=0") def stat_counters(self, *, statement) -> ResultSet: return map(ObjectCountRow.from_dict, self._execute_with_retries(statement, [])) diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py index f074a62b..ce2ab783 100644 --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -1,1299 +1,1311 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 import datetime import itertools import json import random import re -from typing import Any, Dict, List, Iterable, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Iterable, Optional, Set, Tuple, Union import attr from swh.core.api.serializers import msgpack_loads, msgpack_dumps from swh.model.identifiers import parse_swhid, SWHID from swh.model.hashutil import DEFAULT_ALGORITHMS from swh.model.model import ( Revision, Release, Directory, DirectoryEntry, Content, SkippedContent, OriginVisit, OriginVisitStatus, Snapshot, SnapshotBranch, TargetType, Origin, MetadataAuthority, MetadataAuthorityType, MetadataFetcher, MetadataTargetType, RawExtrinsicMetadata, Sha1Git, ) from swh.storage.interface import ( ListOrder, PagedResult, PartialBranches, Sha1, VISIT_STATUSES, ) from swh.storage.objstorage import ObjStorage from swh.storage.writer import JournalWriter from swh.storage.utils import map_optional, now from ..exc import StorageArgumentException, HashCollision from .common import TOKEN_BEGIN, TOKEN_END, hash_url, remove_keys from . import converters from .cql import CqlRunner from .schema import HASH_ALGORITHMS from .model import ( ContentRow, DirectoryEntryRow, DirectoryRow, MetadataAuthorityRow, MetadataFetcherRow, OriginRow, OriginVisitRow, RawExtrinsicMetadataRow, RevisionParentRow, SkippedContentRow, SnapshotBranchRow, SnapshotRow, ) # Max block size of contents to return BULK_BLOCK_CONTENT_LEN_MAX = 10000 class CassandraStorage: def __init__(self, hosts, keyspace, objstorage, port=9042, journal_writer=None): - self._cql_runner = CqlRunner(hosts, keyspace, port) - self.journal_writer = JournalWriter(journal_writer) - self.objstorage = ObjStorage(objstorage) + self._cql_runner: CqlRunner = CqlRunner(hosts, keyspace, port) + self.journal_writer: JournalWriter = JournalWriter(journal_writer) + self.objstorage: ObjStorage = ObjStorage(objstorage) def check_config(self, *, check_write: bool) -> bool: self._cql_runner.check_read() return True def _content_get_from_hash(self, algo, hash_) -> Iterable: """From the name of a hash algorithm and a value of that hash, looks up the "hash -> token" secondary table (content_by_{algo}) to get tokens. Then, looks up the main table (content) to get all contents with that token, and filters out contents whose hash doesn't match.""" found_tokens = self._cql_runner.content_get_tokens_from_single_hash(algo, hash_) for token in found_tokens: + assert isinstance(token, int), found_tokens # Query the main table ('content'). res = self._cql_runner.content_get_from_token(token) for row in res: # re-check the the hash (in case of murmur3 collision) if getattr(row, algo) == hash_: yield row def _content_add(self, contents: List[Content], with_data: bool) -> Dict: # Filter-out content already in the database. contents = [ c for c in contents if not self._cql_runner.content_get_from_pk(c.to_dict()) ] self.journal_writer.content_add(contents) if with_data: # First insert to the objstorage, if the endpoint is # `content_add` (as opposed to `content_add_metadata`). # TODO: this should probably be done in concurrently to inserting # in index tables (but still before the main table; so an entry is # only added to the main table after everything else was # successfully inserted. summary = self.objstorage.content_add( c for c in contents if c.status != "absent" ) content_add_bytes = summary["content:add:bytes"] content_add = 0 for content in contents: content_add += 1 # Check for sha1 or sha1_git collisions. This test is not atomic # with the insertion, so it won't detect a collision if both # contents are inserted at the same time, but it's good enough. # # The proper way to do it would probably be a BATCH, but this # would be inefficient because of the number of partitions we # need to affect (len(HASH_ALGORITHMS)+1, which is currently 5) for algo in {"sha1", "sha1_git"}: collisions = [] # Get tokens of 'content' rows with the same value for # sha1/sha1_git rows = self._content_get_from_hash(algo, content.get_hash(algo)) for row in rows: if getattr(row, algo) != content.get_hash(algo): # collision of token(partition key), ignore this # row continue for algo in HASH_ALGORITHMS: if getattr(row, algo) != content.get_hash(algo): # This hash didn't match; discard the row. collisions.append( {algo: getattr(row, algo) for algo in HASH_ALGORITHMS} ) if collisions: collisions.append(content.hashes()) raise HashCollision(algo, content.get_hash(algo), collisions) (token, insertion_finalizer) = self._cql_runner.content_add_prepare( ContentRow(**remove_keys(content.to_dict(), ("data",))) ) # Then add to index tables for algo in HASH_ALGORITHMS: self._cql_runner.content_index_add_one(algo, content, token) # Then to the main table insertion_finalizer() summary = { "content:add": content_add, } if with_data: summary["content:add:bytes"] = content_add_bytes return summary def content_add(self, content: List[Content]) -> Dict: contents = [attr.evolve(c, ctime=now()) for c in content] return self._content_add(list(contents), with_data=True) def content_update( self, contents: List[Dict[str, Any]], keys: List[str] = [] ) -> None: raise NotImplementedError( "content_update is not supported by the Cassandra backend" ) def content_add_metadata(self, content: List[Content]) -> Dict: return self._content_add(content, with_data=False) def content_get_data(self, content: Sha1) -> Optional[bytes]: # FIXME: Make this method support slicing the `data` return self.objstorage.content_get(content) def content_get_partition( self, partition_id: int, nb_partitions: int, page_token: Optional[str] = None, limit: int = 1000, ) -> PagedResult[Content]: if limit is None: raise StorageArgumentException("limit should not be None") # Compute start and end of the range of tokens covered by the # requested partition partition_size = (TOKEN_END - TOKEN_BEGIN) // nb_partitions range_start = TOKEN_BEGIN + partition_id * partition_size range_end = TOKEN_BEGIN + (partition_id + 1) * partition_size # offset the range start according to the `page_token`. if page_token is not None: if not (range_start <= int(page_token) <= range_end): raise StorageArgumentException("Invalid page_token.") range_start = int(page_token) next_page_token: Optional[str] = None rows = self._cql_runner.content_get_token_range( range_start, range_end, limit + 1 ) contents = [] for counter, (tok, row) in enumerate(rows): if row.status == "absent": continue row_d = row.to_dict() if counter >= limit: next_page_token = str(tok) break contents.append(Content(**row_d)) assert len(contents) <= limit return PagedResult(results=contents, next_page_token=next_page_token) def content_get(self, contents: List[Sha1]) -> List[Optional[Content]]: contents_by_sha1: Dict[Sha1, Optional[Content]] = {} for sha1 in contents: # Get all (sha1, sha1_git, sha256, blake2s256) whose sha1 # matches the argument, from the index table ('content_by_sha1') for row in self._content_get_from_hash("sha1", sha1): row_d = row.to_dict() row_d.pop("ctime") content = Content(**row_d) contents_by_sha1[content.sha1] = content return [contents_by_sha1.get(sha1) for sha1 in contents] def content_find(self, content: Dict[str, Any]) -> List[Content]: # Find an algorithm that is common to all the requested contents. # It will be used to do an initial filtering efficiently. filter_algos = list(set(content).intersection(HASH_ALGORITHMS)) if not filter_algos: raise StorageArgumentException( "content keys must contain at least one " f"of: {', '.join(sorted(HASH_ALGORITHMS))}" ) common_algo = filter_algos[0] results = [] rows = self._content_get_from_hash(common_algo, content[common_algo]) for row in rows: # Re-check all the hashes, in case of collisions (either of the # hash of the partition key, or the hashes in it) for algo in HASH_ALGORITHMS: if content.get(algo) and getattr(row, algo) != content[algo]: # This hash didn't match; discard the row. break else: # All hashes match, keep this row. row_d = row.to_dict() row_d["ctime"] = row.ctime.replace(tzinfo=datetime.timezone.utc) results.append(Content(**row_d)) return results def content_missing( self, contents: List[Dict[str, Any]], key_hash: str = "sha1" ) -> Iterable[bytes]: if key_hash not in DEFAULT_ALGORITHMS: raise StorageArgumentException( "key_hash should be one of {','.join(DEFAULT_ALGORITHMS)}" ) for content in contents: res = self.content_find(content) if not res: yield content[key_hash] def content_missing_per_sha1(self, contents: List[bytes]) -> Iterable[bytes]: return self.content_missing([{"sha1": c for c in contents}]) def content_missing_per_sha1_git( self, contents: List[Sha1Git] ) -> Iterable[Sha1Git]: return self.content_missing( [{"sha1_git": c for c in contents}], key_hash="sha1_git" ) def content_get_random(self) -> Sha1Git: - return self._cql_runner.content_get_random().sha1_git + content = self._cql_runner.content_get_random() + assert content, "Could not find any content" + return content.sha1_git def _skipped_content_add(self, contents: List[SkippedContent]) -> Dict: # Filter-out content already in the database. contents = [ c for c in contents if not self._cql_runner.skipped_content_get_from_pk(c.to_dict()) ] self.journal_writer.skipped_content_add(contents) for content in contents: # Compute token of the row in the main table (token, insertion_finalizer) = self._cql_runner.skipped_content_add_prepare( SkippedContentRow.from_dict({"origin": None, **content.to_dict()}) ) # Then add to index tables for algo in HASH_ALGORITHMS: self._cql_runner.skipped_content_index_add_one(algo, content, token) # Then to the main table insertion_finalizer() return {"skipped_content:add": len(contents)} def skipped_content_add(self, content: List[SkippedContent]) -> Dict: contents = [attr.evolve(c, ctime=now()) for c in content] return self._skipped_content_add(contents) def skipped_content_missing( self, contents: List[Dict[str, Any]] ) -> Iterable[Dict[str, Any]]: for content in contents: if not self._cql_runner.skipped_content_get_from_pk(content): yield {algo: content[algo] for algo in DEFAULT_ALGORITHMS} def directory_add(self, directories: List[Directory]) -> Dict: # Filter out directories that are already inserted. missing = self.directory_missing([dir_.id for dir_ in directories]) directories = [dir_ for dir_ in directories if dir_.id in missing] self.journal_writer.directory_add(directories) for directory in directories: # Add directory entries to the 'directory_entry' table for entry in directory.entries: self._cql_runner.directory_entry_add_one( DirectoryEntryRow(directory_id=directory.id, **entry.to_dict()) ) # Add the directory *after* adding all the entries, so someone # calling snapshot_get_branch in the meantime won't end up # with half the entries. self._cql_runner.directory_add_one(DirectoryRow(id=directory.id)) return {"directory:add": len(directories)} def directory_missing(self, directories: List[Sha1Git]) -> Iterable[Sha1Git]: return self._cql_runner.directory_missing(directories) def _join_dentry_to_content(self, dentry: DirectoryEntry) -> Dict[str, Any]: keys = ( "status", "sha1", "sha1_git", "sha256", "length", ) ret = dict.fromkeys(keys) ret.update(dentry.to_dict()) if ret["type"] == "file": contents = self.content_find({"sha1_git": ret["target"]}) if contents: content = contents[0] for key in keys: ret[key] = getattr(content, key) return ret def _directory_ls( self, directory_id: Sha1Git, recursive: bool, prefix: bytes = b"" ) -> Iterable[Dict[str, Any]]: if self.directory_missing([directory_id]): return rows = list(self._cql_runner.directory_entry_get([directory_id])) for row in rows: entry_d = row.to_dict() # Build and yield the directory entry dict del entry_d["directory_id"] entry = DirectoryEntry.from_dict(entry_d) ret = self._join_dentry_to_content(entry) ret["name"] = prefix + ret["name"] ret["dir_id"] = directory_id yield ret if recursive and ret["type"] == "dir": yield from self._directory_ls( ret["target"], True, prefix + ret["name"] + b"/" ) def directory_entry_get_by_path( self, directory: Sha1Git, paths: List[bytes] ) -> Optional[Dict[str, Any]]: return self._directory_entry_get_by_path(directory, paths, b"") def _directory_entry_get_by_path( self, directory: Sha1Git, paths: List[bytes], prefix: bytes ) -> Optional[Dict[str, Any]]: if not paths: return None contents = list(self.directory_ls(directory)) if not contents: return None def _get_entry(entries, name): """Finds the entry with the requested name, prepends the prefix (to get its full path), and returns it. If no entry has that name, returns None.""" for entry in entries: if entry["name"] == name: entry = entry.copy() entry["name"] = prefix + entry["name"] return entry first_item = _get_entry(contents, paths[0]) if len(paths) == 1: return first_item if not first_item or first_item["type"] != "dir": return None return self._directory_entry_get_by_path( first_item["target"], paths[1:], prefix + paths[0] + b"/" ) def directory_ls( self, directory: Sha1Git, recursive: bool = False ) -> Iterable[Dict[str, Any]]: yield from self._directory_ls(directory, recursive) def directory_get_random(self) -> Sha1Git: - return self._cql_runner.directory_get_random().id + directory = self._cql_runner.directory_get_random() + assert directory, "Could not find any directory" + return directory.id def revision_add(self, revisions: List[Revision]) -> Dict: # Filter-out revisions already in the database missing = self.revision_missing([rev.id for rev in revisions]) revisions = [rev for rev in revisions if rev.id in missing] self.journal_writer.revision_add(revisions) for revision in revisions: revobject = converters.revision_to_db(revision) if revobject: # Add parents first for (rank, parent) in enumerate(revision.parents): self._cql_runner.revision_parent_add_one( RevisionParentRow( id=revobject.id, parent_rank=rank, parent_id=parent ) ) # Then write the main revision row. # Writing this after all parents were written ensures that # read endpoints don't return a partial view while writing # the parents self._cql_runner.revision_add_one(revobject) return {"revision:add": len(revisions)} def revision_missing(self, revisions: List[Sha1Git]) -> Iterable[Sha1Git]: return self._cql_runner.revision_missing(revisions) def revision_get( self, revisions: List[Sha1Git] ) -> Iterable[Optional[Dict[str, Any]]]: rows = self._cql_runner.revision_get(revisions) revs = {} for row in rows: # TODO: use a single query to get all parents? # (it might have lower latency, but requires more code and more # bandwidth, because revision id would be part of each returned # row) parents = tuple(self._cql_runner.revision_parent_get(row.id)) # parent_rank is the clustering key, so results are already # sorted by rank. rev = converters.revision_from_db(row, parents=parents) revs[rev.id] = rev.to_dict() for rev_id in revisions: yield revs.get(rev_id) def _get_parent_revs( self, rev_ids: Iterable[Sha1Git], seen: Set[Sha1Git], limit: Optional[int], short: bool, ) -> Union[ Iterable[Dict[str, Any]], Iterable[Tuple[Sha1Git, Tuple[Sha1Git, ...]]], ]: if limit and len(seen) >= limit: return rev_ids = [id_ for id_ in rev_ids if id_ not in seen] if not rev_ids: return seen |= set(rev_ids) # We need this query, even if short=True, to return consistent # results (ie. not return only a subset of a revision's parents # if it is being written) if short: ids = self._cql_runner.revision_get_ids(rev_ids) for id_ in ids: # TODO: use a single query to get all parents? # (it might have less latency, but requires less code and more # bandwidth (because revision id would be part of each returned # row) parents = tuple(self._cql_runner.revision_parent_get(id_)) # parent_rank is the clustering key, so results are already # sorted by rank. yield (id_, parents) yield from self._get_parent_revs(parents, seen, limit, short) else: rows = self._cql_runner.revision_get(rev_ids) for row in rows: # TODO: use a single query to get all parents? # (it might have less latency, but requires less code and more # bandwidth (because revision id would be part of each returned # row) parents = tuple(self._cql_runner.revision_parent_get(row.id)) # parent_rank is the clustering key, so results are already # sorted by rank. rev = converters.revision_from_db(row, parents=parents) yield rev.to_dict() yield from self._get_parent_revs(parents, seen, limit, short) def revision_log( self, revisions: List[Sha1Git], limit: Optional[int] = None ) -> Iterable[Optional[Dict[str, Any]]]: seen: Set[Sha1Git] = set() yield from self._get_parent_revs(revisions, seen, limit, False) def revision_shortlog( self, revisions: List[Sha1Git], limit: Optional[int] = None ) -> Iterable[Optional[Tuple[Sha1Git, Tuple[Sha1Git, ...]]]]: seen: Set[Sha1Git] = set() yield from self._get_parent_revs(revisions, seen, limit, True) def revision_get_random(self) -> Sha1Git: - return self._cql_runner.revision_get_random().id + revision = self._cql_runner.revision_get_random() + assert revision, "Could not find any revision" + return revision.id def release_add(self, releases: List[Release]) -> Dict: to_add = [] for rel in releases: if rel not in to_add: to_add.append(rel) missing = set(self.release_missing([rel.id for rel in to_add])) to_add = [rel for rel in to_add if rel.id in missing] self.journal_writer.release_add(to_add) for release in to_add: if release: self._cql_runner.release_add_one(converters.release_to_db(release)) return {"release:add": len(to_add)} def release_missing(self, releases: List[Sha1Git]) -> Iterable[Sha1Git]: return self._cql_runner.release_missing(releases) def release_get( self, releases: List[Sha1Git] ) -> Iterable[Optional[Dict[str, Any]]]: rows = self._cql_runner.release_get(releases) rels = {} for row in rows: release = converters.release_from_db(row) rels[row.id] = release.to_dict() for rel_id in releases: yield rels.get(rel_id) def release_get_random(self) -> Sha1Git: - return self._cql_runner.release_get_random().id + release = self._cql_runner.release_get_random() + assert release, "Could not find any release" + return release.id def snapshot_add(self, snapshots: List[Snapshot]) -> Dict: missing = self._cql_runner.snapshot_missing([snp.id for snp in snapshots]) snapshots = [snp for snp in snapshots if snp.id in missing] for snapshot in snapshots: self.journal_writer.snapshot_add([snapshot]) # Add branches for (branch_name, branch) in snapshot.branches.items(): if branch is None: target_type: Optional[str] = None target: Optional[bytes] = None else: target_type = branch.target_type.value target = branch.target self._cql_runner.snapshot_branch_add_one( SnapshotBranchRow( snapshot_id=snapshot.id, name=branch_name, target_type=target_type, target=target, ) ) # Add the snapshot *after* adding all the branches, so someone # calling snapshot_get_branch in the meantime won't end up # with half the branches. self._cql_runner.snapshot_add_one(SnapshotRow(id=snapshot.id)) return {"snapshot:add": len(snapshots)} def snapshot_missing(self, snapshots: List[Sha1Git]) -> Iterable[Sha1Git]: return self._cql_runner.snapshot_missing(snapshots) def snapshot_get(self, snapshot_id: Sha1Git) -> Optional[Dict[str, Any]]: d = self.snapshot_get_branches(snapshot_id) if d is None: return None return { "id": d["id"], "branches": { name: branch.to_dict() if branch else None for (name, branch) in d["branches"].items() }, "next_branch": d["next_branch"], } def snapshot_get_by_origin_visit( self, origin: str, visit: int ) -> Optional[Dict[str, Any]]: visit_status = self.origin_visit_status_get_latest( origin, visit, require_snapshot=True ) if visit_status and visit_status.snapshot: return self.snapshot_get(visit_status.snapshot) return None def snapshot_count_branches( self, snapshot_id: Sha1Git ) -> Optional[Dict[Optional[str], int]]: if self._cql_runner.snapshot_missing([snapshot_id]): # Makes sure we don't fetch branches for a snapshot that is # being added. return None return self._cql_runner.snapshot_count_branches(snapshot_id) def snapshot_get_branches( self, snapshot_id: Sha1Git, branches_from: bytes = b"", branches_count: int = 1000, target_types: Optional[List[str]] = None, ) -> Optional[PartialBranches]: if self._cql_runner.snapshot_missing([snapshot_id]): # Makes sure we don't fetch branches for a snapshot that is # being added. return None branches: List = [] while len(branches) < branches_count + 1: new_branches = list( self._cql_runner.snapshot_branch_get( snapshot_id, branches_from, branches_count + 1 ) ) if not new_branches: break branches_from = new_branches[-1].name new_branches_filtered = new_branches # Filter by target_type if target_types: new_branches_filtered = [ branch for branch in new_branches_filtered if branch.target is not None and branch.target_type in target_types ] branches.extend(new_branches_filtered) if len(new_branches) < branches_count + 1: break if len(branches) > branches_count: last_branch = branches.pop(-1).name else: last_branch = None return PartialBranches( id=snapshot_id, branches={ branch.name: None if branch.target is None else SnapshotBranch( target=branch.target, target_type=TargetType(branch.target_type) ) for branch in branches }, next_branch=last_branch, ) def snapshot_get_random(self) -> Sha1Git: - return self._cql_runner.snapshot_get_random().id + snapshot = self._cql_runner.snapshot_get_random() + assert snapshot, "Could not find any snapshot" + return snapshot.id def object_find_by_sha1_git(self, ids: List[Sha1Git]) -> Dict[Sha1Git, List[Dict]]: results: Dict[Sha1Git, List[Dict]] = {id_: [] for id_ in ids} missing_ids = set(ids) # Mind the order, revision is the most likely one for a given ID, # so we check revisions first. - queries = [ + queries: List[Tuple[str, Callable[[List[Sha1Git]], List[Sha1Git]]]] = [ ("revision", self._cql_runner.revision_missing), ("release", self._cql_runner.release_missing), ("content", self._cql_runner.content_missing_by_sha1_git), ("directory", self._cql_runner.directory_missing), ] for (object_type, query_fn) in queries: - found_ids = missing_ids - set(query_fn(missing_ids)) + found_ids = missing_ids - set(query_fn(list(missing_ids))) for sha1_git in found_ids: results[sha1_git].append( {"sha1_git": sha1_git, "type": object_type,} ) missing_ids.remove(sha1_git) if not missing_ids: # We found everything, skipping the next queries. break return results def origin_get(self, origins: List[str]) -> Iterable[Optional[Origin]]: return [self.origin_get_one(origin) for origin in origins] def origin_get_one(self, origin_url: str) -> Optional[Origin]: """Given an origin url, return the origin if it exists, None otherwise """ rows = list(self._cql_runner.origin_get_by_url(origin_url)) if rows: assert len(rows) == 1 return Origin(url=rows[0].url) else: return None def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]: results = [] for sha1 in sha1s: rows = list(self._cql_runner.origin_get_by_sha1(sha1)) origin = {"url": rows[0].url} if rows else None results.append(origin) return results def origin_list( self, page_token: Optional[str] = None, limit: int = 100 ) -> PagedResult[Origin]: # Compute what token to begin the listing from start_token = TOKEN_BEGIN if page_token: start_token = int(page_token) if not (TOKEN_BEGIN <= start_token <= TOKEN_END): raise StorageArgumentException("Invalid page_token.") next_page_token = None origins = [] # Take one more origin so we can reuse it as the next page token if any for (tok, row) in self._cql_runner.origin_list(start_token, limit + 1): origins.append(Origin(url=row.url)) # keep reference of the last id for pagination purposes last_id = tok if len(origins) > limit: # last origin id is the next page token next_page_token = str(last_id) # excluding that origin from the result to respect the limit size origins = origins[:limit] assert len(origins) <= limit return PagedResult(results=origins, next_page_token=next_page_token) def origin_search( self, url_pattern: str, page_token: Optional[str] = None, limit: int = 50, regexp: bool = False, with_visit: bool = False, ) -> PagedResult[Origin]: # TODO: remove this endpoint, swh-search should be used instead. next_page_token = None offset = int(page_token) if page_token else 0 origin_rows = [row for row in self._cql_runner.origin_iter_all()] if regexp: pat = re.compile(url_pattern) origin_rows = [row for row in origin_rows if pat.search(row.url)] else: origin_rows = [row for row in origin_rows if url_pattern in row.url] if with_visit: origin_rows = [row for row in origin_rows if row.next_visit_id > 1] origins = [Origin(url=row.url) for row in origin_rows] origins = origins[offset : offset + limit + 1] if len(origins) > limit: # next offset next_page_token = str(offset + limit) # excluding that origin from the result to respect the limit size origins = origins[:limit] assert len(origins) <= limit return PagedResult(results=origins, next_page_token=next_page_token) def origin_add(self, origins: List[Origin]) -> Dict[str, int]: to_add = [ori for ori in origins if self.origin_get_one(ori.url) is None] self.journal_writer.origin_add(to_add) for origin in to_add: self._cql_runner.origin_add_one( OriginRow(sha1=hash_url(origin.url), url=origin.url, next_visit_id=1) ) return {"origin:add": len(to_add)} def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]: for visit in visits: origin = self.origin_get_one(visit.origin) if not origin: # Cannot add a visit without an origin raise StorageArgumentException("Unknown origin %s", visit.origin) all_visits = [] nb_visits = 0 for visit in visits: nb_visits += 1 if not visit.visit: visit_id = self._cql_runner.origin_generate_unique_visit_id( visit.origin ) visit = attr.evolve(visit, visit=visit_id) self.journal_writer.origin_visit_add([visit]) self._cql_runner.origin_visit_add_one(OriginVisitRow(**visit.to_dict())) assert visit.visit is not None all_visits.append(visit) self._origin_visit_status_add( OriginVisitStatus( origin=visit.origin, visit=visit.visit, date=visit.date, status="created", snapshot=None, ) ) return all_visits def _origin_visit_status_add(self, visit_status: OriginVisitStatus) -> None: """Add an origin visit status""" self.journal_writer.origin_visit_status_add([visit_status]) self._cql_runner.origin_visit_status_add_one( converters.visit_status_to_row(visit_status) ) def origin_visit_status_add(self, visit_statuses: List[OriginVisitStatus]) -> None: # First round to check existence (fail early if any is ko) for visit_status in visit_statuses: origin_url = self.origin_get_one(visit_status.origin) if not origin_url: raise StorageArgumentException(f"Unknown origin {visit_status.origin}") for visit_status in visit_statuses: self._origin_visit_status_add(visit_status) def _origin_visit_apply_last_status(self, visit: Dict[str, Any]) -> Dict[str, Any]: """Retrieve the latest visit status information for the origin visit. Then merge it with the visit and return it. """ row = self._cql_runner.origin_visit_status_get_latest( visit["origin"], visit["visit"] ) assert row is not None visit_status = converters.row_to_visit_status(row) return { # default to the values in visit **visit, # override with the last update **visit_status.to_dict(), # visit['origin'] is the URL (via a join), while # visit_status['origin'] is only an id. "origin": visit["origin"], # but keep the date of the creation of the origin visit "date": visit["date"], } def _origin_visit_get_latest_status(self, visit: OriginVisit) -> OriginVisitStatus: """Retrieve the latest visit status information for the origin visit object. """ + assert visit.visit row = self._cql_runner.origin_visit_status_get_latest(visit.origin, visit.visit) assert row is not None visit_status = converters.row_to_visit_status(row) return attr.evolve(visit_status, origin=visit.origin) @staticmethod def _format_origin_visit_row(visit): return { **visit.to_dict(), "origin": visit.origin, "date": visit.date.replace(tzinfo=datetime.timezone.utc), } def origin_visit_get( self, origin: str, page_token: Optional[str] = None, order: ListOrder = ListOrder.ASC, limit: int = 10, ) -> PagedResult[OriginVisit]: if not isinstance(order, ListOrder): raise StorageArgumentException("order must be a ListOrder value") if page_token and not isinstance(page_token, str): raise StorageArgumentException("page_token must be a string.") next_page_token = None - visit_from = page_token and int(page_token) + visit_from = None if page_token is None else int(page_token) visits: List[OriginVisit] = [] extra_limit = limit + 1 rows = self._cql_runner.origin_visit_get(origin, visit_from, extra_limit, order) for row in rows: visits.append(converters.row_to_visit(row)) assert len(visits) <= extra_limit if len(visits) == extra_limit: visits = visits[:limit] next_page_token = str(visits[-1].visit) return PagedResult(results=visits, next_page_token=next_page_token) def origin_visit_status_get( self, origin: str, visit: int, page_token: Optional[str] = None, order: ListOrder = ListOrder.ASC, limit: int = 10, ) -> PagedResult[OriginVisitStatus]: next_page_token = None date_from = None if page_token is not None: date_from = datetime.datetime.fromisoformat(page_token) # Take one more visit status so we can reuse it as the next page token if any rows = self._cql_runner.origin_visit_status_get_range( origin, visit, date_from, limit + 1, order ) visit_statuses = [converters.row_to_visit_status(row) for row in rows] if len(visit_statuses) > limit: # last visit status date is the next page token next_page_token = str(visit_statuses[-1].date) # excluding that visit status from the result to respect the limit size visit_statuses = visit_statuses[:limit] return PagedResult(results=visit_statuses, next_page_token=next_page_token) def origin_visit_find_by_date( self, origin: str, visit_date: datetime.datetime ) -> Optional[OriginVisit]: # Iterator over all the visits of the origin # This should be ok for now, as there aren't too many visits # per origin. rows = list(self._cql_runner.origin_visit_get_all(origin)) def key(visit): dt = visit.date.replace(tzinfo=datetime.timezone.utc) - visit_date return (abs(dt), -visit.visit) if rows: return converters.row_to_visit(min(rows, key=key)) return None def origin_visit_get_by(self, origin: str, visit: int) -> Optional[OriginVisit]: row = self._cql_runner.origin_visit_get_one(origin, visit) if row: return converters.row_to_visit(row) return None def origin_visit_get_latest( self, origin: str, type: Optional[str] = None, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, ) -> Optional[OriginVisit]: if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): raise StorageArgumentException( f"Unknown allowed statuses {','.join(allowed_statuses)}, only " f"{','.join(VISIT_STATUSES)} authorized" ) # TODO: Do not fetch all visits rows = self._cql_runner.origin_visit_get_all(origin) latest_visit = None for row in rows: visit = self._format_origin_visit_row(row) updated_visit = self._origin_visit_apply_last_status(visit) if type is not None and updated_visit["type"] != type: continue if allowed_statuses and updated_visit["status"] not in allowed_statuses: continue if require_snapshot and updated_visit["snapshot"] is None: continue # updated_visit is a candidate if latest_visit is not None: if updated_visit["date"] < latest_visit["date"]: continue if updated_visit["visit"] < latest_visit["visit"]: continue latest_visit = updated_visit if latest_visit is None: return None return OriginVisit( origin=latest_visit["origin"], visit=latest_visit["visit"], date=latest_visit["date"], type=latest_visit["type"], ) def origin_visit_status_get_latest( self, origin_url: str, visit: int, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, ) -> Optional[OriginVisitStatus]: if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): raise StorageArgumentException( f"Unknown allowed statuses {','.join(allowed_statuses)}, only " f"{','.join(VISIT_STATUSES)} authorized" ) rows = list( self._cql_runner.origin_visit_status_get( origin_url, visit, allowed_statuses, require_snapshot ) ) # filtering is done python side as we cannot do it server side if allowed_statuses: rows = [row for row in rows if row.status in allowed_statuses] if require_snapshot: rows = [row for row in rows if row.snapshot is not None] if not rows: return None return converters.row_to_visit_status(rows[0]) def origin_visit_status_get_random( self, type: str ) -> Optional[Tuple[OriginVisit, OriginVisitStatus]]: back_in_the_day = now() - datetime.timedelta(weeks=12) # 3 months back # Random position to start iteration at start_token = random.randint(TOKEN_BEGIN, TOKEN_END) # Iterator over all visits, ordered by token(origins) then visit_id rows = self._cql_runner.origin_visit_iter(start_token) for row in rows: visit = converters.row_to_visit(row) visit_status = self._origin_visit_get_latest_status(visit) if visit.date > back_in_the_day and visit_status.status == "full": return visit, visit_status return None def stat_counters(self): rows = self._cql_runner.stat_counters() keys = ( "content", "directory", "origin", "origin_visit", "release", "revision", "skipped_content", "snapshot", ) stats = {key: 0 for key in keys} stats.update({row.object_type: row.count for row in rows}) return stats def refresh_stat_counters(self): pass def raw_extrinsic_metadata_add(self, metadata: List[RawExtrinsicMetadata]) -> None: self.journal_writer.raw_extrinsic_metadata_add(metadata) for metadata_entry in metadata: if not self._cql_runner.metadata_authority_get( metadata_entry.authority.type.value, metadata_entry.authority.url ): raise StorageArgumentException( f"Unknown authority {metadata_entry.authority}" ) if not self._cql_runner.metadata_fetcher_get( metadata_entry.fetcher.name, metadata_entry.fetcher.version ): raise StorageArgumentException( f"Unknown fetcher {metadata_entry.fetcher}" ) try: row = RawExtrinsicMetadataRow( type=metadata_entry.type.value, id=str(metadata_entry.id), authority_type=metadata_entry.authority.type.value, authority_url=metadata_entry.authority.url, discovery_date=metadata_entry.discovery_date, fetcher_name=metadata_entry.fetcher.name, fetcher_version=metadata_entry.fetcher.version, format=metadata_entry.format, metadata=metadata_entry.metadata, origin=metadata_entry.origin, visit=metadata_entry.visit, snapshot=map_optional(str, metadata_entry.snapshot), release=map_optional(str, metadata_entry.release), revision=map_optional(str, metadata_entry.revision), path=metadata_entry.path, directory=map_optional(str, metadata_entry.directory), ) self._cql_runner.raw_extrinsic_metadata_add(row) except TypeError as e: raise StorageArgumentException(*e.args) def raw_extrinsic_metadata_get( self, type: MetadataTargetType, id: Union[str, SWHID], authority: MetadataAuthority, after: Optional[datetime.datetime] = None, page_token: Optional[bytes] = None, limit: int = 1000, ) -> PagedResult[RawExtrinsicMetadata]: if type == MetadataTargetType.ORIGIN: if isinstance(id, SWHID): raise StorageArgumentException( f"raw_extrinsic_metadata_get called with type='origin', " f"but provided id is an SWHID: {id!r}" ) else: if not isinstance(id, SWHID): raise StorageArgumentException( f"raw_extrinsic_metadata_get called with type!='origin', " f"but provided id is not an SWHID: {id!r}" ) if page_token is not None: (after_date, after_fetcher_name, after_fetcher_url) = msgpack_loads( base64.b64decode(page_token) ) if after and after_date < after: raise StorageArgumentException( "page_token is inconsistent with the value of 'after'." ) entries = self._cql_runner.raw_extrinsic_metadata_get_after_date_and_fetcher( # noqa str(id), authority.type.value, authority.url, after_date, after_fetcher_name, after_fetcher_url, ) elif after is not None: entries = self._cql_runner.raw_extrinsic_metadata_get_after_date( str(id), authority.type.value, authority.url, after ) else: entries = self._cql_runner.raw_extrinsic_metadata_get( str(id), authority.type.value, authority.url ) if limit: entries = itertools.islice(entries, 0, limit + 1) results = [] for entry in entries: discovery_date = entry.discovery_date.replace(tzinfo=datetime.timezone.utc) assert str(id) == entry.id result = RawExtrinsicMetadata( type=MetadataTargetType(entry.type), id=id, authority=MetadataAuthority( type=MetadataAuthorityType(entry.authority_type), url=entry.authority_url, ), fetcher=MetadataFetcher( name=entry.fetcher_name, version=entry.fetcher_version, ), discovery_date=discovery_date, format=entry.format, metadata=entry.metadata, origin=entry.origin, visit=entry.visit, snapshot=map_optional(parse_swhid, entry.snapshot), release=map_optional(parse_swhid, entry.release), revision=map_optional(parse_swhid, entry.revision), path=entry.path, directory=map_optional(parse_swhid, entry.directory), ) results.append(result) if len(results) > limit: results.pop() assert len(results) == limit last_result = results[-1] next_page_token: Optional[str] = base64.b64encode( msgpack_dumps( ( last_result.discovery_date, last_result.fetcher.name, last_result.fetcher.version, ) ) ).decode() else: next_page_token = None return PagedResult(next_page_token=next_page_token, results=results,) def metadata_fetcher_add(self, fetchers: List[MetadataFetcher]) -> None: self.journal_writer.metadata_fetcher_add(fetchers) for fetcher in fetchers: self._cql_runner.metadata_fetcher_add( MetadataFetcherRow( name=fetcher.name, version=fetcher.version, metadata=json.dumps(map_optional(dict, fetcher.metadata)), ) ) def metadata_fetcher_get( self, name: str, version: str ) -> Optional[MetadataFetcher]: fetcher = self._cql_runner.metadata_fetcher_get(name, version) if fetcher: return MetadataFetcher( name=fetcher.name, version=fetcher.version, metadata=json.loads(fetcher.metadata), ) else: return None def metadata_authority_add(self, authorities: List[MetadataAuthority]) -> None: self.journal_writer.metadata_authority_add(authorities) for authority in authorities: self._cql_runner.metadata_authority_add( MetadataAuthorityRow( url=authority.url, type=authority.type.value, metadata=json.dumps(map_optional(dict, authority.metadata)), ) ) def metadata_authority_get( self, type: MetadataAuthorityType, url: str ) -> Optional[MetadataAuthority]: authority = self._cql_runner.metadata_authority_get(type.value, url) if authority: return MetadataAuthority( type=MetadataAuthorityType(authority.type), url=authority.url, metadata=json.loads(authority.metadata), ) else: return None def clear_buffers(self, object_types: Optional[List[str]] = None) -> None: """Do nothing """ return None def flush(self, object_types: Optional[List[str]] = None) -> Dict: return {}