diff --git a/requirements.txt b/requirements.txt
index c712cdd2..619b1325 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,9 +1,10 @@
click
flask
psycopg2
vcversioner
aiohttp
tenacity
cassandra-driver >= 3.19.0, != 3.21.0
deprecated
typing-extensions
+mypy_extensions
diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py
index 5f216304..e10e19c7 100644
--- a/swh/storage/cassandra/cql.py
+++ b/swh/storage/cassandra/cql.py
@@ -1,1017 +1,1034 @@
# Copyright (C) 2019-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import dataclasses
import datetime
import functools
import logging
import random
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Tuple,
Type,
TypeVar,
)
from cassandra import CoordinationFailure
from cassandra.cluster import Cluster, EXEC_PROFILE_DEFAULT, ExecutionProfile, ResultSet
from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy
from cassandra.query import PreparedStatement, BoundStatement, dict_factory
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
retry_if_exception_type,
)
+from mypy_extensions import NamedArg
from swh.model.model import (
Content,
SkippedContent,
Sha1Git,
TimestampWithTimezone,
Timestamp,
Person,
)
from swh.storage.interface import ListOrder
from .common import TOKEN_BEGIN, TOKEN_END, hash_url, remove_keys
from .model import (
BaseRow,
ContentRow,
DirectoryEntryRow,
DirectoryRow,
MetadataAuthorityRow,
MetadataFetcherRow,
ObjectCountRow,
OriginRow,
OriginVisitRow,
OriginVisitStatusRow,
RawExtrinsicMetadataRow,
ReleaseRow,
RevisionParentRow,
RevisionRow,
SkippedContentRow,
SnapshotBranchRow,
SnapshotRow,
)
from .schema import CREATE_TABLES_QUERIES, HASH_ALGORITHMS
logger = logging.getLogger(__name__)
_execution_profiles = {
EXEC_PROFILE_DEFAULT: ExecutionProfile(
load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy()),
row_factory=dict_factory,
),
}
# Configuration for cassandra-driver's access to servers:
# * hit the right server directly when sending a query (TokenAwarePolicy),
# * if there's more than one, then pick one at random that's in the same
# datacenter as the client (DCAwareRoundRobinPolicy)
def create_keyspace(
hosts: List[str], keyspace: str, port: int = 9042, *, durable_writes=True
):
cluster = Cluster(hosts, port=port, execution_profiles=_execution_profiles)
session = cluster.connect()
extra_params = ""
if not durable_writes:
extra_params = "AND durable_writes = false"
session.execute(
"""CREATE KEYSPACE IF NOT EXISTS "%s"
WITH REPLICATION = {
'class' : 'SimpleStrategy',
'replication_factor' : 1
} %s;
"""
% (keyspace, extra_params)
)
session.execute('USE "%s"' % keyspace)
for query in CREATE_TABLES_QUERIES:
session.execute(query)
-T = TypeVar("T")
+TRet = TypeVar("TRet")
-def _prepared_statement(query: str) -> Callable[[Callable[..., T]], Callable[..., T]]:
+def _prepared_statement(
+ query: str,
+) -> Callable[[Callable[..., TRet]], Callable[..., TRet]]:
"""Returns a decorator usable on methods of CqlRunner, to
inject them with a 'statement' argument, that is a prepared
statement corresponding to the query.
This only works on methods of CqlRunner, as preparing a
statement requires a connection to a Cassandra server."""
def decorator(f):
@functools.wraps(f)
- def newf(self, *args, **kwargs) -> T:
+ def newf(self, *args, **kwargs) -> TRet:
if f.__name__ not in self._prepared_statements:
statement: PreparedStatement = self._session.prepare(query)
self._prepared_statements[f.__name__] = statement
return f(
self, *args, **kwargs, statement=self._prepared_statements[f.__name__]
)
return newf
return decorator
-def _prepared_insert_statement(table_name: str, columns: List[str]):
+TArg = TypeVar("TArg")
+TSelf = TypeVar("TSelf")
+
+
+def _prepared_insert_statement(
+ table_name: str, columns: List[str]
+) -> Callable[
+ [Callable[[TSelf, TArg, NamedArg(Any, "statement")], TRet]], # noqa
+ Callable[[TSelf, TArg], TRet],
+]:
"""Shorthand for using `_prepared_statement` for `INSERT INTO`
statements."""
return _prepared_statement(
"INSERT INTO %s (%s) VALUES (%s)"
% (table_name, ", ".join(columns), ", ".join("?" for _ in columns),)
)
-def _prepared_exists_statement(table_name: str):
+def _prepared_exists_statement(
+ table_name: str,
+) -> Callable[
+ [Callable[[TSelf, TArg, NamedArg(Any, "statement")], TRet]], # noqa
+ Callable[[TSelf, TArg], TRet],
+]:
"""Shorthand for using `_prepared_statement` for queries that only
check which ids in a list exist in the table."""
return _prepared_statement(f"SELECT id FROM {table_name} WHERE id IN ?")
class CqlRunner:
"""Class managing prepared statements and building queries to be sent
to Cassandra."""
def __init__(self, hosts: List[str], keyspace: str, port: int):
self._cluster = Cluster(
hosts, port=port, execution_profiles=_execution_profiles
)
self._session = self._cluster.connect(keyspace)
self._cluster.register_user_type(
keyspace, "microtimestamp_with_timezone", TimestampWithTimezone
)
self._cluster.register_user_type(keyspace, "microtimestamp", Timestamp)
self._cluster.register_user_type(keyspace, "person", Person)
self._prepared_statements: Dict[str, PreparedStatement] = {}
##########################
# Common utility functions
##########################
MAX_RETRIES = 3
@retry(
wait=wait_random_exponential(multiplier=1, max=10),
stop=stop_after_attempt(MAX_RETRIES),
retry=retry_if_exception_type(CoordinationFailure),
)
def _execute_with_retries(self, statement, args) -> ResultSet:
return self._session.execute(statement, args, timeout=1000.0)
@_prepared_statement(
"UPDATE object_count SET count = count + ? "
"WHERE partition_key = 0 AND object_type = ?"
)
def _increment_counter(
self, object_type: str, nb: int, *, statement: PreparedStatement
) -> None:
self._execute_with_retries(statement, [nb, object_type])
def _add_one(self, statement, object_type: Optional[str], obj: BaseRow) -> None:
if object_type:
self._increment_counter(object_type, 1)
self._execute_with_retries(statement, dataclasses.astuple(obj))
_T = TypeVar("_T", bound=BaseRow)
def _get_random_row(self, row_class: Type[_T], statement) -> Optional[_T]: # noqa
"""Takes a prepared statement of the form
"SELECT * FROM
WHERE token() > ? LIMIT 1"
and uses it to return a random row"""
token = random.randint(TOKEN_BEGIN, TOKEN_END)
rows = self._execute_with_retries(statement, [token])
if not rows:
# There are no row with a greater token; wrap around to get
# the row with the smallest token
rows = self._execute_with_retries(statement, [TOKEN_BEGIN])
if rows:
return row_class.from_dict(rows.one()) # type: ignore
else:
return None
def _missing(self, statement, ids):
rows = self._execute_with_retries(statement, [ids])
found_ids = {row["id"] for row in rows}
return [id_ for id_ in ids if id_ not in found_ids]
##########################
# 'content' table
##########################
_content_pk = ["sha1", "sha1_git", "sha256", "blake2s256"]
def _content_add_finalize(self, statement: BoundStatement) -> None:
"""Returned currified by content_add_prepare, to be called when the
content row should be added to the primary table."""
self._execute_with_retries(statement, None)
self._increment_counter("content", 1)
@_prepared_insert_statement("content", ContentRow.cols())
def content_add_prepare(
self, content: ContentRow, *, statement
) -> Tuple[int, Callable[[], None]]:
"""Prepares insertion of a Content to the main 'content' table.
Returns a token (to be used in secondary tables), and a function to be
called to perform the insertion in the main table."""
statement = statement.bind(dataclasses.astuple(content))
# Type used for hashing keys (usually, it will be
# cassandra.metadata.Murmur3Token)
token_class = self._cluster.metadata.token_map.token_class
# Token of the row when it will be inserted. This is equivalent to
# "SELECT token({', '.join(self._content_pk)}) FROM content WHERE ..."
# after the row is inserted; but we need the token to insert in the
# index tables *before* inserting to the main 'content' table
token = token_class.from_key(statement.routing_key).value
assert TOKEN_BEGIN <= token <= TOKEN_END
# Function to be called after the indexes contain their respective
# row
finalizer = functools.partial(self._content_add_finalize, statement)
return (token, finalizer)
@_prepared_statement(
"SELECT * FROM content WHERE "
+ " AND ".join(map("%s = ?".__mod__, HASH_ALGORITHMS))
)
def content_get_from_pk(
self, content_hashes: Dict[str, bytes], *, statement
) -> Optional[ContentRow]:
rows = list(
self._execute_with_retries(
statement, [content_hashes[algo] for algo in HASH_ALGORITHMS]
)
)
assert len(rows) <= 1
if rows:
return ContentRow(**rows[0])
else:
return None
@_prepared_statement(
"SELECT * FROM content WHERE token(" + ", ".join(_content_pk) + ") = ?"
)
def content_get_from_token(self, token, *, statement) -> Iterable[ContentRow]:
return map(ContentRow.from_dict, self._execute_with_retries(statement, [token]))
@_prepared_statement(
"SELECT * FROM content WHERE token(%s) > ? LIMIT 1" % ", ".join(_content_pk)
)
def content_get_random(self, *, statement) -> Optional[ContentRow]:
return self._get_random_row(ContentRow, statement)
@_prepared_statement(
(
"SELECT token({0}) AS tok, {1} FROM content "
"WHERE token({0}) >= ? AND token({0}) <= ? LIMIT ?"
).format(", ".join(_content_pk), ", ".join(ContentRow.cols()))
)
def content_get_token_range(
self, start: int, end: int, limit: int, *, statement
) -> Iterable[Tuple[int, ContentRow]]:
"""Returns an iterable of (token, row)"""
return (
(row["tok"], ContentRow.from_dict(remove_keys(row, ("tok",))))
for row in self._execute_with_retries(statement, [start, end, limit])
)
##########################
# 'content_by_*' tables
##########################
@_prepared_statement(
"SELECT sha1_git AS id FROM content_by_sha1_git WHERE sha1_git IN ?"
)
def content_missing_by_sha1_git(
self, ids: List[bytes], *, statement
) -> List[bytes]:
return self._missing(statement, ids)
def content_index_add_one(self, algo: str, content: Content, token: int) -> None:
"""Adds a row mapping content[algo] to the token of the Content in
the main 'content' table."""
query = (
f"INSERT INTO content_by_{algo} ({algo}, target_token) " f"VALUES (%s, %s)"
)
self._execute_with_retries(query, [content.get_hash(algo), token])
def content_get_tokens_from_single_hash(
self, algo: str, hash_: bytes
) -> Iterable[int]:
assert algo in HASH_ALGORITHMS
query = f"SELECT target_token FROM content_by_{algo} WHERE {algo} = %s"
return (
row["target_token"] for row in self._execute_with_retries(query, [hash_])
)
##########################
# 'skipped_content' table
##########################
_skipped_content_pk = ["sha1", "sha1_git", "sha256", "blake2s256"]
_magic_null_pk = b""
"""
NULLs (or all-empty blobs) are not allowed in primary keys; instead use a
special value that can't possibly be a valid hash.
"""
def _skipped_content_add_finalize(self, statement: BoundStatement) -> None:
"""Returned currified by skipped_content_add_prepare, to be called
when the content row should be added to the primary table."""
self._execute_with_retries(statement, None)
self._increment_counter("skipped_content", 1)
@_prepared_insert_statement("skipped_content", SkippedContentRow.cols())
def skipped_content_add_prepare(
self, content, *, statement
) -> Tuple[int, Callable[[], None]]:
"""Prepares insertion of a Content to the main 'skipped_content' table.
Returns a token (to be used in secondary tables), and a function to be
called to perform the insertion in the main table."""
# Replace NULLs (which are not allowed in the partition key) with
# an empty byte string
for key in self._skipped_content_pk:
if getattr(content, key) is None:
setattr(content, key, self._magic_null_pk)
statement = statement.bind(dataclasses.astuple(content))
# Type used for hashing keys (usually, it will be
# cassandra.metadata.Murmur3Token)
token_class = self._cluster.metadata.token_map.token_class
# Token of the row when it will be inserted. This is equivalent to
# "SELECT token({', '.join(self._content_pk)})
# FROM skipped_content WHERE ..."
# after the row is inserted; but we need the token to insert in the
# index tables *before* inserting to the main 'skipped_content' table
token = token_class.from_key(statement.routing_key).value
assert TOKEN_BEGIN <= token <= TOKEN_END
# Function to be called after the indexes contain their respective
# row
finalizer = functools.partial(self._skipped_content_add_finalize, statement)
return (token, finalizer)
@_prepared_statement(
"SELECT * FROM skipped_content WHERE "
+ " AND ".join(map("%s = ?".__mod__, HASH_ALGORITHMS))
)
def skipped_content_get_from_pk(
self, content_hashes: Dict[str, bytes], *, statement
) -> Optional[SkippedContentRow]:
rows = list(
self._execute_with_retries(
statement,
[
content_hashes[algo] or self._magic_null_pk
for algo in HASH_ALGORITHMS
],
)
)
assert len(rows) <= 1
if rows:
# TODO: convert _magic_null_pk back to None?
return SkippedContentRow.from_dict(rows[0])
else:
return None
##########################
# 'skipped_content_by_*' tables
##########################
def skipped_content_index_add_one(
self, algo: str, content: SkippedContent, token: int
) -> None:
"""Adds a row mapping content[algo] to the token of the SkippedContent
in the main 'skipped_content' table."""
query = (
f"INSERT INTO skipped_content_by_{algo} ({algo}, target_token) "
f"VALUES (%s, %s)"
)
self._execute_with_retries(
query, [content.get_hash(algo) or self._magic_null_pk, token]
)
##########################
# 'revision' table
##########################
@_prepared_exists_statement("revision")
def revision_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
return self._missing(statement, ids)
@_prepared_insert_statement("revision", RevisionRow.cols())
def revision_add_one(self, revision: RevisionRow, *, statement) -> None:
self._add_one(statement, "revision", revision)
@_prepared_statement("SELECT id FROM revision WHERE id IN ?")
def revision_get_ids(self, revision_ids, *, statement) -> Iterable[int]:
return (
row["id"] for row in self._execute_with_retries(statement, [revision_ids])
)
@_prepared_statement("SELECT * FROM revision WHERE id IN ?")
def revision_get(self, revision_ids, *, statement) -> Iterable[RevisionRow]:
return map(
RevisionRow.from_dict, self._execute_with_retries(statement, [revision_ids])
)
@_prepared_statement("SELECT * FROM revision WHERE token(id) > ? LIMIT 1")
def revision_get_random(self, *, statement) -> Optional[RevisionRow]:
return self._get_random_row(RevisionRow, statement)
##########################
# 'revision_parent' table
##########################
@_prepared_insert_statement("revision_parent", RevisionParentRow.cols())
def revision_parent_add_one(
self, revision_parent: RevisionParentRow, *, statement
) -> None:
self._add_one(statement, None, revision_parent)
@_prepared_statement("SELECT parent_id FROM revision_parent WHERE id = ?")
def revision_parent_get(
self, revision_id: Sha1Git, *, statement
) -> Iterable[bytes]:
return (
row["parent_id"]
for row in self._execute_with_retries(statement, [revision_id])
)
##########################
# 'release' table
##########################
@_prepared_exists_statement("release")
def release_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
return self._missing(statement, ids)
@_prepared_insert_statement("release", ReleaseRow.cols())
def release_add_one(self, release: ReleaseRow, *, statement) -> None:
self._add_one(statement, "release", release)
@_prepared_statement("SELECT * FROM release WHERE id in ?")
def release_get(self, release_ids: List[str], *, statement) -> Iterable[ReleaseRow]:
return map(
ReleaseRow.from_dict, self._execute_with_retries(statement, [release_ids])
)
@_prepared_statement("SELECT * FROM release WHERE token(id) > ? LIMIT 1")
def release_get_random(self, *, statement) -> Optional[ReleaseRow]:
return self._get_random_row(ReleaseRow, statement)
##########################
# 'directory' table
##########################
@_prepared_exists_statement("directory")
def directory_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
return self._missing(statement, ids)
@_prepared_insert_statement("directory", DirectoryRow.cols())
def directory_add_one(self, directory: DirectoryRow, *, statement) -> None:
"""Called after all calls to directory_entry_add_one, to
commit/finalize the directory."""
self._add_one(statement, "directory", directory)
@_prepared_statement("SELECT * FROM directory WHERE token(id) > ? LIMIT 1")
def directory_get_random(self, *, statement) -> Optional[DirectoryRow]:
return self._get_random_row(DirectoryRow, statement)
##########################
# 'directory_entry' table
##########################
@_prepared_insert_statement("directory_entry", DirectoryEntryRow.cols())
def directory_entry_add_one(self, entry: DirectoryEntryRow, *, statement) -> None:
self._add_one(statement, None, entry)
@_prepared_statement("SELECT * FROM directory_entry WHERE directory_id IN ?")
def directory_entry_get(
self, directory_ids, *, statement
) -> Iterable[DirectoryEntryRow]:
return map(
DirectoryEntryRow.from_dict,
self._execute_with_retries(statement, [directory_ids]),
)
##########################
# 'snapshot' table
##########################
@_prepared_exists_statement("snapshot")
def snapshot_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
return self._missing(statement, ids)
@_prepared_insert_statement("snapshot", SnapshotRow.cols())
def snapshot_add_one(self, snapshot: SnapshotRow, *, statement) -> None:
self._add_one(statement, "snapshot", snapshot)
@_prepared_statement("SELECT * FROM snapshot WHERE id = ?")
def snapshot_get(self, snapshot_id: Sha1Git, *, statement) -> ResultSet:
return map(
SnapshotRow.from_dict, self._execute_with_retries(statement, [snapshot_id])
)
@_prepared_statement("SELECT * FROM snapshot WHERE token(id) > ? LIMIT 1")
def snapshot_get_random(self, *, statement) -> Optional[SnapshotRow]:
return self._get_random_row(SnapshotRow, statement)
##########################
# 'snapshot_branch' table
##########################
@_prepared_insert_statement("snapshot_branch", SnapshotBranchRow.cols())
def snapshot_branch_add_one(self, branch: SnapshotBranchRow, *, statement) -> None:
self._add_one(statement, None, branch)
@_prepared_statement(
"SELECT ascii_bins_count(target_type) AS counts "
"FROM snapshot_branch "
"WHERE snapshot_id = ? "
)
def snapshot_count_branches(
self, snapshot_id: Sha1Git, *, statement
) -> Dict[Optional[str], int]:
"""Returns a dictionary from type names to the number of branches
of that type."""
row = self._execute_with_retries(statement, [snapshot_id]).one()
(nb_none, counts) = row["counts"]
return {None: nb_none, **counts}
@_prepared_statement(
"SELECT * FROM snapshot_branch WHERE snapshot_id = ? AND name >= ? LIMIT ?"
)
def snapshot_branch_get(
self, snapshot_id: Sha1Git, from_: bytes, limit: int, *, statement
) -> Iterable[SnapshotBranchRow]:
return map(
SnapshotBranchRow.from_dict,
self._execute_with_retries(statement, [snapshot_id, from_, limit]),
)
##########################
# 'origin' table
##########################
@_prepared_insert_statement("origin", OriginRow.cols())
def origin_add_one(self, origin: OriginRow, *, statement) -> None:
self._add_one(statement, "origin", origin)
@_prepared_statement("SELECT * FROM origin WHERE sha1 = ?")
def origin_get_by_sha1(self, sha1: bytes, *, statement) -> Iterable[OriginRow]:
return map(OriginRow.from_dict, self._execute_with_retries(statement, [sha1]))
def origin_get_by_url(self, url: str) -> Iterable[OriginRow]:
return self.origin_get_by_sha1(hash_url(url))
@_prepared_statement(
f'SELECT token(sha1) AS tok, {", ".join(OriginRow.cols())} '
f"FROM origin WHERE token(sha1) >= ? LIMIT ?"
)
def origin_list(
self, start_token: int, limit: int, *, statement
) -> Iterable[Tuple[int, OriginRow]]:
"""Returns an iterable of (token, origin)"""
return (
(row["tok"], OriginRow.from_dict(remove_keys(row, ("tok",))))
for row in self._execute_with_retries(statement, [start_token, limit])
)
@_prepared_statement("SELECT * FROM origin")
def origin_iter_all(self, *, statement) -> Iterable[OriginRow]:
return map(OriginRow.from_dict, self._execute_with_retries(statement, []))
@_prepared_statement("SELECT next_visit_id FROM origin WHERE sha1 = ?")
def _origin_get_next_visit_id(self, origin_sha1: bytes, *, statement) -> int:
rows = list(self._execute_with_retries(statement, [origin_sha1]))
assert len(rows) == 1 # TODO: error handling
return rows[0]["next_visit_id"]
@_prepared_statement(
"UPDATE origin SET next_visit_id=? WHERE sha1 = ? IF next_visit_id=?"
)
def origin_generate_unique_visit_id(self, origin_url: str, *, statement) -> int:
origin_sha1 = hash_url(origin_url)
next_id = self._origin_get_next_visit_id(origin_sha1)
while True:
res = list(
self._execute_with_retries(
statement, [next_id + 1, origin_sha1, next_id]
)
)
assert len(res) == 1
if res[0]["[applied]"]:
# No data race
return next_id
else:
# Someone else updated it before we did, let's try again
next_id = res[0]["next_visit_id"]
# TODO: abort after too many attempts
return next_id
##########################
# 'origin_visit' table
##########################
@_prepared_statement(
"SELECT * FROM origin_visit WHERE origin = ? AND visit > ? "
"ORDER BY visit ASC"
)
def _origin_visit_get_pagination_asc_no_limit(
self, origin_url: str, last_visit: int, *, statement
) -> ResultSet:
return self._execute_with_retries(statement, [origin_url, last_visit])
@_prepared_statement(
"SELECT * FROM origin_visit WHERE origin = ? AND visit > ? "
"ORDER BY visit ASC "
"LIMIT ?"
)
def _origin_visit_get_pagination_asc_limit(
self, origin_url: str, last_visit: int, limit: int, *, statement
) -> ResultSet:
return self._execute_with_retries(statement, [origin_url, last_visit, limit])
@_prepared_statement(
"SELECT * FROM origin_visit WHERE origin = ? AND visit < ? "
"ORDER BY visit DESC"
)
def _origin_visit_get_pagination_desc_no_limit(
self, origin_url: str, last_visit: int, *, statement
) -> ResultSet:
return self._execute_with_retries(statement, [origin_url, last_visit])
@_prepared_statement(
"SELECT * FROM origin_visit WHERE origin = ? AND visit < ? "
"ORDER BY visit DESC "
"LIMIT ?"
)
def _origin_visit_get_pagination_desc_limit(
self, origin_url: str, last_visit: int, limit: int, *, statement
) -> ResultSet:
return self._execute_with_retries(statement, [origin_url, last_visit, limit])
@_prepared_statement(
"SELECT * FROM origin_visit WHERE origin = ? ORDER BY visit ASC LIMIT ?"
)
def _origin_visit_get_no_pagination_asc_limit(
self, origin_url: str, limit: int, *, statement
) -> ResultSet:
return self._execute_with_retries(statement, [origin_url, limit])
@_prepared_statement(
"SELECT * FROM origin_visit WHERE origin = ? ORDER BY visit ASC "
)
def _origin_visit_get_no_pagination_asc_no_limit(
self, origin_url: str, *, statement
) -> ResultSet:
return self._execute_with_retries(statement, [origin_url])
@_prepared_statement(
"SELECT * FROM origin_visit WHERE origin = ? ORDER BY visit DESC"
)
def _origin_visit_get_no_pagination_desc_no_limit(
self, origin_url: str, *, statement
) -> ResultSet:
return self._execute_with_retries(statement, [origin_url])
@_prepared_statement(
"SELECT * FROM origin_visit WHERE origin = ? ORDER BY visit DESC LIMIT ?"
)
def _origin_visit_get_no_pagination_desc_limit(
self, origin_url: str, limit: int, *, statement
) -> ResultSet:
return self._execute_with_retries(statement, [origin_url, limit])
def origin_visit_get(
self,
origin_url: str,
last_visit: Optional[int],
limit: Optional[int],
order: ListOrder,
) -> Iterable[OriginVisitRow]:
args: List[Any] = [origin_url]
if last_visit is not None:
page_name = "pagination"
args.append(last_visit)
else:
page_name = "no_pagination"
if limit is not None:
limit_name = "limit"
args.append(limit)
else:
limit_name = "no_limit"
method_name = f"_origin_visit_get_{page_name}_{order.value}_{limit_name}"
origin_visit_get_method = getattr(self, method_name)
return map(OriginVisitRow.from_dict, origin_visit_get_method(*args))
@_prepared_statement(
"SELECT * FROM origin_visit_status WHERE origin = ? "
"AND visit = ? AND date >= ? "
"ORDER BY date ASC "
"LIMIT ?"
)
def _origin_visit_status_get_with_date_asc_limit(
self,
origin: str,
visit: int,
date_from: datetime.datetime,
limit: int,
*,
statement,
) -> ResultSet:
return self._execute_with_retries(statement, [origin, visit, date_from, limit])
@_prepared_statement(
"SELECT * FROM origin_visit_status WHERE origin = ? "
"AND visit = ? AND date <= ? "
"ORDER BY visit DESC "
"LIMIT ?"
)
def _origin_visit_status_get_with_date_desc_limit(
self,
origin: str,
visit: int,
date_from: datetime.datetime,
limit: int,
*,
statement,
) -> ResultSet:
return self._execute_with_retries(statement, [origin, visit, date_from, limit])
@_prepared_statement(
"SELECT * FROM origin_visit_status WHERE origin = ? AND visit = ? "
"ORDER BY visit ASC "
"LIMIT ?"
)
def _origin_visit_status_get_with_no_date_asc_limit(
self, origin: str, visit: int, limit: int, *, statement
) -> ResultSet:
return self._execute_with_retries(statement, [origin, visit, limit])
@_prepared_statement(
"SELECT * FROM origin_visit_status WHERE origin = ? AND visit = ? "
"ORDER BY visit DESC "
"LIMIT ?"
)
def _origin_visit_status_get_with_no_date_desc_limit(
self, origin: str, visit: int, limit: int, *, statement
) -> ResultSet:
return self._execute_with_retries(statement, [origin, visit, limit])
def origin_visit_status_get_range(
self,
origin: str,
visit: int,
date_from: Optional[datetime.datetime],
limit: int,
order: ListOrder,
) -> Iterable[OriginVisitStatusRow]:
args: List[Any] = [origin, visit]
if date_from is not None:
date_name = "date"
args.append(date_from)
else:
date_name = "no_date"
args.append(limit)
method_name = f"_origin_visit_status_get_with_{date_name}_{order.value}_limit"
origin_visit_status_get_method = getattr(self, method_name)
return map(
OriginVisitStatusRow.from_dict, origin_visit_status_get_method(*args)
)
@_prepared_insert_statement("origin_visit", OriginVisitRow.cols())
def origin_visit_add_one(self, visit: OriginVisitRow, *, statement) -> None:
self._add_one(statement, "origin_visit", visit)
@_prepared_insert_statement("origin_visit_status", OriginVisitStatusRow.cols())
def origin_visit_status_add_one(
self, visit_update: OriginVisitStatusRow, *, statement
) -> None:
self._add_one(statement, None, visit_update)
def origin_visit_status_get_latest(
self, origin: str, visit: int,
) -> Optional[OriginVisitStatusRow]:
"""Given an origin visit id, return its latest origin_visit_status
"""
return next(self.origin_visit_status_get(origin, visit), None)
@_prepared_statement(
"SELECT * FROM origin_visit_status "
"WHERE origin = ? AND visit = ? "
"ORDER BY date DESC"
)
def origin_visit_status_get(
self,
origin: str,
visit: int,
allowed_statuses: Optional[List[str]] = None,
require_snapshot: bool = False,
*,
statement,
) -> Iterator[OriginVisitStatusRow]:
"""Return all origin visit statuses for a given visit
"""
return map(
OriginVisitStatusRow.from_dict,
self._execute_with_retries(statement, [origin, visit]),
)
@_prepared_statement("SELECT * FROM origin_visit WHERE origin = ? AND visit = ?")
def origin_visit_get_one(
self, origin_url: str, visit_id: int, *, statement
) -> Optional[OriginVisitRow]:
# TODO: error handling
rows = list(self._execute_with_retries(statement, [origin_url, visit_id]))
if rows:
return OriginVisitRow.from_dict(rows[0])
else:
return None
@_prepared_statement("SELECT * FROM origin_visit WHERE origin = ?")
def origin_visit_get_all(
self, origin_url: str, *, statement
) -> Iterable[OriginVisitRow]:
return map(
OriginVisitRow.from_dict,
self._execute_with_retries(statement, [origin_url]),
)
@_prepared_statement("SELECT * FROM origin_visit WHERE token(origin) >= ?")
def _origin_visit_iter_from(
self, min_token: int, *, statement
) -> Iterable[OriginVisitRow]:
return map(
OriginVisitRow.from_dict, self._execute_with_retries(statement, [min_token])
)
@_prepared_statement("SELECT * FROM origin_visit WHERE token(origin) < ?")
def _origin_visit_iter_to(
self, max_token: int, *, statement
) -> Iterable[OriginVisitRow]:
return map(
OriginVisitRow.from_dict, self._execute_with_retries(statement, [max_token])
)
def origin_visit_iter(self, start_token: int) -> Iterator[OriginVisitRow]:
"""Returns all origin visits in order from this token,
and wraps around the token space."""
yield from self._origin_visit_iter_from(start_token)
yield from self._origin_visit_iter_to(start_token)
##########################
# 'metadata_authority' table
##########################
@_prepared_insert_statement("metadata_authority", MetadataAuthorityRow.cols())
def metadata_authority_add(self, authority: MetadataAuthorityRow, *, statement):
self._add_one(statement, None, authority)
@_prepared_statement("SELECT * from metadata_authority WHERE type = ? AND url = ?")
def metadata_authority_get(
self, type, url, *, statement
) -> Optional[MetadataAuthorityRow]:
rows = list(self._execute_with_retries(statement, [type, url]))
if rows:
return MetadataAuthorityRow.from_dict(rows[0])
else:
return None
##########################
# 'metadata_fetcher' table
##########################
@_prepared_insert_statement("metadata_fetcher", MetadataFetcherRow.cols())
def metadata_fetcher_add(self, fetcher, *, statement):
self._add_one(statement, None, fetcher)
@_prepared_statement(
"SELECT * from metadata_fetcher WHERE name = ? AND version = ?"
)
def metadata_fetcher_get(
self, name, version, *, statement
) -> Optional[MetadataFetcherRow]:
rows = list(self._execute_with_retries(statement, [name, version]))
if rows:
return MetadataFetcherRow.from_dict(rows[0])
else:
return None
#########################
# 'raw_extrinsic_metadata' table
#########################
@_prepared_insert_statement(
"raw_extrinsic_metadata", RawExtrinsicMetadataRow.cols()
)
def raw_extrinsic_metadata_add(self, raw_extrinsic_metadata, *, statement):
self._add_one(statement, None, raw_extrinsic_metadata)
@_prepared_statement(
"SELECT * from raw_extrinsic_metadata "
"WHERE id=? AND authority_url=? AND discovery_date>? AND authority_type=?"
)
def raw_extrinsic_metadata_get_after_date(
self,
id: str,
authority_type: str,
authority_url: str,
after: datetime.datetime,
*,
statement,
) -> Iterable[RawExtrinsicMetadataRow]:
return map(
RawExtrinsicMetadataRow.from_dict,
self._execute_with_retries(
statement, [id, authority_url, after, authority_type]
),
)
@_prepared_statement(
"SELECT * from raw_extrinsic_metadata "
"WHERE id=? AND authority_type=? AND authority_url=? "
"AND (discovery_date, fetcher_name, fetcher_version) > (?, ?, ?)"
)
def raw_extrinsic_metadata_get_after_date_and_fetcher(
self,
id: str,
authority_type: str,
authority_url: str,
after_date: datetime.datetime,
after_fetcher_name: str,
after_fetcher_version: str,
*,
statement,
) -> Iterable[RawExtrinsicMetadataRow]:
return map(
RawExtrinsicMetadataRow.from_dict,
self._execute_with_retries(
statement,
[
id,
authority_type,
authority_url,
after_date,
after_fetcher_name,
after_fetcher_version,
],
),
)
@_prepared_statement(
"SELECT * from raw_extrinsic_metadata "
"WHERE id=? AND authority_url=? AND authority_type=?"
)
def raw_extrinsic_metadata_get(
self, id: str, authority_type: str, authority_url: str, *, statement
) -> Iterable[RawExtrinsicMetadataRow]:
return map(
RawExtrinsicMetadataRow.from_dict,
self._execute_with_retries(statement, [id, authority_url, authority_type]),
)
##########################
# Miscellaneous
##########################
@_prepared_statement("SELECT uuid() FROM revision LIMIT 1;")
def check_read(self, *, statement):
self._execute_with_retries(statement, [])
@_prepared_statement("SELECT * FROM object_count WHERE partition_key=0")
def stat_counters(self, *, statement) -> ResultSet:
return map(ObjectCountRow.from_dict, self._execute_with_retries(statement, []))
diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py
index f074a62b..ce2ab783 100644
--- a/swh/storage/cassandra/storage.py
+++ b/swh/storage/cassandra/storage.py
@@ -1,1299 +1,1311 @@
# Copyright (C) 2019-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import base64
import datetime
import itertools
import json
import random
import re
-from typing import Any, Dict, List, Iterable, Optional, Set, Tuple, Union
+from typing import Any, Callable, Dict, List, Iterable, Optional, Set, Tuple, Union
import attr
from swh.core.api.serializers import msgpack_loads, msgpack_dumps
from swh.model.identifiers import parse_swhid, SWHID
from swh.model.hashutil import DEFAULT_ALGORITHMS
from swh.model.model import (
Revision,
Release,
Directory,
DirectoryEntry,
Content,
SkippedContent,
OriginVisit,
OriginVisitStatus,
Snapshot,
SnapshotBranch,
TargetType,
Origin,
MetadataAuthority,
MetadataAuthorityType,
MetadataFetcher,
MetadataTargetType,
RawExtrinsicMetadata,
Sha1Git,
)
from swh.storage.interface import (
ListOrder,
PagedResult,
PartialBranches,
Sha1,
VISIT_STATUSES,
)
from swh.storage.objstorage import ObjStorage
from swh.storage.writer import JournalWriter
from swh.storage.utils import map_optional, now
from ..exc import StorageArgumentException, HashCollision
from .common import TOKEN_BEGIN, TOKEN_END, hash_url, remove_keys
from . import converters
from .cql import CqlRunner
from .schema import HASH_ALGORITHMS
from .model import (
ContentRow,
DirectoryEntryRow,
DirectoryRow,
MetadataAuthorityRow,
MetadataFetcherRow,
OriginRow,
OriginVisitRow,
RawExtrinsicMetadataRow,
RevisionParentRow,
SkippedContentRow,
SnapshotBranchRow,
SnapshotRow,
)
# Max block size of contents to return
BULK_BLOCK_CONTENT_LEN_MAX = 10000
class CassandraStorage:
def __init__(self, hosts, keyspace, objstorage, port=9042, journal_writer=None):
- self._cql_runner = CqlRunner(hosts, keyspace, port)
- self.journal_writer = JournalWriter(journal_writer)
- self.objstorage = ObjStorage(objstorage)
+ self._cql_runner: CqlRunner = CqlRunner(hosts, keyspace, port)
+ self.journal_writer: JournalWriter = JournalWriter(journal_writer)
+ self.objstorage: ObjStorage = ObjStorage(objstorage)
def check_config(self, *, check_write: bool) -> bool:
self._cql_runner.check_read()
return True
def _content_get_from_hash(self, algo, hash_) -> Iterable:
"""From the name of a hash algorithm and a value of that hash,
looks up the "hash -> token" secondary table (content_by_{algo})
to get tokens.
Then, looks up the main table (content) to get all contents with
that token, and filters out contents whose hash doesn't match."""
found_tokens = self._cql_runner.content_get_tokens_from_single_hash(algo, hash_)
for token in found_tokens:
+ assert isinstance(token, int), found_tokens
# Query the main table ('content').
res = self._cql_runner.content_get_from_token(token)
for row in res:
# re-check the the hash (in case of murmur3 collision)
if getattr(row, algo) == hash_:
yield row
def _content_add(self, contents: List[Content], with_data: bool) -> Dict:
# Filter-out content already in the database.
contents = [
c for c in contents if not self._cql_runner.content_get_from_pk(c.to_dict())
]
self.journal_writer.content_add(contents)
if with_data:
# First insert to the objstorage, if the endpoint is
# `content_add` (as opposed to `content_add_metadata`).
# TODO: this should probably be done in concurrently to inserting
# in index tables (but still before the main table; so an entry is
# only added to the main table after everything else was
# successfully inserted.
summary = self.objstorage.content_add(
c for c in contents if c.status != "absent"
)
content_add_bytes = summary["content:add:bytes"]
content_add = 0
for content in contents:
content_add += 1
# Check for sha1 or sha1_git collisions. This test is not atomic
# with the insertion, so it won't detect a collision if both
# contents are inserted at the same time, but it's good enough.
#
# The proper way to do it would probably be a BATCH, but this
# would be inefficient because of the number of partitions we
# need to affect (len(HASH_ALGORITHMS)+1, which is currently 5)
for algo in {"sha1", "sha1_git"}:
collisions = []
# Get tokens of 'content' rows with the same value for
# sha1/sha1_git
rows = self._content_get_from_hash(algo, content.get_hash(algo))
for row in rows:
if getattr(row, algo) != content.get_hash(algo):
# collision of token(partition key), ignore this
# row
continue
for algo in HASH_ALGORITHMS:
if getattr(row, algo) != content.get_hash(algo):
# This hash didn't match; discard the row.
collisions.append(
{algo: getattr(row, algo) for algo in HASH_ALGORITHMS}
)
if collisions:
collisions.append(content.hashes())
raise HashCollision(algo, content.get_hash(algo), collisions)
(token, insertion_finalizer) = self._cql_runner.content_add_prepare(
ContentRow(**remove_keys(content.to_dict(), ("data",)))
)
# Then add to index tables
for algo in HASH_ALGORITHMS:
self._cql_runner.content_index_add_one(algo, content, token)
# Then to the main table
insertion_finalizer()
summary = {
"content:add": content_add,
}
if with_data:
summary["content:add:bytes"] = content_add_bytes
return summary
def content_add(self, content: List[Content]) -> Dict:
contents = [attr.evolve(c, ctime=now()) for c in content]
return self._content_add(list(contents), with_data=True)
def content_update(
self, contents: List[Dict[str, Any]], keys: List[str] = []
) -> None:
raise NotImplementedError(
"content_update is not supported by the Cassandra backend"
)
def content_add_metadata(self, content: List[Content]) -> Dict:
return self._content_add(content, with_data=False)
def content_get_data(self, content: Sha1) -> Optional[bytes]:
# FIXME: Make this method support slicing the `data`
return self.objstorage.content_get(content)
def content_get_partition(
self,
partition_id: int,
nb_partitions: int,
page_token: Optional[str] = None,
limit: int = 1000,
) -> PagedResult[Content]:
if limit is None:
raise StorageArgumentException("limit should not be None")
# Compute start and end of the range of tokens covered by the
# requested partition
partition_size = (TOKEN_END - TOKEN_BEGIN) // nb_partitions
range_start = TOKEN_BEGIN + partition_id * partition_size
range_end = TOKEN_BEGIN + (partition_id + 1) * partition_size
# offset the range start according to the `page_token`.
if page_token is not None:
if not (range_start <= int(page_token) <= range_end):
raise StorageArgumentException("Invalid page_token.")
range_start = int(page_token)
next_page_token: Optional[str] = None
rows = self._cql_runner.content_get_token_range(
range_start, range_end, limit + 1
)
contents = []
for counter, (tok, row) in enumerate(rows):
if row.status == "absent":
continue
row_d = row.to_dict()
if counter >= limit:
next_page_token = str(tok)
break
contents.append(Content(**row_d))
assert len(contents) <= limit
return PagedResult(results=contents, next_page_token=next_page_token)
def content_get(self, contents: List[Sha1]) -> List[Optional[Content]]:
contents_by_sha1: Dict[Sha1, Optional[Content]] = {}
for sha1 in contents:
# Get all (sha1, sha1_git, sha256, blake2s256) whose sha1
# matches the argument, from the index table ('content_by_sha1')
for row in self._content_get_from_hash("sha1", sha1):
row_d = row.to_dict()
row_d.pop("ctime")
content = Content(**row_d)
contents_by_sha1[content.sha1] = content
return [contents_by_sha1.get(sha1) for sha1 in contents]
def content_find(self, content: Dict[str, Any]) -> List[Content]:
# Find an algorithm that is common to all the requested contents.
# It will be used to do an initial filtering efficiently.
filter_algos = list(set(content).intersection(HASH_ALGORITHMS))
if not filter_algos:
raise StorageArgumentException(
"content keys must contain at least one "
f"of: {', '.join(sorted(HASH_ALGORITHMS))}"
)
common_algo = filter_algos[0]
results = []
rows = self._content_get_from_hash(common_algo, content[common_algo])
for row in rows:
# Re-check all the hashes, in case of collisions (either of the
# hash of the partition key, or the hashes in it)
for algo in HASH_ALGORITHMS:
if content.get(algo) and getattr(row, algo) != content[algo]:
# This hash didn't match; discard the row.
break
else:
# All hashes match, keep this row.
row_d = row.to_dict()
row_d["ctime"] = row.ctime.replace(tzinfo=datetime.timezone.utc)
results.append(Content(**row_d))
return results
def content_missing(
self, contents: List[Dict[str, Any]], key_hash: str = "sha1"
) -> Iterable[bytes]:
if key_hash not in DEFAULT_ALGORITHMS:
raise StorageArgumentException(
"key_hash should be one of {','.join(DEFAULT_ALGORITHMS)}"
)
for content in contents:
res = self.content_find(content)
if not res:
yield content[key_hash]
def content_missing_per_sha1(self, contents: List[bytes]) -> Iterable[bytes]:
return self.content_missing([{"sha1": c for c in contents}])
def content_missing_per_sha1_git(
self, contents: List[Sha1Git]
) -> Iterable[Sha1Git]:
return self.content_missing(
[{"sha1_git": c for c in contents}], key_hash="sha1_git"
)
def content_get_random(self) -> Sha1Git:
- return self._cql_runner.content_get_random().sha1_git
+ content = self._cql_runner.content_get_random()
+ assert content, "Could not find any content"
+ return content.sha1_git
def _skipped_content_add(self, contents: List[SkippedContent]) -> Dict:
# Filter-out content already in the database.
contents = [
c
for c in contents
if not self._cql_runner.skipped_content_get_from_pk(c.to_dict())
]
self.journal_writer.skipped_content_add(contents)
for content in contents:
# Compute token of the row in the main table
(token, insertion_finalizer) = self._cql_runner.skipped_content_add_prepare(
SkippedContentRow.from_dict({"origin": None, **content.to_dict()})
)
# Then add to index tables
for algo in HASH_ALGORITHMS:
self._cql_runner.skipped_content_index_add_one(algo, content, token)
# Then to the main table
insertion_finalizer()
return {"skipped_content:add": len(contents)}
def skipped_content_add(self, content: List[SkippedContent]) -> Dict:
contents = [attr.evolve(c, ctime=now()) for c in content]
return self._skipped_content_add(contents)
def skipped_content_missing(
self, contents: List[Dict[str, Any]]
) -> Iterable[Dict[str, Any]]:
for content in contents:
if not self._cql_runner.skipped_content_get_from_pk(content):
yield {algo: content[algo] for algo in DEFAULT_ALGORITHMS}
def directory_add(self, directories: List[Directory]) -> Dict:
# Filter out directories that are already inserted.
missing = self.directory_missing([dir_.id for dir_ in directories])
directories = [dir_ for dir_ in directories if dir_.id in missing]
self.journal_writer.directory_add(directories)
for directory in directories:
# Add directory entries to the 'directory_entry' table
for entry in directory.entries:
self._cql_runner.directory_entry_add_one(
DirectoryEntryRow(directory_id=directory.id, **entry.to_dict())
)
# Add the directory *after* adding all the entries, so someone
# calling snapshot_get_branch in the meantime won't end up
# with half the entries.
self._cql_runner.directory_add_one(DirectoryRow(id=directory.id))
return {"directory:add": len(directories)}
def directory_missing(self, directories: List[Sha1Git]) -> Iterable[Sha1Git]:
return self._cql_runner.directory_missing(directories)
def _join_dentry_to_content(self, dentry: DirectoryEntry) -> Dict[str, Any]:
keys = (
"status",
"sha1",
"sha1_git",
"sha256",
"length",
)
ret = dict.fromkeys(keys)
ret.update(dentry.to_dict())
if ret["type"] == "file":
contents = self.content_find({"sha1_git": ret["target"]})
if contents:
content = contents[0]
for key in keys:
ret[key] = getattr(content, key)
return ret
def _directory_ls(
self, directory_id: Sha1Git, recursive: bool, prefix: bytes = b""
) -> Iterable[Dict[str, Any]]:
if self.directory_missing([directory_id]):
return
rows = list(self._cql_runner.directory_entry_get([directory_id]))
for row in rows:
entry_d = row.to_dict()
# Build and yield the directory entry dict
del entry_d["directory_id"]
entry = DirectoryEntry.from_dict(entry_d)
ret = self._join_dentry_to_content(entry)
ret["name"] = prefix + ret["name"]
ret["dir_id"] = directory_id
yield ret
if recursive and ret["type"] == "dir":
yield from self._directory_ls(
ret["target"], True, prefix + ret["name"] + b"/"
)
def directory_entry_get_by_path(
self, directory: Sha1Git, paths: List[bytes]
) -> Optional[Dict[str, Any]]:
return self._directory_entry_get_by_path(directory, paths, b"")
def _directory_entry_get_by_path(
self, directory: Sha1Git, paths: List[bytes], prefix: bytes
) -> Optional[Dict[str, Any]]:
if not paths:
return None
contents = list(self.directory_ls(directory))
if not contents:
return None
def _get_entry(entries, name):
"""Finds the entry with the requested name, prepends the
prefix (to get its full path), and returns it.
If no entry has that name, returns None."""
for entry in entries:
if entry["name"] == name:
entry = entry.copy()
entry["name"] = prefix + entry["name"]
return entry
first_item = _get_entry(contents, paths[0])
if len(paths) == 1:
return first_item
if not first_item or first_item["type"] != "dir":
return None
return self._directory_entry_get_by_path(
first_item["target"], paths[1:], prefix + paths[0] + b"/"
)
def directory_ls(
self, directory: Sha1Git, recursive: bool = False
) -> Iterable[Dict[str, Any]]:
yield from self._directory_ls(directory, recursive)
def directory_get_random(self) -> Sha1Git:
- return self._cql_runner.directory_get_random().id
+ directory = self._cql_runner.directory_get_random()
+ assert directory, "Could not find any directory"
+ return directory.id
def revision_add(self, revisions: List[Revision]) -> Dict:
# Filter-out revisions already in the database
missing = self.revision_missing([rev.id for rev in revisions])
revisions = [rev for rev in revisions if rev.id in missing]
self.journal_writer.revision_add(revisions)
for revision in revisions:
revobject = converters.revision_to_db(revision)
if revobject:
# Add parents first
for (rank, parent) in enumerate(revision.parents):
self._cql_runner.revision_parent_add_one(
RevisionParentRow(
id=revobject.id, parent_rank=rank, parent_id=parent
)
)
# Then write the main revision row.
# Writing this after all parents were written ensures that
# read endpoints don't return a partial view while writing
# the parents
self._cql_runner.revision_add_one(revobject)
return {"revision:add": len(revisions)}
def revision_missing(self, revisions: List[Sha1Git]) -> Iterable[Sha1Git]:
return self._cql_runner.revision_missing(revisions)
def revision_get(
self, revisions: List[Sha1Git]
) -> Iterable[Optional[Dict[str, Any]]]:
rows = self._cql_runner.revision_get(revisions)
revs = {}
for row in rows:
# TODO: use a single query to get all parents?
# (it might have lower latency, but requires more code and more
# bandwidth, because revision id would be part of each returned
# row)
parents = tuple(self._cql_runner.revision_parent_get(row.id))
# parent_rank is the clustering key, so results are already
# sorted by rank.
rev = converters.revision_from_db(row, parents=parents)
revs[rev.id] = rev.to_dict()
for rev_id in revisions:
yield revs.get(rev_id)
def _get_parent_revs(
self,
rev_ids: Iterable[Sha1Git],
seen: Set[Sha1Git],
limit: Optional[int],
short: bool,
) -> Union[
Iterable[Dict[str, Any]], Iterable[Tuple[Sha1Git, Tuple[Sha1Git, ...]]],
]:
if limit and len(seen) >= limit:
return
rev_ids = [id_ for id_ in rev_ids if id_ not in seen]
if not rev_ids:
return
seen |= set(rev_ids)
# We need this query, even if short=True, to return consistent
# results (ie. not return only a subset of a revision's parents
# if it is being written)
if short:
ids = self._cql_runner.revision_get_ids(rev_ids)
for id_ in ids:
# TODO: use a single query to get all parents?
# (it might have less latency, but requires less code and more
# bandwidth (because revision id would be part of each returned
# row)
parents = tuple(self._cql_runner.revision_parent_get(id_))
# parent_rank is the clustering key, so results are already
# sorted by rank.
yield (id_, parents)
yield from self._get_parent_revs(parents, seen, limit, short)
else:
rows = self._cql_runner.revision_get(rev_ids)
for row in rows:
# TODO: use a single query to get all parents?
# (it might have less latency, but requires less code and more
# bandwidth (because revision id would be part of each returned
# row)
parents = tuple(self._cql_runner.revision_parent_get(row.id))
# parent_rank is the clustering key, so results are already
# sorted by rank.
rev = converters.revision_from_db(row, parents=parents)
yield rev.to_dict()
yield from self._get_parent_revs(parents, seen, limit, short)
def revision_log(
self, revisions: List[Sha1Git], limit: Optional[int] = None
) -> Iterable[Optional[Dict[str, Any]]]:
seen: Set[Sha1Git] = set()
yield from self._get_parent_revs(revisions, seen, limit, False)
def revision_shortlog(
self, revisions: List[Sha1Git], limit: Optional[int] = None
) -> Iterable[Optional[Tuple[Sha1Git, Tuple[Sha1Git, ...]]]]:
seen: Set[Sha1Git] = set()
yield from self._get_parent_revs(revisions, seen, limit, True)
def revision_get_random(self) -> Sha1Git:
- return self._cql_runner.revision_get_random().id
+ revision = self._cql_runner.revision_get_random()
+ assert revision, "Could not find any revision"
+ return revision.id
def release_add(self, releases: List[Release]) -> Dict:
to_add = []
for rel in releases:
if rel not in to_add:
to_add.append(rel)
missing = set(self.release_missing([rel.id for rel in to_add]))
to_add = [rel for rel in to_add if rel.id in missing]
self.journal_writer.release_add(to_add)
for release in to_add:
if release:
self._cql_runner.release_add_one(converters.release_to_db(release))
return {"release:add": len(to_add)}
def release_missing(self, releases: List[Sha1Git]) -> Iterable[Sha1Git]:
return self._cql_runner.release_missing(releases)
def release_get(
self, releases: List[Sha1Git]
) -> Iterable[Optional[Dict[str, Any]]]:
rows = self._cql_runner.release_get(releases)
rels = {}
for row in rows:
release = converters.release_from_db(row)
rels[row.id] = release.to_dict()
for rel_id in releases:
yield rels.get(rel_id)
def release_get_random(self) -> Sha1Git:
- return self._cql_runner.release_get_random().id
+ release = self._cql_runner.release_get_random()
+ assert release, "Could not find any release"
+ return release.id
def snapshot_add(self, snapshots: List[Snapshot]) -> Dict:
missing = self._cql_runner.snapshot_missing([snp.id for snp in snapshots])
snapshots = [snp for snp in snapshots if snp.id in missing]
for snapshot in snapshots:
self.journal_writer.snapshot_add([snapshot])
# Add branches
for (branch_name, branch) in snapshot.branches.items():
if branch is None:
target_type: Optional[str] = None
target: Optional[bytes] = None
else:
target_type = branch.target_type.value
target = branch.target
self._cql_runner.snapshot_branch_add_one(
SnapshotBranchRow(
snapshot_id=snapshot.id,
name=branch_name,
target_type=target_type,
target=target,
)
)
# Add the snapshot *after* adding all the branches, so someone
# calling snapshot_get_branch in the meantime won't end up
# with half the branches.
self._cql_runner.snapshot_add_one(SnapshotRow(id=snapshot.id))
return {"snapshot:add": len(snapshots)}
def snapshot_missing(self, snapshots: List[Sha1Git]) -> Iterable[Sha1Git]:
return self._cql_runner.snapshot_missing(snapshots)
def snapshot_get(self, snapshot_id: Sha1Git) -> Optional[Dict[str, Any]]:
d = self.snapshot_get_branches(snapshot_id)
if d is None:
return None
return {
"id": d["id"],
"branches": {
name: branch.to_dict() if branch else None
for (name, branch) in d["branches"].items()
},
"next_branch": d["next_branch"],
}
def snapshot_get_by_origin_visit(
self, origin: str, visit: int
) -> Optional[Dict[str, Any]]:
visit_status = self.origin_visit_status_get_latest(
origin, visit, require_snapshot=True
)
if visit_status and visit_status.snapshot:
return self.snapshot_get(visit_status.snapshot)
return None
def snapshot_count_branches(
self, snapshot_id: Sha1Git
) -> Optional[Dict[Optional[str], int]]:
if self._cql_runner.snapshot_missing([snapshot_id]):
# Makes sure we don't fetch branches for a snapshot that is
# being added.
return None
return self._cql_runner.snapshot_count_branches(snapshot_id)
def snapshot_get_branches(
self,
snapshot_id: Sha1Git,
branches_from: bytes = b"",
branches_count: int = 1000,
target_types: Optional[List[str]] = None,
) -> Optional[PartialBranches]:
if self._cql_runner.snapshot_missing([snapshot_id]):
# Makes sure we don't fetch branches for a snapshot that is
# being added.
return None
branches: List = []
while len(branches) < branches_count + 1:
new_branches = list(
self._cql_runner.snapshot_branch_get(
snapshot_id, branches_from, branches_count + 1
)
)
if not new_branches:
break
branches_from = new_branches[-1].name
new_branches_filtered = new_branches
# Filter by target_type
if target_types:
new_branches_filtered = [
branch
for branch in new_branches_filtered
if branch.target is not None and branch.target_type in target_types
]
branches.extend(new_branches_filtered)
if len(new_branches) < branches_count + 1:
break
if len(branches) > branches_count:
last_branch = branches.pop(-1).name
else:
last_branch = None
return PartialBranches(
id=snapshot_id,
branches={
branch.name: None
if branch.target is None
else SnapshotBranch(
target=branch.target, target_type=TargetType(branch.target_type)
)
for branch in branches
},
next_branch=last_branch,
)
def snapshot_get_random(self) -> Sha1Git:
- return self._cql_runner.snapshot_get_random().id
+ snapshot = self._cql_runner.snapshot_get_random()
+ assert snapshot, "Could not find any snapshot"
+ return snapshot.id
def object_find_by_sha1_git(self, ids: List[Sha1Git]) -> Dict[Sha1Git, List[Dict]]:
results: Dict[Sha1Git, List[Dict]] = {id_: [] for id_ in ids}
missing_ids = set(ids)
# Mind the order, revision is the most likely one for a given ID,
# so we check revisions first.
- queries = [
+ queries: List[Tuple[str, Callable[[List[Sha1Git]], List[Sha1Git]]]] = [
("revision", self._cql_runner.revision_missing),
("release", self._cql_runner.release_missing),
("content", self._cql_runner.content_missing_by_sha1_git),
("directory", self._cql_runner.directory_missing),
]
for (object_type, query_fn) in queries:
- found_ids = missing_ids - set(query_fn(missing_ids))
+ found_ids = missing_ids - set(query_fn(list(missing_ids)))
for sha1_git in found_ids:
results[sha1_git].append(
{"sha1_git": sha1_git, "type": object_type,}
)
missing_ids.remove(sha1_git)
if not missing_ids:
# We found everything, skipping the next queries.
break
return results
def origin_get(self, origins: List[str]) -> Iterable[Optional[Origin]]:
return [self.origin_get_one(origin) for origin in origins]
def origin_get_one(self, origin_url: str) -> Optional[Origin]:
"""Given an origin url, return the origin if it exists, None otherwise
"""
rows = list(self._cql_runner.origin_get_by_url(origin_url))
if rows:
assert len(rows) == 1
return Origin(url=rows[0].url)
else:
return None
def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]:
results = []
for sha1 in sha1s:
rows = list(self._cql_runner.origin_get_by_sha1(sha1))
origin = {"url": rows[0].url} if rows else None
results.append(origin)
return results
def origin_list(
self, page_token: Optional[str] = None, limit: int = 100
) -> PagedResult[Origin]:
# Compute what token to begin the listing from
start_token = TOKEN_BEGIN
if page_token:
start_token = int(page_token)
if not (TOKEN_BEGIN <= start_token <= TOKEN_END):
raise StorageArgumentException("Invalid page_token.")
next_page_token = None
origins = []
# Take one more origin so we can reuse it as the next page token if any
for (tok, row) in self._cql_runner.origin_list(start_token, limit + 1):
origins.append(Origin(url=row.url))
# keep reference of the last id for pagination purposes
last_id = tok
if len(origins) > limit:
# last origin id is the next page token
next_page_token = str(last_id)
# excluding that origin from the result to respect the limit size
origins = origins[:limit]
assert len(origins) <= limit
return PagedResult(results=origins, next_page_token=next_page_token)
def origin_search(
self,
url_pattern: str,
page_token: Optional[str] = None,
limit: int = 50,
regexp: bool = False,
with_visit: bool = False,
) -> PagedResult[Origin]:
# TODO: remove this endpoint, swh-search should be used instead.
next_page_token = None
offset = int(page_token) if page_token else 0
origin_rows = [row for row in self._cql_runner.origin_iter_all()]
if regexp:
pat = re.compile(url_pattern)
origin_rows = [row for row in origin_rows if pat.search(row.url)]
else:
origin_rows = [row for row in origin_rows if url_pattern in row.url]
if with_visit:
origin_rows = [row for row in origin_rows if row.next_visit_id > 1]
origins = [Origin(url=row.url) for row in origin_rows]
origins = origins[offset : offset + limit + 1]
if len(origins) > limit:
# next offset
next_page_token = str(offset + limit)
# excluding that origin from the result to respect the limit size
origins = origins[:limit]
assert len(origins) <= limit
return PagedResult(results=origins, next_page_token=next_page_token)
def origin_add(self, origins: List[Origin]) -> Dict[str, int]:
to_add = [ori for ori in origins if self.origin_get_one(ori.url) is None]
self.journal_writer.origin_add(to_add)
for origin in to_add:
self._cql_runner.origin_add_one(
OriginRow(sha1=hash_url(origin.url), url=origin.url, next_visit_id=1)
)
return {"origin:add": len(to_add)}
def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]:
for visit in visits:
origin = self.origin_get_one(visit.origin)
if not origin: # Cannot add a visit without an origin
raise StorageArgumentException("Unknown origin %s", visit.origin)
all_visits = []
nb_visits = 0
for visit in visits:
nb_visits += 1
if not visit.visit:
visit_id = self._cql_runner.origin_generate_unique_visit_id(
visit.origin
)
visit = attr.evolve(visit, visit=visit_id)
self.journal_writer.origin_visit_add([visit])
self._cql_runner.origin_visit_add_one(OriginVisitRow(**visit.to_dict()))
assert visit.visit is not None
all_visits.append(visit)
self._origin_visit_status_add(
OriginVisitStatus(
origin=visit.origin,
visit=visit.visit,
date=visit.date,
status="created",
snapshot=None,
)
)
return all_visits
def _origin_visit_status_add(self, visit_status: OriginVisitStatus) -> None:
"""Add an origin visit status"""
self.journal_writer.origin_visit_status_add([visit_status])
self._cql_runner.origin_visit_status_add_one(
converters.visit_status_to_row(visit_status)
)
def origin_visit_status_add(self, visit_statuses: List[OriginVisitStatus]) -> None:
# First round to check existence (fail early if any is ko)
for visit_status in visit_statuses:
origin_url = self.origin_get_one(visit_status.origin)
if not origin_url:
raise StorageArgumentException(f"Unknown origin {visit_status.origin}")
for visit_status in visit_statuses:
self._origin_visit_status_add(visit_status)
def _origin_visit_apply_last_status(self, visit: Dict[str, Any]) -> Dict[str, Any]:
"""Retrieve the latest visit status information for the origin visit.
Then merge it with the visit and return it.
"""
row = self._cql_runner.origin_visit_status_get_latest(
visit["origin"], visit["visit"]
)
assert row is not None
visit_status = converters.row_to_visit_status(row)
return {
# default to the values in visit
**visit,
# override with the last update
**visit_status.to_dict(),
# visit['origin'] is the URL (via a join), while
# visit_status['origin'] is only an id.
"origin": visit["origin"],
# but keep the date of the creation of the origin visit
"date": visit["date"],
}
def _origin_visit_get_latest_status(self, visit: OriginVisit) -> OriginVisitStatus:
"""Retrieve the latest visit status information for the origin visit object.
"""
+ assert visit.visit
row = self._cql_runner.origin_visit_status_get_latest(visit.origin, visit.visit)
assert row is not None
visit_status = converters.row_to_visit_status(row)
return attr.evolve(visit_status, origin=visit.origin)
@staticmethod
def _format_origin_visit_row(visit):
return {
**visit.to_dict(),
"origin": visit.origin,
"date": visit.date.replace(tzinfo=datetime.timezone.utc),
}
def origin_visit_get(
self,
origin: str,
page_token: Optional[str] = None,
order: ListOrder = ListOrder.ASC,
limit: int = 10,
) -> PagedResult[OriginVisit]:
if not isinstance(order, ListOrder):
raise StorageArgumentException("order must be a ListOrder value")
if page_token and not isinstance(page_token, str):
raise StorageArgumentException("page_token must be a string.")
next_page_token = None
- visit_from = page_token and int(page_token)
+ visit_from = None if page_token is None else int(page_token)
visits: List[OriginVisit] = []
extra_limit = limit + 1
rows = self._cql_runner.origin_visit_get(origin, visit_from, extra_limit, order)
for row in rows:
visits.append(converters.row_to_visit(row))
assert len(visits) <= extra_limit
if len(visits) == extra_limit:
visits = visits[:limit]
next_page_token = str(visits[-1].visit)
return PagedResult(results=visits, next_page_token=next_page_token)
def origin_visit_status_get(
self,
origin: str,
visit: int,
page_token: Optional[str] = None,
order: ListOrder = ListOrder.ASC,
limit: int = 10,
) -> PagedResult[OriginVisitStatus]:
next_page_token = None
date_from = None
if page_token is not None:
date_from = datetime.datetime.fromisoformat(page_token)
# Take one more visit status so we can reuse it as the next page token if any
rows = self._cql_runner.origin_visit_status_get_range(
origin, visit, date_from, limit + 1, order
)
visit_statuses = [converters.row_to_visit_status(row) for row in rows]
if len(visit_statuses) > limit:
# last visit status date is the next page token
next_page_token = str(visit_statuses[-1].date)
# excluding that visit status from the result to respect the limit size
visit_statuses = visit_statuses[:limit]
return PagedResult(results=visit_statuses, next_page_token=next_page_token)
def origin_visit_find_by_date(
self, origin: str, visit_date: datetime.datetime
) -> Optional[OriginVisit]:
# Iterator over all the visits of the origin
# This should be ok for now, as there aren't too many visits
# per origin.
rows = list(self._cql_runner.origin_visit_get_all(origin))
def key(visit):
dt = visit.date.replace(tzinfo=datetime.timezone.utc) - visit_date
return (abs(dt), -visit.visit)
if rows:
return converters.row_to_visit(min(rows, key=key))
return None
def origin_visit_get_by(self, origin: str, visit: int) -> Optional[OriginVisit]:
row = self._cql_runner.origin_visit_get_one(origin, visit)
if row:
return converters.row_to_visit(row)
return None
def origin_visit_get_latest(
self,
origin: str,
type: Optional[str] = None,
allowed_statuses: Optional[List[str]] = None,
require_snapshot: bool = False,
) -> Optional[OriginVisit]:
if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES):
raise StorageArgumentException(
f"Unknown allowed statuses {','.join(allowed_statuses)}, only "
f"{','.join(VISIT_STATUSES)} authorized"
)
# TODO: Do not fetch all visits
rows = self._cql_runner.origin_visit_get_all(origin)
latest_visit = None
for row in rows:
visit = self._format_origin_visit_row(row)
updated_visit = self._origin_visit_apply_last_status(visit)
if type is not None and updated_visit["type"] != type:
continue
if allowed_statuses and updated_visit["status"] not in allowed_statuses:
continue
if require_snapshot and updated_visit["snapshot"] is None:
continue
# updated_visit is a candidate
if latest_visit is not None:
if updated_visit["date"] < latest_visit["date"]:
continue
if updated_visit["visit"] < latest_visit["visit"]:
continue
latest_visit = updated_visit
if latest_visit is None:
return None
return OriginVisit(
origin=latest_visit["origin"],
visit=latest_visit["visit"],
date=latest_visit["date"],
type=latest_visit["type"],
)
def origin_visit_status_get_latest(
self,
origin_url: str,
visit: int,
allowed_statuses: Optional[List[str]] = None,
require_snapshot: bool = False,
) -> Optional[OriginVisitStatus]:
if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES):
raise StorageArgumentException(
f"Unknown allowed statuses {','.join(allowed_statuses)}, only "
f"{','.join(VISIT_STATUSES)} authorized"
)
rows = list(
self._cql_runner.origin_visit_status_get(
origin_url, visit, allowed_statuses, require_snapshot
)
)
# filtering is done python side as we cannot do it server side
if allowed_statuses:
rows = [row for row in rows if row.status in allowed_statuses]
if require_snapshot:
rows = [row for row in rows if row.snapshot is not None]
if not rows:
return None
return converters.row_to_visit_status(rows[0])
def origin_visit_status_get_random(
self, type: str
) -> Optional[Tuple[OriginVisit, OriginVisitStatus]]:
back_in_the_day = now() - datetime.timedelta(weeks=12) # 3 months back
# Random position to start iteration at
start_token = random.randint(TOKEN_BEGIN, TOKEN_END)
# Iterator over all visits, ordered by token(origins) then visit_id
rows = self._cql_runner.origin_visit_iter(start_token)
for row in rows:
visit = converters.row_to_visit(row)
visit_status = self._origin_visit_get_latest_status(visit)
if visit.date > back_in_the_day and visit_status.status == "full":
return visit, visit_status
return None
def stat_counters(self):
rows = self._cql_runner.stat_counters()
keys = (
"content",
"directory",
"origin",
"origin_visit",
"release",
"revision",
"skipped_content",
"snapshot",
)
stats = {key: 0 for key in keys}
stats.update({row.object_type: row.count for row in rows})
return stats
def refresh_stat_counters(self):
pass
def raw_extrinsic_metadata_add(self, metadata: List[RawExtrinsicMetadata]) -> None:
self.journal_writer.raw_extrinsic_metadata_add(metadata)
for metadata_entry in metadata:
if not self._cql_runner.metadata_authority_get(
metadata_entry.authority.type.value, metadata_entry.authority.url
):
raise StorageArgumentException(
f"Unknown authority {metadata_entry.authority}"
)
if not self._cql_runner.metadata_fetcher_get(
metadata_entry.fetcher.name, metadata_entry.fetcher.version
):
raise StorageArgumentException(
f"Unknown fetcher {metadata_entry.fetcher}"
)
try:
row = RawExtrinsicMetadataRow(
type=metadata_entry.type.value,
id=str(metadata_entry.id),
authority_type=metadata_entry.authority.type.value,
authority_url=metadata_entry.authority.url,
discovery_date=metadata_entry.discovery_date,
fetcher_name=metadata_entry.fetcher.name,
fetcher_version=metadata_entry.fetcher.version,
format=metadata_entry.format,
metadata=metadata_entry.metadata,
origin=metadata_entry.origin,
visit=metadata_entry.visit,
snapshot=map_optional(str, metadata_entry.snapshot),
release=map_optional(str, metadata_entry.release),
revision=map_optional(str, metadata_entry.revision),
path=metadata_entry.path,
directory=map_optional(str, metadata_entry.directory),
)
self._cql_runner.raw_extrinsic_metadata_add(row)
except TypeError as e:
raise StorageArgumentException(*e.args)
def raw_extrinsic_metadata_get(
self,
type: MetadataTargetType,
id: Union[str, SWHID],
authority: MetadataAuthority,
after: Optional[datetime.datetime] = None,
page_token: Optional[bytes] = None,
limit: int = 1000,
) -> PagedResult[RawExtrinsicMetadata]:
if type == MetadataTargetType.ORIGIN:
if isinstance(id, SWHID):
raise StorageArgumentException(
f"raw_extrinsic_metadata_get called with type='origin', "
f"but provided id is an SWHID: {id!r}"
)
else:
if not isinstance(id, SWHID):
raise StorageArgumentException(
f"raw_extrinsic_metadata_get called with type!='origin', "
f"but provided id is not an SWHID: {id!r}"
)
if page_token is not None:
(after_date, after_fetcher_name, after_fetcher_url) = msgpack_loads(
base64.b64decode(page_token)
)
if after and after_date < after:
raise StorageArgumentException(
"page_token is inconsistent with the value of 'after'."
)
entries = self._cql_runner.raw_extrinsic_metadata_get_after_date_and_fetcher( # noqa
str(id),
authority.type.value,
authority.url,
after_date,
after_fetcher_name,
after_fetcher_url,
)
elif after is not None:
entries = self._cql_runner.raw_extrinsic_metadata_get_after_date(
str(id), authority.type.value, authority.url, after
)
else:
entries = self._cql_runner.raw_extrinsic_metadata_get(
str(id), authority.type.value, authority.url
)
if limit:
entries = itertools.islice(entries, 0, limit + 1)
results = []
for entry in entries:
discovery_date = entry.discovery_date.replace(tzinfo=datetime.timezone.utc)
assert str(id) == entry.id
result = RawExtrinsicMetadata(
type=MetadataTargetType(entry.type),
id=id,
authority=MetadataAuthority(
type=MetadataAuthorityType(entry.authority_type),
url=entry.authority_url,
),
fetcher=MetadataFetcher(
name=entry.fetcher_name, version=entry.fetcher_version,
),
discovery_date=discovery_date,
format=entry.format,
metadata=entry.metadata,
origin=entry.origin,
visit=entry.visit,
snapshot=map_optional(parse_swhid, entry.snapshot),
release=map_optional(parse_swhid, entry.release),
revision=map_optional(parse_swhid, entry.revision),
path=entry.path,
directory=map_optional(parse_swhid, entry.directory),
)
results.append(result)
if len(results) > limit:
results.pop()
assert len(results) == limit
last_result = results[-1]
next_page_token: Optional[str] = base64.b64encode(
msgpack_dumps(
(
last_result.discovery_date,
last_result.fetcher.name,
last_result.fetcher.version,
)
)
).decode()
else:
next_page_token = None
return PagedResult(next_page_token=next_page_token, results=results,)
def metadata_fetcher_add(self, fetchers: List[MetadataFetcher]) -> None:
self.journal_writer.metadata_fetcher_add(fetchers)
for fetcher in fetchers:
self._cql_runner.metadata_fetcher_add(
MetadataFetcherRow(
name=fetcher.name,
version=fetcher.version,
metadata=json.dumps(map_optional(dict, fetcher.metadata)),
)
)
def metadata_fetcher_get(
self, name: str, version: str
) -> Optional[MetadataFetcher]:
fetcher = self._cql_runner.metadata_fetcher_get(name, version)
if fetcher:
return MetadataFetcher(
name=fetcher.name,
version=fetcher.version,
metadata=json.loads(fetcher.metadata),
)
else:
return None
def metadata_authority_add(self, authorities: List[MetadataAuthority]) -> None:
self.journal_writer.metadata_authority_add(authorities)
for authority in authorities:
self._cql_runner.metadata_authority_add(
MetadataAuthorityRow(
url=authority.url,
type=authority.type.value,
metadata=json.dumps(map_optional(dict, authority.metadata)),
)
)
def metadata_authority_get(
self, type: MetadataAuthorityType, url: str
) -> Optional[MetadataAuthority]:
authority = self._cql_runner.metadata_authority_get(type.value, url)
if authority:
return MetadataAuthority(
type=MetadataAuthorityType(authority.type),
url=authority.url,
metadata=json.loads(authority.metadata),
)
else:
return None
def clear_buffers(self, object_types: Optional[List[str]] = None) -> None:
"""Do nothing
"""
return None
def flush(self, object_types: Optional[List[str]] = None) -> Dict:
return {}