diff --git a/swh/storage/cassandra/converters.py b/swh/storage/cassandra/converters.py index a1943011..253ba3cd 100644 --- a/swh/storage/cassandra/converters.py +++ b/swh/storage/cassandra/converters.py @@ -1,94 +1,96 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import json import attr from copy import deepcopy from typing import Any, Dict, Tuple from cassandra.cluster import ResultSet from swh.model.model import ( ObjectType, OriginVisitStatus, Revision, RevisionType, Release, Sha1Git, ) from swh.model.hashutil import DEFAULT_ALGORITHMS from .common import Row def revision_to_db(revision: Revision) -> Dict[str, Any]: # we use a deepcopy of the dict because we do not want to recurse the # Model->dict conversion (to keep Timestamp & al. entities), BUT we do not # want to modify original metadata (embedded in the Model entity), so we # non-recursively convert it as a dict but make a deep copy. db_revision = deepcopy(attr.asdict(revision, recurse=False)) metadata = revision.metadata extra_headers = revision.extra_headers if not extra_headers and metadata and "extra_headers" in metadata: extra_headers = db_revision["metadata"].pop("extra_headers") - db_revision["metadata"] = json.dumps(db_revision["metadata"]) + db_revision["metadata"] = json.dumps( + dict(db_revision["metadata"] if db_revision["metadata"] is not None else None) + ) db_revision["extra_headers"] = extra_headers db_revision["type"] = db_revision["type"].value return db_revision def revision_from_db(db_revision: Row, parents: Tuple[Sha1Git]) -> Revision: revision = db_revision._asdict() # type: ignore metadata = json.loads(revision.pop("metadata", None)) extra_headers = revision.pop("extra_headers", ()) if not extra_headers and metadata and "extra_headers" in metadata: extra_headers = metadata.pop("extra_headers") if extra_headers is None: extra_headers = () return Revision( parents=parents, type=RevisionType(revision.pop("type")), metadata=metadata, extra_headers=extra_headers, **revision, ) def release_to_db(release: Release) -> Dict[str, Any]: db_release = attr.asdict(release, recurse=False) db_release["target_type"] = db_release["target_type"].value return db_release def release_from_db(db_release: Row) -> Release: release = db_release._asdict() # type: ignore return Release(target_type=ObjectType(release.pop("target_type")), **release,) def row_to_content_hashes(row: Row) -> Dict[str, bytes]: """Convert cassandra row to a content hashes """ hashes = {} for algo in DEFAULT_ALGORITHMS: hashes[algo] = getattr(row, algo) return hashes def row_to_visit_status(row: ResultSet) -> OriginVisitStatus: """Format a row representing a visit_status to an actual dict representing an OriginVisitStatus. """ return OriginVisitStatus.from_dict( { **row._asdict(), "origin": row.origin, "date": row.date.replace(tzinfo=datetime.timezone.utc), "metadata": (json.loads(row.metadata) if row.metadata else None), } ) diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py index 6a2997f8..66a5ce8c 100644 --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -1,1000 +1,1002 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import functools import json import logging import random from typing import ( Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union, ) from cassandra import CoordinationFailure from cassandra.cluster import Cluster, EXEC_PROFILE_DEFAULT, ExecutionProfile, ResultSet from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy from cassandra.query import PreparedStatement, BoundStatement from tenacity import ( retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type, ) from swh.model.model import ( Sha1Git, TimestampWithTimezone, Timestamp, Person, Content, SkippedContent, OriginVisit, OriginVisitStatus, Origin, ) from .common import Row, TOKEN_BEGIN, TOKEN_END, hash_url from .schema import CREATE_TABLES_QUERIES, HASH_ALGORITHMS from .. import extrinsic_metadata logger = logging.getLogger(__name__) _execution_profiles = { EXEC_PROFILE_DEFAULT: ExecutionProfile( load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy()) ), } # Configuration for cassandra-driver's access to servers: # * hit the right server directly when sending a query (TokenAwarePolicy), # * if there's more than one, then pick one at random that's in the same # datacenter as the client (DCAwareRoundRobinPolicy) def create_keyspace( hosts: List[str], keyspace: str, port: int = 9042, *, durable_writes=True ): cluster = Cluster(hosts, port=port, execution_profiles=_execution_profiles) session = cluster.connect() extra_params = "" if not durable_writes: extra_params = "AND durable_writes = false" session.execute( """CREATE KEYSPACE IF NOT EXISTS "%s" WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 } %s; """ % (keyspace, extra_params) ) session.execute('USE "%s"' % keyspace) for query in CREATE_TABLES_QUERIES: session.execute(query) T = TypeVar("T") def _prepared_statement(query: str) -> Callable[[Callable[..., T]], Callable[..., T]]: """Returns a decorator usable on methods of CqlRunner, to inject them with a 'statement' argument, that is a prepared statement corresponding to the query. This only works on methods of CqlRunner, as preparing a statement requires a connection to a Cassandra server.""" def decorator(f): @functools.wraps(f) def newf(self, *args, **kwargs) -> T: if f.__name__ not in self._prepared_statements: statement: PreparedStatement = self._session.prepare(query) self._prepared_statements[f.__name__] = statement return f( self, *args, **kwargs, statement=self._prepared_statements[f.__name__] ) return newf return decorator def _prepared_insert_statement(table_name: str, columns: List[str]): """Shorthand for using `_prepared_statement` for `INSERT INTO` statements.""" return _prepared_statement( "INSERT INTO %s (%s) VALUES (%s)" % (table_name, ", ".join(columns), ", ".join("?" for _ in columns),) ) def _prepared_exists_statement(table_name: str): """Shorthand for using `_prepared_statement` for queries that only check which ids in a list exist in the table.""" return _prepared_statement(f"SELECT id FROM {table_name} WHERE id IN ?") class CqlRunner: """Class managing prepared statements and building queries to be sent to Cassandra.""" def __init__(self, hosts: List[str], keyspace: str, port: int): self._cluster = Cluster( hosts, port=port, execution_profiles=_execution_profiles ) self._session = self._cluster.connect(keyspace) self._cluster.register_user_type( keyspace, "microtimestamp_with_timezone", TimestampWithTimezone ) self._cluster.register_user_type(keyspace, "microtimestamp", Timestamp) self._cluster.register_user_type(keyspace, "person", Person) self._prepared_statements: Dict[str, PreparedStatement] = {} ########################## # Common utility functions ########################## MAX_RETRIES = 3 @retry( wait=wait_random_exponential(multiplier=1, max=10), stop=stop_after_attempt(MAX_RETRIES), retry=retry_if_exception_type(CoordinationFailure), ) def _execute_with_retries(self, statement, args) -> ResultSet: return self._session.execute(statement, args, timeout=1000.0) @_prepared_statement( "UPDATE object_count SET count = count + ? " "WHERE partition_key = 0 AND object_type = ?" ) def _increment_counter( self, object_type: str, nb: int, *, statement: PreparedStatement ) -> None: self._execute_with_retries(statement, [nb, object_type]) def _add_one(self, statement, object_type: str, obj, keys: List[str]) -> None: self._increment_counter(object_type, 1) self._execute_with_retries(statement, [getattr(obj, key) for key in keys]) def _get_random_row(self, statement) -> Optional[Row]: """Takes a prepared statement of the form "SELECT * FROM WHERE token() > ? LIMIT 1" and uses it to return a random row""" token = random.randint(TOKEN_BEGIN, TOKEN_END) rows = self._execute_with_retries(statement, [token]) if not rows: # There are no row with a greater token; wrap around to get # the row with the smallest token rows = self._execute_with_retries(statement, [TOKEN_BEGIN]) if rows: return rows.one() else: return None def _missing(self, statement, ids): res = self._execute_with_retries(statement, [ids]) found_ids = {id_ for (id_,) in res} return [id_ for id_ in ids if id_ not in found_ids] ########################## # 'content' table ########################## _content_pk = ["sha1", "sha1_git", "sha256", "blake2s256"] _content_keys = [ "sha1", "sha1_git", "sha256", "blake2s256", "length", "ctime", "status", ] def _content_add_finalize(self, statement: BoundStatement) -> None: """Returned currified by content_add_prepare, to be called when the content row should be added to the primary table.""" self._execute_with_retries(statement, None) self._increment_counter("content", 1) @_prepared_insert_statement("content", _content_keys) def content_add_prepare( self, content, *, statement ) -> Tuple[int, Callable[[], None]]: """Prepares insertion of a Content to the main 'content' table. Returns a token (to be used in secondary tables), and a function to be called to perform the insertion in the main table.""" statement = statement.bind( [getattr(content, key) for key in self._content_keys] ) # Type used for hashing keys (usually, it will be # cassandra.metadata.Murmur3Token) token_class = self._cluster.metadata.token_map.token_class # Token of the row when it will be inserted. This is equivalent to # "SELECT token({', '.join(self._content_pk)}) FROM content WHERE ..." # after the row is inserted; but we need the token to insert in the # index tables *before* inserting to the main 'content' table token = token_class.from_key(statement.routing_key).value assert TOKEN_BEGIN <= token <= TOKEN_END # Function to be called after the indexes contain their respective # row finalizer = functools.partial(self._content_add_finalize, statement) return (token, finalizer) @_prepared_statement( "SELECT * FROM content WHERE " + " AND ".join(map("%s = ?".__mod__, HASH_ALGORITHMS)) ) def content_get_from_pk( self, content_hashes: Dict[str, bytes], *, statement ) -> Optional[Row]: rows = list( self._execute_with_retries( statement, [content_hashes[algo] for algo in HASH_ALGORITHMS] ) ) assert len(rows) <= 1 if rows: return rows[0] else: return None @_prepared_statement( "SELECT * FROM content WHERE token(" + ", ".join(_content_pk) + ") = ?" ) def content_get_from_token(self, token, *, statement) -> Iterable[Row]: return self._execute_with_retries(statement, [token]) @_prepared_statement( "SELECT * FROM content WHERE token(%s) > ? LIMIT 1" % ", ".join(_content_pk) ) def content_get_random(self, *, statement) -> Optional[Row]: return self._get_random_row(statement) @_prepared_statement( ( "SELECT token({0}) AS tok, {1} FROM content " "WHERE token({0}) >= ? AND token({0}) <= ? LIMIT ?" ).format(", ".join(_content_pk), ", ".join(_content_keys)) ) def content_get_token_range( self, start: int, end: int, limit: int, *, statement ) -> Iterable[Row]: return self._execute_with_retries(statement, [start, end, limit]) ########################## # 'content_by_*' tables ########################## @_prepared_statement("SELECT sha1_git FROM content_by_sha1_git WHERE sha1_git IN ?") def content_missing_by_sha1_git( self, ids: List[bytes], *, statement ) -> List[bytes]: return self._missing(statement, ids) def content_index_add_one(self, algo: str, content: Content, token: int) -> None: """Adds a row mapping content[algo] to the token of the Content in the main 'content' table.""" query = ( f"INSERT INTO content_by_{algo} ({algo}, target_token) " f"VALUES (%s, %s)" ) self._execute_with_retries(query, [content.get_hash(algo), token]) def content_get_tokens_from_single_hash( self, algo: str, hash_: bytes ) -> Iterable[int]: assert algo in HASH_ALGORITHMS query = f"SELECT target_token FROM content_by_{algo} WHERE {algo} = %s" return (tok for (tok,) in self._execute_with_retries(query, [hash_])) ########################## # 'skipped_content' table ########################## _skipped_content_pk = ["sha1", "sha1_git", "sha256", "blake2s256"] _skipped_content_keys = [ "sha1", "sha1_git", "sha256", "blake2s256", "length", "ctime", "status", "reason", "origin", ] _magic_null_pk = b"" """ NULLs (or all-empty blobs) are not allowed in primary keys; instead use a special value that can't possibly be a valid hash. """ def _skipped_content_add_finalize(self, statement: BoundStatement) -> None: """Returned currified by skipped_content_add_prepare, to be called when the content row should be added to the primary table.""" self._execute_with_retries(statement, None) self._increment_counter("skipped_content", 1) @_prepared_insert_statement("skipped_content", _skipped_content_keys) def skipped_content_add_prepare( self, content, *, statement ) -> Tuple[int, Callable[[], None]]: """Prepares insertion of a Content to the main 'skipped_content' table. Returns a token (to be used in secondary tables), and a function to be called to perform the insertion in the main table.""" # Replace NULLs (which are not allowed in the partition key) with # an empty byte string content = content.to_dict() for key in self._skipped_content_pk: if content[key] is None: content[key] = self._magic_null_pk statement = statement.bind( [content.get(key) for key in self._skipped_content_keys] ) # Type used for hashing keys (usually, it will be # cassandra.metadata.Murmur3Token) token_class = self._cluster.metadata.token_map.token_class # Token of the row when it will be inserted. This is equivalent to # "SELECT token({', '.join(self._content_pk)}) # FROM skipped_content WHERE ..." # after the row is inserted; but we need the token to insert in the # index tables *before* inserting to the main 'skipped_content' table token = token_class.from_key(statement.routing_key).value assert TOKEN_BEGIN <= token <= TOKEN_END # Function to be called after the indexes contain their respective # row finalizer = functools.partial(self._skipped_content_add_finalize, statement) return (token, finalizer) @_prepared_statement( "SELECT * FROM skipped_content WHERE " + " AND ".join(map("%s = ?".__mod__, HASH_ALGORITHMS)) ) def skipped_content_get_from_pk( self, content_hashes: Dict[str, bytes], *, statement ) -> Optional[Row]: rows = list( self._execute_with_retries( statement, [ content_hashes[algo] or self._magic_null_pk for algo in HASH_ALGORITHMS ], ) ) assert len(rows) <= 1 if rows: # TODO: convert _magic_null_pk back to None? return rows[0] else: return None ########################## # 'skipped_content_by_*' tables ########################## def skipped_content_index_add_one( self, algo: str, content: SkippedContent, token: int ) -> None: """Adds a row mapping content[algo] to the token of the SkippedContent in the main 'skipped_content' table.""" query = ( f"INSERT INTO skipped_content_by_{algo} ({algo}, target_token) " f"VALUES (%s, %s)" ) self._execute_with_retries( query, [content.get_hash(algo) or self._magic_null_pk, token] ) ########################## # 'revision' table ########################## _revision_keys = [ "id", "date", "committer_date", "type", "directory", "message", "author", "committer", "synthetic", "metadata", "extra_headers", ] @_prepared_exists_statement("revision") def revision_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) @_prepared_insert_statement("revision", _revision_keys) def revision_add_one(self, revision: Dict[str, Any], *, statement) -> None: self._execute_with_retries( statement, [revision[key] for key in self._revision_keys] ) self._increment_counter("revision", 1) @_prepared_statement("SELECT id FROM revision WHERE id IN ?") def revision_get_ids(self, revision_ids, *, statement) -> ResultSet: return self._execute_with_retries(statement, [revision_ids]) @_prepared_statement("SELECT * FROM revision WHERE id IN ?") def revision_get(self, revision_ids, *, statement) -> ResultSet: return self._execute_with_retries(statement, [revision_ids]) @_prepared_statement("SELECT * FROM revision WHERE token(id) > ? LIMIT 1") def revision_get_random(self, *, statement) -> Optional[Row]: return self._get_random_row(statement) ########################## # 'revision_parent' table ########################## _revision_parent_keys = ["id", "parent_rank", "parent_id"] @_prepared_insert_statement("revision_parent", _revision_parent_keys) def revision_parent_add_one( self, id_: Sha1Git, parent_rank: int, parent_id: Sha1Git, *, statement ) -> None: self._execute_with_retries(statement, [id_, parent_rank, parent_id]) @_prepared_statement("SELECT parent_id FROM revision_parent WHERE id = ?") def revision_parent_get(self, revision_id: Sha1Git, *, statement) -> ResultSet: return self._execute_with_retries(statement, [revision_id]) ########################## # 'release' table ########################## _release_keys = [ "id", "target", "target_type", "date", "name", "message", "author", "synthetic", ] @_prepared_exists_statement("release") def release_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) @_prepared_insert_statement("release", _release_keys) def release_add_one(self, release: Dict[str, Any], *, statement) -> None: self._execute_with_retries( statement, [release[key] for key in self._release_keys] ) self._increment_counter("release", 1) @_prepared_statement("SELECT * FROM release WHERE id in ?") def release_get(self, release_ids: List[str], *, statement) -> None: return self._execute_with_retries(statement, [release_ids]) @_prepared_statement("SELECT * FROM release WHERE token(id) > ? LIMIT 1") def release_get_random(self, *, statement) -> Optional[Row]: return self._get_random_row(statement) ########################## # 'directory' table ########################## _directory_keys = ["id"] @_prepared_exists_statement("directory") def directory_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) @_prepared_insert_statement("directory", _directory_keys) def directory_add_one(self, directory_id: Sha1Git, *, statement) -> None: """Called after all calls to directory_entry_add_one, to commit/finalize the directory.""" self._execute_with_retries(statement, [directory_id]) self._increment_counter("directory", 1) @_prepared_statement("SELECT * FROM directory WHERE token(id) > ? LIMIT 1") def directory_get_random(self, *, statement) -> Optional[Row]: return self._get_random_row(statement) ########################## # 'directory_entry' table ########################## _directory_entry_keys = ["directory_id", "name", "type", "target", "perms"] @_prepared_insert_statement("directory_entry", _directory_entry_keys) def directory_entry_add_one(self, entry: Dict[str, Any], *, statement) -> None: self._execute_with_retries( statement, [entry[key] for key in self._directory_entry_keys] ) @_prepared_statement("SELECT * FROM directory_entry WHERE directory_id IN ?") def directory_entry_get(self, directory_ids, *, statement) -> ResultSet: return self._execute_with_retries(statement, [directory_ids]) ########################## # 'snapshot' table ########################## _snapshot_keys = ["id"] @_prepared_exists_statement("snapshot") def snapshot_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) @_prepared_insert_statement("snapshot", _snapshot_keys) def snapshot_add_one(self, snapshot_id: Sha1Git, *, statement) -> None: self._execute_with_retries(statement, [snapshot_id]) self._increment_counter("snapshot", 1) @_prepared_statement("SELECT * FROM snapshot WHERE id = ?") def snapshot_get(self, snapshot_id: Sha1Git, *, statement) -> ResultSet: return self._execute_with_retries(statement, [snapshot_id]) @_prepared_statement("SELECT * FROM snapshot WHERE token(id) > ? LIMIT 1") def snapshot_get_random(self, *, statement) -> Optional[Row]: return self._get_random_row(statement) ########################## # 'snapshot_branch' table ########################## _snapshot_branch_keys = ["snapshot_id", "name", "target_type", "target"] @_prepared_insert_statement("snapshot_branch", _snapshot_branch_keys) def snapshot_branch_add_one(self, branch: Dict[str, Any], *, statement) -> None: self._execute_with_retries( statement, [branch[key] for key in self._snapshot_branch_keys] ) @_prepared_statement( "SELECT ascii_bins_count(target_type) AS counts " "FROM snapshot_branch " "WHERE snapshot_id = ? " ) def snapshot_count_branches(self, snapshot_id: Sha1Git, *, statement) -> ResultSet: return self._execute_with_retries(statement, [snapshot_id]) @_prepared_statement( "SELECT * FROM snapshot_branch WHERE snapshot_id = ? AND name >= ? LIMIT ?" ) def snapshot_branch_get( self, snapshot_id: Sha1Git, from_: bytes, limit: int, *, statement ) -> None: return self._execute_with_retries(statement, [snapshot_id, from_, limit]) ########################## # 'origin' table ########################## origin_keys = ["sha1", "url", "type", "next_visit_id"] @_prepared_statement( "INSERT INTO origin (sha1, url, next_visit_id) " "VALUES (?, ?, 1) IF NOT EXISTS" ) def origin_add_one(self, origin: Origin, *, statement) -> None: self._execute_with_retries(statement, [hash_url(origin.url), origin.url]) self._increment_counter("origin", 1) @_prepared_statement("SELECT * FROM origin WHERE sha1 = ?") def origin_get_by_sha1(self, sha1: bytes, *, statement) -> ResultSet: return self._execute_with_retries(statement, [sha1]) def origin_get_by_url(self, url: str) -> ResultSet: return self.origin_get_by_sha1(hash_url(url)) @_prepared_statement( f'SELECT token(sha1) AS tok, {", ".join(origin_keys)} ' f"FROM origin WHERE token(sha1) >= ? LIMIT ?" ) def origin_list(self, start_token: int, limit: int, *, statement) -> ResultSet: return self._execute_with_retries(statement, [start_token, limit]) @_prepared_statement("SELECT * FROM origin") def origin_iter_all(self, *, statement) -> ResultSet: return self._execute_with_retries(statement, []) @_prepared_statement("SELECT next_visit_id FROM origin WHERE sha1 = ?") def _origin_get_next_visit_id(self, origin_sha1: bytes, *, statement) -> int: rows = list(self._execute_with_retries(statement, [origin_sha1])) assert len(rows) == 1 # TODO: error handling return rows[0].next_visit_id @_prepared_statement( "UPDATE origin SET next_visit_id=? WHERE sha1 = ? IF next_visit_id=?" ) def origin_generate_unique_visit_id(self, origin_url: str, *, statement) -> int: origin_sha1 = hash_url(origin_url) next_id = self._origin_get_next_visit_id(origin_sha1) while True: res = list( self._execute_with_retries( statement, [next_id + 1, origin_sha1, next_id] ) ) assert len(res) == 1 if res[0].applied: # No data race return next_id else: # Someone else updated it before we did, let's try again next_id = res[0].next_visit_id # TODO: abort after too many attempts return next_id ########################## # 'origin_visit' table ########################## _origin_visit_keys = [ "origin", "visit", "type", "date", ] @_prepared_statement( "SELECT * FROM origin_visit WHERE origin = ? AND visit > ? " "ORDER BY visit ASC" ) def _origin_visit_get_pagination_asc_no_limit( self, origin_url: str, last_visit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, last_visit]) @_prepared_statement( "SELECT * FROM origin_visit WHERE origin = ? AND visit > ? " "ORDER BY visit ASC " "LIMIT ?" ) def _origin_visit_get_pagination_asc_limit( self, origin_url: str, last_visit: int, limit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, last_visit, limit]) @_prepared_statement( "SELECT * FROM origin_visit WHERE origin = ? AND visit < ? " "ORDER BY visit DESC" ) def _origin_visit_get_pagination_desc_no_limit( self, origin_url: str, last_visit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, last_visit]) @_prepared_statement( "SELECT * FROM origin_visit WHERE origin = ? AND visit < ? " "ORDER BY visit DESC " "LIMIT ?" ) def _origin_visit_get_pagination_desc_limit( self, origin_url: str, last_visit: int, limit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, last_visit, limit]) @_prepared_statement( "SELECT * FROM origin_visit WHERE origin = ? ORDER BY visit ASC LIMIT ?" ) def _origin_visit_get_no_pagination_asc_limit( self, origin_url: str, limit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, limit]) @_prepared_statement( "SELECT * FROM origin_visit WHERE origin = ? ORDER BY visit ASC " ) def _origin_visit_get_no_pagination_asc_no_limit( self, origin_url: str, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url]) @_prepared_statement( "SELECT * FROM origin_visit WHERE origin = ? ORDER BY visit DESC" ) def _origin_visit_get_no_pagination_desc_no_limit( self, origin_url: str, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url]) @_prepared_statement( "SELECT * FROM origin_visit WHERE origin = ? ORDER BY visit DESC LIMIT ?" ) def _origin_visit_get_no_pagination_desc_limit( self, origin_url: str, limit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, limit]) def origin_visit_get( self, origin_url: str, last_visit: Optional[int], limit: Optional[int], order: str = "asc", ) -> ResultSet: order = order.lower() assert order in ["asc", "desc"] args: List[Any] = [origin_url] if last_visit is not None: page_name = "pagination" args.append(last_visit) else: page_name = "no_pagination" if limit is not None: limit_name = "limit" args.append(limit) else: limit_name = "no_limit" method_name = f"_origin_visit_get_{page_name}_{order}_{limit_name}" origin_visit_get_method = getattr(self, method_name) return origin_visit_get_method(*args) @_prepared_insert_statement("origin_visit", _origin_visit_keys) def origin_visit_add_one(self, visit: OriginVisit, *, statement) -> None: self._add_one(statement, "origin_visit", visit, self._origin_visit_keys) _origin_visit_status_keys = [ "origin", "visit", "date", "status", "snapshot", "metadata", ] @_prepared_insert_statement("origin_visit_status", _origin_visit_status_keys) def origin_visit_status_add_one( self, visit_update: OriginVisitStatus, *, statement ) -> None: assert self._origin_visit_status_keys[-1] == "metadata" keys = self._origin_visit_status_keys - metadata = json.dumps(visit_update.metadata) + metadata = json.dumps( + dict(visit_update.metadata) if visit_update.metadata is not None else None + ) self._execute_with_retries( statement, [getattr(visit_update, key) for key in keys[:-1]] + [metadata] ) def origin_visit_status_get_latest(self, origin: str, visit: int,) -> Optional[Row]: """Given an origin visit id, return its latest origin_visit_status """ rows = self.origin_visit_status_get(origin, visit) return rows[0] if rows else None @_prepared_statement( "SELECT * FROM origin_visit_status " "WHERE origin = ? AND visit = ? " "ORDER BY date DESC" ) def origin_visit_status_get( self, origin: str, visit: int, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, *, statement, ) -> List[Row]: """Return all origin visit statuses for a given visit """ return list(self._execute_with_retries(statement, [origin, visit])) @_prepared_statement("SELECT * FROM origin_visit WHERE origin = ? AND visit = ?") def origin_visit_get_one( self, origin_url: str, visit_id: int, *, statement ) -> Optional[Row]: # TODO: error handling rows = list(self._execute_with_retries(statement, [origin_url, visit_id])) if rows: return rows[0] else: return None @_prepared_statement("SELECT * FROM origin_visit WHERE origin = ?") def origin_visit_get_all(self, origin_url: str, *, statement) -> ResultSet: return self._execute_with_retries(statement, [origin_url]) @_prepared_statement("SELECT * FROM origin_visit WHERE token(origin) >= ?") def _origin_visit_iter_from(self, min_token: int, *, statement) -> Iterator[Row]: yield from self._execute_with_retries(statement, [min_token]) @_prepared_statement("SELECT * FROM origin_visit WHERE token(origin) < ?") def _origin_visit_iter_to(self, max_token: int, *, statement) -> Iterator[Row]: yield from self._execute_with_retries(statement, [max_token]) def origin_visit_iter(self, start_token: int) -> Iterator[Row]: """Returns all origin visits in order from this token, and wraps around the token space.""" yield from self._origin_visit_iter_from(start_token) yield from self._origin_visit_iter_to(start_token) ########################## # 'metadata_authority' table ########################## _metadata_authority_keys = ["url", "type", "metadata"] @_prepared_insert_statement("metadata_authority", _metadata_authority_keys) def metadata_authority_add(self, url, type, metadata, *, statement): return self._execute_with_retries(statement, [url, type, metadata]) @_prepared_statement("SELECT * from metadata_authority WHERE type = ? AND url = ?") def metadata_authority_get(self, type, url, *, statement) -> Optional[Row]: return next(iter(self._execute_with_retries(statement, [type, url])), None) ########################## # 'metadata_fetcher' table ########################## _metadata_fetcher_keys = ["name", "version", "metadata"] @_prepared_insert_statement("metadata_fetcher", _metadata_fetcher_keys) def metadata_fetcher_add(self, name, version, metadata, *, statement): return self._execute_with_retries(statement, [name, version, metadata]) @_prepared_statement( "SELECT * from metadata_fetcher WHERE name = ? AND version = ?" ) def metadata_fetcher_get(self, name, version, *, statement) -> Optional[Row]: return next(iter(self._execute_with_retries(statement, [name, version])), None) ######################### # 'object_metadata' table ######################### _object_metadata_keys = [ "type", "id", "authority_type", "authority_url", "discovery_date", "fetcher_name", "fetcher_version", "format", "metadata", "origin", "visit", "snapshot", "release", "revision", "path", "directory", ] @_prepared_statement( f"INSERT INTO object_metadata ({', '.join(_object_metadata_keys)}) " f"VALUES ({', '.join('?' for _ in _object_metadata_keys)})" ) def object_metadata_add( self, object_type: str, id: str, authority_type, authority_url, discovery_date, fetcher_name, fetcher_version, format, metadata, context: Dict[str, Union[str, bytes, int]], *, statement, ): params = [ object_type, id, authority_type, authority_url, discovery_date, fetcher_name, fetcher_version, format, metadata, ] params.extend( context.get(key) for key in extrinsic_metadata.CONTEXT_KEYS[object_type] ) return self._execute_with_retries(statement, params,) @_prepared_statement( "SELECT * from object_metadata " "WHERE id=? AND authority_url=? AND discovery_date>? AND authority_type=?" ) def object_metadata_get_after_date( self, id: str, authority_type: str, authority_url: str, after: datetime.datetime, *, statement, ): return self._execute_with_retries( statement, [id, authority_url, after, authority_type] ) @_prepared_statement( "SELECT * from object_metadata " "WHERE id=? AND authority_type=? AND authority_url=? " "AND (discovery_date, fetcher_name, fetcher_version) > (?, ?, ?)" ) def object_metadata_get_after_date_and_fetcher( self, id: str, authority_type: str, authority_url: str, after_date: datetime.datetime, after_fetcher_name: str, after_fetcher_version: str, *, statement, ): return self._execute_with_retries( statement, [ id, authority_type, authority_url, after_date, after_fetcher_name, after_fetcher_version, ], ) @_prepared_statement( "SELECT * from object_metadata " "WHERE id=? AND authority_url=? AND authority_type=?" ) def object_metadata_get( self, id: str, authority_type: str, authority_url: str, *, statement ) -> Iterable[Row]: return self._execute_with_retries( statement, [id, authority_url, authority_type] ) ########################## # Miscellaneous ########################## @_prepared_statement("SELECT uuid() FROM revision LIMIT 1;") def check_read(self, *, statement): self._execute_with_retries(statement, []) @_prepared_statement( "SELECT object_type, count FROM object_count WHERE partition_key=0" ) def stat_counters(self, *, statement) -> ResultSet: return self._execute_with_retries(statement, []) diff --git a/swh/storage/db.py b/swh/storage/db.py index 41e83863..3728ed92 100644 --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -1,1288 +1,1292 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import random import select from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from swh.core.db import BaseDb -from swh.core.db.db_utils import stored_procedure, jsonize +from swh.core.db.db_utils import stored_procedure, jsonize as _jsonize from swh.core.db.db_utils import execute_values_generator from swh.model.model import OriginVisit, OriginVisitStatus, SHA1_SIZE +def jsonize(d): + return _jsonize(dict(d) if d is not None else None) + + class Db(BaseDb): """Proxy to the SWH DB, with wrappers around stored procedures """ def mktemp_dir_entry(self, entry_type, cur=None): self._cursor(cur).execute( "SELECT swh_mktemp_dir_entry(%s)", (("directory_entry_%s" % entry_type),) ) @stored_procedure("swh_mktemp_revision") def mktemp_revision(self, cur=None): pass @stored_procedure("swh_mktemp_release") def mktemp_release(self, cur=None): pass @stored_procedure("swh_mktemp_snapshot_branch") def mktemp_snapshot_branch(self, cur=None): pass def register_listener(self, notify_queue, cur=None): """Register a listener for NOTIFY queue `notify_queue`""" self._cursor(cur).execute("LISTEN %s" % notify_queue) def listen_notifies(self, timeout): """Listen to notifications for `timeout` seconds""" if select.select([self.conn], [], [], timeout) == ([], [], []): return else: self.conn.poll() while self.conn.notifies: yield self.conn.notifies.pop(0) @stored_procedure("swh_content_add") def content_add_from_temp(self, cur=None): pass @stored_procedure("swh_directory_add") def directory_add_from_temp(self, cur=None): pass @stored_procedure("swh_skipped_content_add") def skipped_content_add_from_temp(self, cur=None): pass @stored_procedure("swh_revision_add") def revision_add_from_temp(self, cur=None): pass @stored_procedure("swh_release_add") def release_add_from_temp(self, cur=None): pass def content_update_from_temp(self, keys_to_update, cur=None): cur = self._cursor(cur) cur.execute( """select swh_content_update(ARRAY[%s] :: text[])""" % keys_to_update ) content_get_metadata_keys = [ "sha1", "sha1_git", "sha256", "blake2s256", "length", "status", ] content_add_keys = content_get_metadata_keys + ["ctime"] skipped_content_keys = [ "sha1", "sha1_git", "sha256", "blake2s256", "length", "reason", "status", "origin", ] def content_get_metadata_from_sha1s(self, sha1s, cur=None): cur = self._cursor(cur) yield from execute_values_generator( cur, """ select t.sha1, %s from (values %%s) as t (sha1) inner join content using (sha1) """ % ", ".join(self.content_get_metadata_keys[1:]), ((sha1,) for sha1 in sha1s), ) def content_get_range(self, start, end, limit=None, cur=None): """Retrieve contents within range [start, end]. """ cur = self._cursor(cur) query = """select %s from content where %%s <= sha1 and sha1 <= %%s order by sha1 limit %%s""" % ", ".join( self.content_get_metadata_keys ) cur.execute(query, (start, end, limit)) yield from cur content_hash_keys = ["sha1", "sha1_git", "sha256", "blake2s256"] def content_missing_from_list(self, contents, cur=None): cur = self._cursor(cur) keys = ", ".join(self.content_hash_keys) equality = " AND ".join( ("t.%s = c.%s" % (key, key)) for key in self.content_hash_keys ) yield from execute_values_generator( cur, """ SELECT %s FROM (VALUES %%s) as t(%s) WHERE NOT EXISTS ( SELECT 1 FROM content c WHERE %s ) """ % (keys, keys, equality), (tuple(c[key] for key in self.content_hash_keys) for c in contents), ) def content_missing_per_sha1(self, sha1s, cur=None): cur = self._cursor(cur) yield from execute_values_generator( cur, """ SELECT t.sha1 FROM (VALUES %s) AS t(sha1) WHERE NOT EXISTS ( SELECT 1 FROM content c WHERE c.sha1 = t.sha1 )""", ((sha1,) for sha1 in sha1s), ) def content_missing_per_sha1_git(self, contents, cur=None): cur = self._cursor(cur) yield from execute_values_generator( cur, """ SELECT t.sha1_git FROM (VALUES %s) AS t(sha1_git) WHERE NOT EXISTS ( SELECT 1 FROM content c WHERE c.sha1_git = t.sha1_git )""", ((sha1,) for sha1 in contents), ) def skipped_content_missing(self, contents, cur=None): if not contents: return [] cur = self._cursor(cur) query = """SELECT * FROM (VALUES %s) AS t (%s) WHERE not exists (SELECT 1 FROM skipped_content s WHERE s.sha1 is not distinct from t.sha1::sha1 and s.sha1_git is not distinct from t.sha1_git::sha1 and s.sha256 is not distinct from t.sha256::bytea);""" % ( (", ".join("%s" for _ in contents)), ", ".join(self.content_hash_keys), ) cur.execute( query, [tuple(cont[key] for key in self.content_hash_keys) for cont in contents], ) yield from cur def snapshot_exists(self, snapshot_id, cur=None): """Check whether a snapshot with the given id exists""" cur = self._cursor(cur) cur.execute("""SELECT 1 FROM snapshot where id=%s""", (snapshot_id,)) return bool(cur.fetchone()) def snapshot_missing_from_list(self, snapshots, cur=None): cur = self._cursor(cur) yield from execute_values_generator( cur, """ SELECT id FROM (VALUES %s) as t(id) WHERE NOT EXISTS ( SELECT 1 FROM snapshot d WHERE d.id = t.id ) """, ((id,) for id in snapshots), ) def snapshot_add(self, snapshot_id, cur=None): """Add a snapshot from the temporary table""" cur = self._cursor(cur) cur.execute("""SELECT swh_snapshot_add(%s)""", (snapshot_id,)) snapshot_count_cols = ["target_type", "count"] def snapshot_count_branches(self, snapshot_id, cur=None): cur = self._cursor(cur) query = """\ SELECT %s FROM swh_snapshot_count_branches(%%s) """ % ", ".join( self.snapshot_count_cols ) cur.execute(query, (snapshot_id,)) yield from cur snapshot_get_cols = ["snapshot_id", "name", "target", "target_type"] def snapshot_get_by_id( self, snapshot_id, branches_from=b"", branches_count=None, target_types=None, cur=None, ): cur = self._cursor(cur) query = """\ SELECT %s FROM swh_snapshot_get_by_id(%%s, %%s, %%s, %%s :: snapshot_target[]) """ % ", ".join( self.snapshot_get_cols ) cur.execute(query, (snapshot_id, branches_from, branches_count, target_types)) yield from cur def snapshot_get_by_origin_visit(self, origin_url, visit_id, cur=None): cur = self._cursor(cur) query = """\ SELECT ovs.snapshot FROM origin_visit ov INNER JOIN origin o ON o.id = ov.origin INNER JOIN origin_visit_status ovs ON ov.origin = ovs.origin AND ov.visit = ovs.visit WHERE o.url=%s AND ov.visit=%s ORDER BY ovs.date DESC LIMIT 1 """ cur.execute(query, (origin_url, visit_id)) ret = cur.fetchone() if ret: return ret[0] def snapshot_get_random(self, cur=None): return self._get_random_row_from_table("snapshot", ["id"], "id", cur) content_find_cols = [ "sha1", "sha1_git", "sha256", "blake2s256", "length", "ctime", "status", ] def content_find( self, sha1=None, sha1_git=None, sha256=None, blake2s256=None, cur=None ): """Find the content optionally on a combination of the following checksums sha1, sha1_git, sha256 or blake2s256. Args: sha1: sha1 content git_sha1: the sha1 computed `a la git` sha1 of the content sha256: sha256 content blake2s256: blake2s256 content Returns: The tuple (sha1, sha1_git, sha256, blake2s256) if found or None. """ cur = self._cursor(cur) checksum_dict = { "sha1": sha1, "sha1_git": sha1_git, "sha256": sha256, "blake2s256": blake2s256, } where_parts = [] args = [] # Adds only those keys which have value other than None for algorithm in checksum_dict: if checksum_dict[algorithm] is not None: args.append(checksum_dict[algorithm]) where_parts.append(algorithm + "= %s") query = " AND ".join(where_parts) cur.execute( """SELECT %s FROM content WHERE %s """ % (",".join(self.content_find_cols), query), args, ) content = cur.fetchall() return content def content_get_random(self, cur=None): return self._get_random_row_from_table("content", ["sha1_git"], "sha1_git", cur) def directory_missing_from_list(self, directories, cur=None): cur = self._cursor(cur) yield from execute_values_generator( cur, """ SELECT id FROM (VALUES %s) as t(id) WHERE NOT EXISTS ( SELECT 1 FROM directory d WHERE d.id = t.id ) """, ((id,) for id in directories), ) directory_ls_cols = [ "dir_id", "type", "target", "name", "perms", "status", "sha1", "sha1_git", "sha256", "length", ] def directory_walk_one(self, directory, cur=None): cur = self._cursor(cur) cols = ", ".join(self.directory_ls_cols) query = "SELECT %s FROM swh_directory_walk_one(%%s)" % cols cur.execute(query, (directory,)) yield from cur def directory_walk(self, directory, cur=None): cur = self._cursor(cur) cols = ", ".join(self.directory_ls_cols) query = "SELECT %s FROM swh_directory_walk(%%s)" % cols cur.execute(query, (directory,)) yield from cur def directory_entry_get_by_path(self, directory, paths, cur=None): """Retrieve a directory entry by path. """ cur = self._cursor(cur) cols = ", ".join(self.directory_ls_cols) query = "SELECT %s FROM swh_find_directory_entry_by_path(%%s, %%s)" % cols cur.execute(query, (directory, paths)) data = cur.fetchone() if set(data) == {None}: return None return data def directory_get_random(self, cur=None): return self._get_random_row_from_table("directory", ["id"], "id", cur) def revision_missing_from_list(self, revisions, cur=None): cur = self._cursor(cur) yield from execute_values_generator( cur, """ SELECT id FROM (VALUES %s) as t(id) WHERE NOT EXISTS ( SELECT 1 FROM revision r WHERE r.id = t.id ) """, ((id,) for id in revisions), ) revision_add_cols = [ "id", "date", "date_offset", "date_neg_utc_offset", "committer_date", "committer_date_offset", "committer_date_neg_utc_offset", "type", "directory", "message", "author_fullname", "author_name", "author_email", "committer_fullname", "committer_name", "committer_email", "metadata", "synthetic", "extra_headers", ] revision_get_cols = revision_add_cols + ["parents"] def origin_visit_add(self, origin, ts, type, cur=None): """Add a new origin_visit for origin origin at timestamp ts. Args: origin: origin concerned by the visit ts: the date of the visit type: type of loader for the visit Returns: The new visit index step for that origin """ cur = self._cursor(cur) self._cursor(cur).execute( "SELECT swh_origin_visit_add(%s, %s, %s)", (origin, ts, type) ) return cur.fetchone()[0] origin_visit_status_cols = [ "origin", "visit", "date", "status", "snapshot", "metadata", ] def origin_visit_status_add( self, visit_status: OriginVisitStatus, cur=None ) -> None: """Add new origin visit status """ assert self.origin_visit_status_cols[0] == "origin" assert self.origin_visit_status_cols[-1] == "metadata" cols = self.origin_visit_status_cols[1:-1] cur = self._cursor(cur) cur.execute( f"WITH origin_id as (select id from origin where url=%s) " f"INSERT INTO origin_visit_status " f"(origin, {', '.join(cols)}, metadata) " f"VALUES ((select id from origin_id), " f"{', '.join(['%s']*len(cols))}, %s) " f"ON CONFLICT (origin, visit, date) do nothing", [visit_status.origin] + [getattr(visit_status, key) for key in cols] + [jsonize(visit_status.metadata)], ) def origin_visit_add_with_id(self, origin_visit: OriginVisit, cur=None) -> None: """Insert origin visit when id are already set """ ov = origin_visit assert ov.visit is not None cur = self._cursor(cur) origin_visit_cols = ["origin", "visit", "date", "type"] query = """INSERT INTO origin_visit ({cols}) VALUES ((select id from origin where url=%s), {values}) ON CONFLICT (origin, visit) DO NOTHING""".format( cols=", ".join(origin_visit_cols), values=", ".join("%s" for col in origin_visit_cols[1:]), ) cur.execute(query, (ov.origin, ov.visit, ov.date, ov.type)) origin_visit_get_cols = [ "origin", "visit", "date", "type", "status", "metadata", "snapshot", ] origin_visit_select_cols = [ "o.url AS origin", "ov.visit", "ov.date", "ov.type AS type", "ovs.status", "ovs.metadata", "ovs.snapshot", ] origin_visit_status_select_cols = [ "o.url AS origin", "ovs.visit", "ovs.date", "ovs.status", "ovs.snapshot", "ovs.metadata", ] def _make_origin_visit_status( self, row: Optional[Tuple[Any]] ) -> Optional[Dict[str, Any]]: """Make an origin_visit_status dict out of a row """ if not row: return None return dict(zip(self.origin_visit_status_cols, row)) def origin_visit_status_get_latest( self, origin_url: str, visit: int, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, cur=None, ) -> Optional[Dict[str, Any]]: """Given an origin visit id, return its latest origin_visit_status """ cur = self._cursor(cur) query_parts = [ "SELECT %s" % ", ".join(self.origin_visit_status_select_cols), "FROM origin_visit_status ovs ", "INNER JOIN origin o ON o.id = ovs.origin", ] query_parts.append("WHERE o.url = %s") query_params: List[Any] = [origin_url] query_parts.append("AND ovs.visit = %s") query_params.append(visit) if require_snapshot: query_parts.append("AND ovs.snapshot is not null") if allowed_statuses: query_parts.append("AND ovs.status IN %s") query_params.append(tuple(allowed_statuses)) query_parts.append("ORDER BY ovs.date DESC LIMIT 1") query = "\n".join(query_parts) cur.execute(query, tuple(query_params)) row = cur.fetchone() return self._make_origin_visit_status(row) def origin_visit_get_all( self, origin_id, last_visit=None, order="asc", limit=None, cur=None ): """Retrieve all visits for origin with id origin_id. Args: origin_id: The occurrence's origin Yields: The visits for that origin """ cur = self._cursor(cur) assert order.lower() in ["asc", "desc"] query_parts = [ "SELECT DISTINCT ON (ov.visit) %s " % ", ".join(self.origin_visit_select_cols), "FROM origin_visit ov", "INNER JOIN origin o ON o.id = ov.origin", "INNER JOIN origin_visit_status ovs", "ON ov.origin = ovs.origin AND ov.visit = ovs.visit", ] query_parts.append("WHERE o.url = %s") query_params: List[Any] = [origin_id] if last_visit is not None: op_comparison = ">" if order == "asc" else "<" query_parts.append(f"and ov.visit {op_comparison} %s") query_params.append(last_visit) if order == "asc": query_parts.append("ORDER BY ov.visit ASC, ovs.date DESC") elif order == "desc": query_parts.append("ORDER BY ov.visit DESC, ovs.date DESC") else: assert False if limit is not None: query_parts.append("LIMIT %s") query_params.append(limit) query = "\n".join(query_parts) cur.execute(query, tuple(query_params)) yield from cur def origin_visit_get(self, origin_id, visit_id, cur=None): """Retrieve information on visit visit_id of origin origin_id. Args: origin_id: the origin concerned visit_id: The visit step for that origin Returns: The origin_visit information """ cur = self._cursor(cur) query = """\ SELECT %s FROM origin_visit ov INNER JOIN origin o ON o.id = ov.origin INNER JOIN origin_visit_status ovs ON ov.origin = ovs.origin AND ov.visit = ovs.visit WHERE o.url = %%s AND ov.visit = %%s ORDER BY ovs.date DESC LIMIT 1 """ % ( ", ".join(self.origin_visit_select_cols) ) cur.execute(query, (origin_id, visit_id)) r = cur.fetchall() if not r: return None return r[0] def origin_visit_find_by_date(self, origin, visit_date, cur=None): cur = self._cursor(cur) cur.execute( "SELECT * FROM swh_visit_find_by_date(%s, %s)", (origin, visit_date) ) rows = cur.fetchall() if rows: visit = dict(zip(self.origin_visit_get_cols, rows[0])) visit["origin"] = origin return visit def origin_visit_exists(self, origin_id, visit_id, cur=None): """Check whether an origin visit with the given ids exists""" cur = self._cursor(cur) query = "SELECT 1 FROM origin_visit where origin = %s AND visit = %s" cur.execute(query, (origin_id, visit_id)) return bool(cur.fetchone()) def origin_visit_get_latest( self, origin_id: str, type: Optional[str], allowed_statuses: Optional[Iterable[str]], require_snapshot: bool, cur=None, ): """Retrieve the most recent origin_visit of the given origin, with optional filters. Args: origin_id: the origin concerned type: Optional visit type to filter on allowed_statuses: the visit statuses allowed for the returned visit require_snapshot (bool): If True, only a visit with a known snapshot will be returned. Returns: The origin_visit information, or None if no visit matches. """ cur = self._cursor(cur) query_parts = [ "SELECT %s" % ", ".join(self.origin_visit_select_cols), "FROM origin_visit ov ", "INNER JOIN origin o ON o.id = ov.origin", "INNER JOIN origin_visit_status ovs ", "ON o.id = ovs.origin AND ov.visit = ovs.visit ", ] query_parts.append("WHERE o.url = %s") query_params: List[Any] = [origin_id] if type is not None: query_parts.append("AND ov.type = %s") query_params.append(type) if require_snapshot: query_parts.append("AND ovs.snapshot is not null") if allowed_statuses: query_parts.append("AND ovs.status IN %s") query_params.append(tuple(allowed_statuses)) query_parts.append( "ORDER BY ov.date DESC, ov.visit DESC, ovs.date DESC LIMIT 1" ) query = "\n".join(query_parts) cur.execute(query, tuple(query_params)) r = cur.fetchone() if not r: return None return r def origin_visit_get_random(self, type, cur=None): """Randomly select one origin visit that was full and in the last 3 months """ cur = self._cursor(cur) columns = ",".join(self.origin_visit_select_cols) query = f"""select {columns} from origin_visit ov inner join origin o on ov.origin=o.id inner join origin_visit_status ovs on ov.origin = ovs.origin and ov.visit = ovs.visit where ovs.status='full' and ov.type=%s and ov.date > now() - '3 months'::interval and random() < 0.1 limit 1 """ cur.execute(query, (type,)) return cur.fetchone() @staticmethod def mangle_query_key(key, main_table): if key == "id": return "t.id" if key == "parents": return """ ARRAY( SELECT rh.parent_id::bytea FROM revision_history rh WHERE rh.id = t.id ORDER BY rh.parent_rank )""" if "_" not in key: return "%s.%s" % (main_table, key) head, tail = key.split("_", 1) if head in ("author", "committer") and tail in ( "name", "email", "id", "fullname", ): return "%s.%s" % (head, tail) return "%s.%s" % (main_table, key) def revision_get_from_list(self, revisions, cur=None): cur = self._cursor(cur) query_keys = ", ".join( self.mangle_query_key(k, "revision") for k in self.revision_get_cols ) yield from execute_values_generator( cur, """ SELECT %s FROM (VALUES %%s) as t(sortkey, id) LEFT JOIN revision ON t.id = revision.id LEFT JOIN person author ON revision.author = author.id LEFT JOIN person committer ON revision.committer = committer.id ORDER BY sortkey """ % query_keys, ((sortkey, id) for sortkey, id in enumerate(revisions)), ) def revision_log(self, root_revisions, limit=None, cur=None): cur = self._cursor(cur) query = """SELECT %s FROM swh_revision_log(%%s, %%s) """ % ", ".join( self.revision_get_cols ) cur.execute(query, (root_revisions, limit)) yield from cur revision_shortlog_cols = ["id", "parents"] def revision_shortlog(self, root_revisions, limit=None, cur=None): cur = self._cursor(cur) query = """SELECT %s FROM swh_revision_list(%%s, %%s) """ % ", ".join( self.revision_shortlog_cols ) cur.execute(query, (root_revisions, limit)) yield from cur def revision_get_random(self, cur=None): return self._get_random_row_from_table("revision", ["id"], "id", cur) def release_missing_from_list(self, releases, cur=None): cur = self._cursor(cur) yield from execute_values_generator( cur, """ SELECT id FROM (VALUES %s) as t(id) WHERE NOT EXISTS ( SELECT 1 FROM release r WHERE r.id = t.id ) """, ((id,) for id in releases), ) object_find_by_sha1_git_cols = ["sha1_git", "type"] def object_find_by_sha1_git(self, ids, cur=None): cur = self._cursor(cur) yield from execute_values_generator( cur, """ WITH t (sha1_git) AS (VALUES %s), known_objects as (( select id as sha1_git, 'release'::object_type as type, object_id from release r where exists (select 1 from t where t.sha1_git = r.id) ) union all ( select id as sha1_git, 'revision'::object_type as type, object_id from revision r where exists (select 1 from t where t.sha1_git = r.id) ) union all ( select id as sha1_git, 'directory'::object_type as type, object_id from directory d where exists (select 1 from t where t.sha1_git = d.id) ) union all ( select sha1_git as sha1_git, 'content'::object_type as type, object_id from content c where exists (select 1 from t where t.sha1_git = c.sha1_git) )) select t.sha1_git as sha1_git, k.type from t left join known_objects k on t.sha1_git = k.sha1_git """, ((id,) for id in ids), ) def stat_counters(self, cur=None): cur = self._cursor(cur) cur.execute("SELECT * FROM swh_stat_counters()") yield from cur def origin_add(self, url, cur=None): """Insert a new origin and return the new identifier.""" insert = """INSERT INTO origin (url) values (%s) RETURNING url""" cur.execute(insert, (url,)) return cur.fetchone()[0] origin_cols = ["url"] def origin_get_by_url(self, origins, cur=None): """Retrieve origin `(type, url)` from urls if found.""" cur = self._cursor(cur) query = """SELECT %s FROM (VALUES %%s) as t(url) LEFT JOIN origin ON t.url = origin.url """ % ",".join( "origin." + col for col in self.origin_cols ) yield from execute_values_generator(cur, query, ((url,) for url in origins)) def origin_get_by_sha1(self, sha1s, cur=None): """Retrieve origin urls from sha1s if found.""" cur = self._cursor(cur) query = """SELECT %s FROM (VALUES %%s) as t(sha1) LEFT JOIN origin ON t.sha1 = digest(origin.url, 'sha1') """ % ",".join( "origin." + col for col in self.origin_cols ) yield from execute_values_generator(cur, query, ((sha1,) for sha1 in sha1s)) def origin_id_get_by_url(self, origins, cur=None): """Retrieve origin `(type, url)` from urls if found.""" cur = self._cursor(cur) query = """SELECT id FROM (VALUES %s) as t(url) LEFT JOIN origin ON t.url = origin.url """ for row in execute_values_generator(cur, query, ((url,) for url in origins)): yield row[0] origin_get_range_cols = ["id", "url"] def origin_get_range(self, origin_from=1, origin_count=100, cur=None): """Retrieve ``origin_count`` origins whose ids are greater or equal than ``origin_from``. Origins are sorted by id before retrieving them. Args: origin_from (int): the minimum id of origins to retrieve origin_count (int): the maximum number of origins to retrieve """ cur = self._cursor(cur) query = """SELECT %s FROM origin WHERE id >= %%s ORDER BY id LIMIT %%s """ % ",".join( self.origin_get_range_cols ) cur.execute(query, (origin_from, origin_count)) yield from cur def _origin_query( self, url_pattern, count=False, offset=0, limit=50, regexp=False, with_visit=False, cur=None, ): """ Method factorizing query creation for searching and counting origins. """ cur = self._cursor(cur) if count: origin_cols = "COUNT(*)" else: origin_cols = ",".join(self.origin_cols) query = """SELECT %s FROM origin o WHERE """ if with_visit: query += """ EXISTS ( SELECT 1 FROM origin_visit ov INNER JOIN origin_visit_status ovs ON ov.origin = ovs.origin AND ov.visit = ovs.visit INNER JOIN snapshot ON ovs.snapshot=snapshot.id WHERE ov.origin=o.id ) AND """ query += "url %s %%s " if not count: query += "ORDER BY id OFFSET %%s LIMIT %%s" if not regexp: query = query % (origin_cols, "ILIKE") query_params = ("%" + url_pattern + "%", offset, limit) else: query = query % (origin_cols, "~*") query_params = (url_pattern, offset, limit) if count: query_params = (query_params[0],) cur.execute(query, query_params) def origin_search( self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False, cur=None ): """Search for origins whose urls contain a provided string pattern or match a provided regular expression. The search is performed in a case insensitive way. Args: url_pattern (str): the string pattern to search for in origin urls offset (int): number of found origins to skip before returning results limit (int): the maximum number of found origins to return regexp (bool): if True, consider the provided pattern as a regular expression and returns origins whose urls match it with_visit (bool): if True, filter out origins with no visit """ self._origin_query( url_pattern, offset=offset, limit=limit, regexp=regexp, with_visit=with_visit, cur=cur, ) yield from cur def origin_count(self, url_pattern, regexp=False, with_visit=False, cur=None): """Count origins whose urls contain a provided string pattern or match a provided regular expression. The pattern search in origin urls is performed in a case insensitive way. Args: url_pattern (str): the string pattern to search for in origin urls regexp (bool): if True, consider the provided pattern as a regular expression and returns origins whose urls match it with_visit (bool): if True, filter out origins with no visit """ self._origin_query( url_pattern, count=True, regexp=regexp, with_visit=with_visit, cur=cur ) return cur.fetchone()[0] release_add_cols = [ "id", "target", "target_type", "date", "date_offset", "date_neg_utc_offset", "name", "comment", "synthetic", "author_fullname", "author_name", "author_email", ] release_get_cols = release_add_cols def release_get_from_list(self, releases, cur=None): cur = self._cursor(cur) query_keys = ", ".join( self.mangle_query_key(k, "release") for k in self.release_get_cols ) yield from execute_values_generator( cur, """ SELECT %s FROM (VALUES %%s) as t(sortkey, id) LEFT JOIN release ON t.id = release.id LEFT JOIN person author ON release.author = author.id ORDER BY sortkey """ % query_keys, ((sortkey, id) for sortkey, id in enumerate(releases)), ) def release_get_random(self, cur=None): return self._get_random_row_from_table("release", ["id"], "id", cur) _object_metadata_context_cols = [ "origin", "visit", "snapshot", "release", "revision", "path", "directory", ] """The list of context columns for all artifact types.""" _object_metadata_insert_cols = [ "type", "id", "authority_id", "fetcher_id", "discovery_date", "format", "metadata", *_object_metadata_context_cols, ] """List of columns of the object_metadata table, used when writing metadata.""" _object_metadata_insert_query = f""" INSERT INTO object_metadata ({', '.join(_object_metadata_insert_cols)}) VALUES ({', '.join('%s' for _ in _object_metadata_insert_cols)}) ON CONFLICT (id, authority_id, discovery_date, fetcher_id) DO NOTHING """ object_metadata_get_cols = [ "id", "discovery_date", "metadata_authority.type", "metadata_authority.url", "metadata_fetcher.id", "metadata_fetcher.name", "metadata_fetcher.version", *_object_metadata_context_cols, "format", "metadata", ] """List of columns of the object_metadata, metadata_authority, and metadata_fetcher tables, used when reading object metadata.""" _object_metadata_select_query = f""" SELECT object_metadata.id AS id, {', '.join(object_metadata_get_cols[1:-1])}, object_metadata.metadata AS metadata FROM object_metadata INNER JOIN metadata_authority ON (metadata_authority.id=authority_id) INNER JOIN metadata_fetcher ON (metadata_fetcher.id=fetcher_id) WHERE object_metadata.id=%s AND authority_id=%s """ def object_metadata_add( self, object_type: str, id: str, context: Dict[str, Union[str, bytes, int]], discovery_date: datetime.datetime, authority_id: int, fetcher_id: int, format: str, metadata: bytes, cur, ): query = self._object_metadata_insert_query args: Dict[str, Any] = dict( type=object_type, id=id, authority_id=authority_id, fetcher_id=fetcher_id, discovery_date=discovery_date, format=format, metadata=metadata, ) for col in self._object_metadata_context_cols: args[col] = context.get(col) params = [args[col] for col in self._object_metadata_insert_cols] cur.execute(query, params) def object_metadata_get( self, object_type: str, id: str, authority_id: int, after_time: Optional[datetime.datetime], after_fetcher: Optional[int], limit: int, cur, ): query_parts = [self._object_metadata_select_query] args = [id, authority_id] if after_fetcher is not None: assert after_time query_parts.append("AND (discovery_date, fetcher_id) > (%s, %s)") args.extend([after_time, after_fetcher]) elif after_time is not None: query_parts.append("AND discovery_date > %s") args.append(after_time) query_parts.append("ORDER BY discovery_date, fetcher_id") if limit: query_parts.append("LIMIT %s") args.append(limit) cur.execute(" ".join(query_parts), args) yield from cur metadata_fetcher_cols = ["name", "version", "metadata"] def metadata_fetcher_add( self, name: str, version: str, metadata: bytes, cur=None ) -> None: cur = self._cursor(cur) cur.execute( "INSERT INTO metadata_fetcher (name, version, metadata) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (name, version, jsonize(metadata)), ) def metadata_fetcher_get(self, name: str, version: str, cur=None): cur = self._cursor(cur) cur.execute( f"SELECT {', '.join(self.metadata_fetcher_cols)} " f"FROM metadata_fetcher " f"WHERE name=%s AND version=%s", (name, version), ) return cur.fetchone() def metadata_fetcher_get_id( self, name: str, version: str, cur=None ) -> Optional[int]: cur = self._cursor(cur) cur.execute( "SELECT id FROM metadata_fetcher WHERE name=%s AND version=%s", (name, version), ) row = cur.fetchone() if row: return row[0] else: return None metadata_authority_cols = ["type", "url", "metadata"] def metadata_authority_add( self, type: str, url: str, metadata: bytes, cur=None ) -> None: cur = self._cursor(cur) cur.execute( "INSERT INTO metadata_authority (type, url, metadata) " "VALUES (%s, %s, %s) ON CONFLICT DO NOTHING", (type, url, jsonize(metadata)), ) def metadata_authority_get(self, type: str, url: str, cur=None): cur = self._cursor(cur) cur.execute( f"SELECT {', '.join(self.metadata_authority_cols)} " f"FROM metadata_authority " f"WHERE type=%s AND url=%s", (type, url), ) return cur.fetchone() def metadata_authority_get_id(self, type: str, url: str, cur=None) -> Optional[int]: cur = self._cursor(cur) cur.execute( "SELECT id FROM metadata_authority WHERE type=%s AND url=%s", (type, url) ) row = cur.fetchone() if row: return row[0] else: return None def _get_random_row_from_table(self, table_name, cols, id_col, cur=None): random_sha1 = bytes(random.randint(0, 255) for _ in range(SHA1_SIZE)) cur = self._cursor(cur) query = """ (SELECT {cols} FROM {table} WHERE {id_col} >= %s ORDER BY {id_col} LIMIT 1) UNION (SELECT {cols} FROM {table} WHERE {id_col} < %s ORDER BY {id_col} DESC LIMIT 1) LIMIT 1 """.format( cols=", ".join(cols), table=table_name, id_col=id_col ) cur.execute(query, (random_sha1, random_sha1)) row = cur.fetchone() if row: return row[0]