diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py
index 13a54350..f6c300d1 100644
--- a/swh/storage/cassandra/cql.py
+++ b/swh/storage/cassandra/cql.py
@@ -1,1012 +1,976 @@
# Copyright (C) 2019-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import dataclasses
import datetime
import functools
import logging
import random
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Tuple,
Type,
TypeVar,
)
from cassandra import CoordinationFailure
from cassandra.cluster import Cluster, EXEC_PROFILE_DEFAULT, ExecutionProfile, ResultSet
from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy
from cassandra.query import PreparedStatement, BoundStatement, dict_factory
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
retry_if_exception_type,
)
from mypy_extensions import NamedArg
from swh.model.model import (
Content,
SkippedContent,
Sha1Git,
TimestampWithTimezone,
Timestamp,
Person,
)
from swh.storage.interface import ListOrder
from .common import TOKEN_BEGIN, TOKEN_END, hash_url, remove_keys
from .model import (
BaseRow,
ContentRow,
DirectoryEntryRow,
DirectoryRow,
MetadataAuthorityRow,
MetadataFetcherRow,
ObjectCountRow,
OriginRow,
OriginVisitRow,
OriginVisitStatusRow,
RawExtrinsicMetadataRow,
ReleaseRow,
RevisionParentRow,
RevisionRow,
SkippedContentRow,
SnapshotBranchRow,
SnapshotRow,
)
from .schema import CREATE_TABLES_QUERIES, HASH_ALGORITHMS
logger = logging.getLogger(__name__)
_execution_profiles = {
EXEC_PROFILE_DEFAULT: ExecutionProfile(
load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy()),
row_factory=dict_factory,
),
}
# Configuration for cassandra-driver's access to servers:
# * hit the right server directly when sending a query (TokenAwarePolicy),
# * if there's more than one, then pick one at random that's in the same
# datacenter as the client (DCAwareRoundRobinPolicy)
def create_keyspace(
hosts: List[str], keyspace: str, port: int = 9042, *, durable_writes=True
):
cluster = Cluster(hosts, port=port, execution_profiles=_execution_profiles)
session = cluster.connect()
extra_params = ""
if not durable_writes:
extra_params = "AND durable_writes = false"
session.execute(
"""CREATE KEYSPACE IF NOT EXISTS "%s"
WITH REPLICATION = {
'class' : 'SimpleStrategy',
'replication_factor' : 1
} %s;
"""
% (keyspace, extra_params)
)
session.execute('USE "%s"' % keyspace)
for query in CREATE_TABLES_QUERIES:
session.execute(query)
TRet = TypeVar("TRet")
def _prepared_statement(
query: str,
) -> Callable[[Callable[..., TRet]], Callable[..., TRet]]:
"""Returns a decorator usable on methods of CqlRunner, to
inject them with a 'statement' argument, that is a prepared
statement corresponding to the query.
This only works on methods of CqlRunner, as preparing a
statement requires a connection to a Cassandra server."""
def decorator(f):
@functools.wraps(f)
def newf(self, *args, **kwargs) -> TRet:
if f.__name__ not in self._prepared_statements:
statement: PreparedStatement = self._session.prepare(query)
self._prepared_statements[f.__name__] = statement
return f(
self, *args, **kwargs, statement=self._prepared_statements[f.__name__]
)
return newf
return decorator
TArg = TypeVar("TArg")
TSelf = TypeVar("TSelf")
def _prepared_insert_statement(
row_class: Type[BaseRow],
) -> Callable[
[Callable[[TSelf, TArg, NamedArg(Any, "statement")], TRet]], # noqa
Callable[[TSelf, TArg], TRet],
]:
"""Shorthand for using `_prepared_statement` for `INSERT INTO`
statements."""
columns = row_class.cols()
return _prepared_statement(
"INSERT INTO %s (%s) VALUES (%s)"
% (row_class.TABLE, ", ".join(columns), ", ".join("?" for _ in columns),)
)
def _prepared_exists_statement(
table_name: str,
) -> Callable[
[Callable[[TSelf, TArg, NamedArg(Any, "statement")], TRet]], # noqa
Callable[[TSelf, TArg], TRet],
]:
"""Shorthand for using `_prepared_statement` for queries that only
check which ids in a list exist in the table."""
return _prepared_statement(f"SELECT id FROM {table_name} WHERE id IN ?")
def _prepared_select_statement(
row_class: Type[BaseRow], clauses: str = "", cols: Optional[List[str]] = None,
) -> Callable[[Callable[..., TRet]], Callable[..., TRet]]:
if cols is None:
cols = row_class.cols()
return _prepared_statement(
f"SELECT {', '.join(cols)} FROM {row_class.TABLE} {clauses}"
)
class CqlRunner:
"""Class managing prepared statements and building queries to be sent
to Cassandra."""
def __init__(self, hosts: List[str], keyspace: str, port: int):
self._cluster = Cluster(
hosts, port=port, execution_profiles=_execution_profiles
)
self._session = self._cluster.connect(keyspace)
self._cluster.register_user_type(
keyspace, "microtimestamp_with_timezone", TimestampWithTimezone
)
self._cluster.register_user_type(keyspace, "microtimestamp", Timestamp)
self._cluster.register_user_type(keyspace, "person", Person)
self._prepared_statements: Dict[str, PreparedStatement] = {}
##########################
# Common utility functions
##########################
MAX_RETRIES = 3
@retry(
wait=wait_random_exponential(multiplier=1, max=10),
stop=stop_after_attempt(MAX_RETRIES),
retry=retry_if_exception_type(CoordinationFailure),
)
def _execute_with_retries(self, statement, args) -> ResultSet:
return self._session.execute(statement, args, timeout=1000.0)
@_prepared_statement(
"UPDATE object_count SET count = count + ? "
"WHERE partition_key = 0 AND object_type = ?"
)
def _increment_counter(
self, object_type: str, nb: int, *, statement: PreparedStatement
) -> None:
self._execute_with_retries(statement, [nb, object_type])
def _add_one(self, statement, obj: BaseRow) -> None:
self._increment_counter(obj.TABLE, 1)
self._execute_with_retries(statement, dataclasses.astuple(obj))
_T = TypeVar("_T", bound=BaseRow)
def _get_random_row(self, row_class: Type[_T], statement) -> Optional[_T]: # noqa
"""Takes a prepared statement of the form
"SELECT * FROM
WHERE token() > ? LIMIT 1"
and uses it to return a random row"""
token = random.randint(TOKEN_BEGIN, TOKEN_END)
rows = self._execute_with_retries(statement, [token])
if not rows:
# There are no row with a greater token; wrap around to get
# the row with the smallest token
rows = self._execute_with_retries(statement, [TOKEN_BEGIN])
if rows:
return row_class.from_dict(rows.one()) # type: ignore
else:
return None
def _missing(self, statement, ids):
rows = self._execute_with_retries(statement, [ids])
found_ids = {row["id"] for row in rows}
return [id_ for id_ in ids if id_ not in found_ids]
##########################
# 'content' table
##########################
def _content_add_finalize(self, statement: BoundStatement) -> None:
"""Returned currified by content_add_prepare, to be called when the
content row should be added to the primary table."""
self._execute_with_retries(statement, None)
self._increment_counter("content", 1)
@_prepared_insert_statement(ContentRow)
def content_add_prepare(
self, content: ContentRow, *, statement
) -> Tuple[int, Callable[[], None]]:
"""Prepares insertion of a Content to the main 'content' table.
Returns a token (to be used in secondary tables), and a function to be
called to perform the insertion in the main table."""
statement = statement.bind(dataclasses.astuple(content))
# Type used for hashing keys (usually, it will be
# cassandra.metadata.Murmur3Token)
token_class = self._cluster.metadata.token_map.token_class
# Token of the row when it will be inserted. This is equivalent to
# "SELECT token({', '.join(ContentRow.PARTITION_KEY)}) FROM content WHERE ..."
# after the row is inserted; but we need the token to insert in the
# index tables *before* inserting to the main 'content' table
token = token_class.from_key(statement.routing_key).value
assert TOKEN_BEGIN <= token <= TOKEN_END
# Function to be called after the indexes contain their respective
# row
finalizer = functools.partial(self._content_add_finalize, statement)
return (token, finalizer)
@_prepared_select_statement(
ContentRow, f"WHERE {' AND '.join(map('%s = ?'.__mod__, HASH_ALGORITHMS))}"
)
def content_get_from_pk(
self, content_hashes: Dict[str, bytes], *, statement
) -> Optional[ContentRow]:
rows = list(
self._execute_with_retries(
statement, [content_hashes[algo] for algo in HASH_ALGORITHMS]
)
)
assert len(rows) <= 1
if rows:
return ContentRow(**rows[0])
else:
return None
@_prepared_select_statement(
ContentRow, f"WHERE token({', '.join(ContentRow.PARTITION_KEY)}) = ?"
)
def content_get_from_token(self, token, *, statement) -> Iterable[ContentRow]:
return map(ContentRow.from_dict, self._execute_with_retries(statement, [token]))
@_prepared_select_statement(
ContentRow, f"WHERE token({', '.join(ContentRow.PARTITION_KEY)}) > ? LIMIT 1"
)
def content_get_random(self, *, statement) -> Optional[ContentRow]:
return self._get_random_row(ContentRow, statement)
@_prepared_statement(
(
"SELECT token({0}) AS tok, {1} FROM content "
"WHERE token({0}) >= ? AND token({0}) <= ? LIMIT ?"
).format(", ".join(ContentRow.PARTITION_KEY), ", ".join(ContentRow.cols()))
)
def content_get_token_range(
self, start: int, end: int, limit: int, *, statement
) -> Iterable[Tuple[int, ContentRow]]:
"""Returns an iterable of (token, row)"""
return (
(row["tok"], ContentRow.from_dict(remove_keys(row, ("tok",))))
for row in self._execute_with_retries(statement, [start, end, limit])
)
##########################
# 'content_by_*' tables
##########################
@_prepared_statement(
"SELECT sha1_git AS id FROM content_by_sha1_git WHERE sha1_git IN ?"
)
def content_missing_by_sha1_git(
self, ids: List[bytes], *, statement
) -> List[bytes]:
return self._missing(statement, ids)
def content_index_add_one(self, algo: str, content: Content, token: int) -> None:
"""Adds a row mapping content[algo] to the token of the Content in
the main 'content' table."""
query = (
f"INSERT INTO content_by_{algo} ({algo}, target_token) " f"VALUES (%s, %s)"
)
self._execute_with_retries(query, [content.get_hash(algo), token])
def content_get_tokens_from_single_hash(
self, algo: str, hash_: bytes
) -> Iterable[int]:
assert algo in HASH_ALGORITHMS
query = f"SELECT target_token FROM content_by_{algo} WHERE {algo} = %s"
return (
row["target_token"] for row in self._execute_with_retries(query, [hash_])
)
##########################
# 'skipped_content' table
##########################
_magic_null_pk = b""
"""
NULLs (or all-empty blobs) are not allowed in primary keys; instead use a
special value that can't possibly be a valid hash.
"""
def _skipped_content_add_finalize(self, statement: BoundStatement) -> None:
"""Returned currified by skipped_content_add_prepare, to be called
when the content row should be added to the primary table."""
self._execute_with_retries(statement, None)
self._increment_counter("skipped_content", 1)
@_prepared_insert_statement(SkippedContentRow)
def skipped_content_add_prepare(
self, content, *, statement
) -> Tuple[int, Callable[[], None]]:
"""Prepares insertion of a Content to the main 'skipped_content' table.
Returns a token (to be used in secondary tables), and a function to be
called to perform the insertion in the main table."""
# Replace NULLs (which are not allowed in the partition key) with
# an empty byte string
for key in SkippedContentRow.PARTITION_KEY:
if getattr(content, key) is None:
setattr(content, key, self._magic_null_pk)
statement = statement.bind(dataclasses.astuple(content))
# Type used for hashing keys (usually, it will be
# cassandra.metadata.Murmur3Token)
token_class = self._cluster.metadata.token_map.token_class
# Token of the row when it will be inserted. This is equivalent to
# "SELECT token({', '.join(SkippedContentRow.PARTITION_KEY)})
# FROM skipped_content WHERE ..."
# after the row is inserted; but we need the token to insert in the
# index tables *before* inserting to the main 'skipped_content' table
token = token_class.from_key(statement.routing_key).value
assert TOKEN_BEGIN <= token <= TOKEN_END
# Function to be called after the indexes contain their respective
# row
finalizer = functools.partial(self._skipped_content_add_finalize, statement)
return (token, finalizer)
@_prepared_select_statement(
SkippedContentRow,
f"WHERE {' AND '.join(map('%s = ?'.__mod__, HASH_ALGORITHMS))}",
)
def skipped_content_get_from_pk(
self, content_hashes: Dict[str, bytes], *, statement
) -> Optional[SkippedContentRow]:
rows = list(
self._execute_with_retries(
statement,
[
content_hashes[algo] or self._magic_null_pk
for algo in HASH_ALGORITHMS
],
)
)
assert len(rows) <= 1
if rows:
# TODO: convert _magic_null_pk back to None?
return SkippedContentRow.from_dict(rows[0])
else:
return None
##########################
# 'skipped_content_by_*' tables
##########################
def skipped_content_index_add_one(
self, algo: str, content: SkippedContent, token: int
) -> None:
"""Adds a row mapping content[algo] to the token of the SkippedContent
in the main 'skipped_content' table."""
query = (
f"INSERT INTO skipped_content_by_{algo} ({algo}, target_token) "
f"VALUES (%s, %s)"
)
self._execute_with_retries(
query, [content.get_hash(algo) or self._magic_null_pk, token]
)
##########################
# 'revision' table
##########################
@_prepared_exists_statement("revision")
def revision_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
return self._missing(statement, ids)
@_prepared_insert_statement(RevisionRow)
def revision_add_one(self, revision: RevisionRow, *, statement) -> None:
self._add_one(statement, revision)
@_prepared_statement("SELECT id FROM revision WHERE id IN ?")
def revision_get_ids(self, revision_ids, *, statement) -> Iterable[int]:
return (
row["id"] for row in self._execute_with_retries(statement, [revision_ids])
)
@_prepared_select_statement(RevisionRow, "WHERE id IN ?")
def revision_get(
self, revision_ids: List[Sha1Git], *, statement
) -> Iterable[RevisionRow]:
return map(
RevisionRow.from_dict, self._execute_with_retries(statement, [revision_ids])
)
@_prepared_select_statement(RevisionRow, "WHERE token(id) > ? LIMIT 1")
def revision_get_random(self, *, statement) -> Optional[RevisionRow]:
return self._get_random_row(RevisionRow, statement)
##########################
# 'revision_parent' table
##########################
@_prepared_insert_statement(RevisionParentRow)
def revision_parent_add_one(
self, revision_parent: RevisionParentRow, *, statement
) -> None:
self._add_one(statement, revision_parent)
@_prepared_statement("SELECT parent_id FROM revision_parent WHERE id = ?")
def revision_parent_get(
self, revision_id: Sha1Git, *, statement
) -> Iterable[bytes]:
return (
row["parent_id"]
for row in self._execute_with_retries(statement, [revision_id])
)
##########################
# 'release' table
##########################
@_prepared_exists_statement("release")
def release_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
return self._missing(statement, ids)
@_prepared_insert_statement(ReleaseRow)
def release_add_one(self, release: ReleaseRow, *, statement) -> None:
self._add_one(statement, release)
@_prepared_select_statement(ReleaseRow, "WHERE id in ?")
def release_get(self, release_ids: List[str], *, statement) -> Iterable[ReleaseRow]:
return map(
ReleaseRow.from_dict, self._execute_with_retries(statement, [release_ids])
)
@_prepared_select_statement(ReleaseRow, "WHERE token(id) > ? LIMIT 1")
def release_get_random(self, *, statement) -> Optional[ReleaseRow]:
return self._get_random_row(ReleaseRow, statement)
##########################
# 'directory' table
##########################
@_prepared_exists_statement("directory")
def directory_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
return self._missing(statement, ids)
@_prepared_insert_statement(DirectoryRow)
def directory_add_one(self, directory: DirectoryRow, *, statement) -> None:
"""Called after all calls to directory_entry_add_one, to
commit/finalize the directory."""
self._add_one(statement, directory)
@_prepared_select_statement(DirectoryRow, "WHERE token(id) > ? LIMIT 1")
def directory_get_random(self, *, statement) -> Optional[DirectoryRow]:
return self._get_random_row(DirectoryRow, statement)
##########################
# 'directory_entry' table
##########################
@_prepared_insert_statement(DirectoryEntryRow)
def directory_entry_add_one(self, entry: DirectoryEntryRow, *, statement) -> None:
self._add_one(statement, entry)
@_prepared_select_statement(DirectoryEntryRow, "WHERE directory_id IN ?")
def directory_entry_get(
self, directory_ids, *, statement
) -> Iterable[DirectoryEntryRow]:
return map(
DirectoryEntryRow.from_dict,
self._execute_with_retries(statement, [directory_ids]),
)
##########################
# 'snapshot' table
##########################
@_prepared_exists_statement("snapshot")
def snapshot_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
return self._missing(statement, ids)
@_prepared_insert_statement(SnapshotRow)
def snapshot_add_one(self, snapshot: SnapshotRow, *, statement) -> None:
self._add_one(statement, snapshot)
@_prepared_select_statement(SnapshotRow, "WHERE token(id) > ? LIMIT 1")
def snapshot_get_random(self, *, statement) -> Optional[SnapshotRow]:
return self._get_random_row(SnapshotRow, statement)
##########################
# 'snapshot_branch' table
##########################
@_prepared_insert_statement(SnapshotBranchRow)
def snapshot_branch_add_one(self, branch: SnapshotBranchRow, *, statement) -> None:
self._add_one(statement, branch)
@_prepared_statement(
"SELECT ascii_bins_count(target_type) AS counts "
"FROM snapshot_branch "
"WHERE snapshot_id = ? "
)
def snapshot_count_branches(
self, snapshot_id: Sha1Git, *, statement
) -> Dict[Optional[str], int]:
"""Returns a dictionary from type names to the number of branches
of that type."""
row = self._execute_with_retries(statement, [snapshot_id]).one()
(nb_none, counts) = row["counts"]
return {None: nb_none, **counts}
@_prepared_select_statement(
SnapshotBranchRow, "WHERE snapshot_id = ? AND name >= ? LIMIT ?"
)
def snapshot_branch_get(
self, snapshot_id: Sha1Git, from_: bytes, limit: int, *, statement
) -> Iterable[SnapshotBranchRow]:
return map(
SnapshotBranchRow.from_dict,
self._execute_with_retries(statement, [snapshot_id, from_, limit]),
)
##########################
# 'origin' table
##########################
@_prepared_insert_statement(OriginRow)
def origin_add_one(self, origin: OriginRow, *, statement) -> None:
self._add_one(statement, origin)
@_prepared_select_statement(OriginRow, "WHERE sha1 = ?")
def origin_get_by_sha1(self, sha1: bytes, *, statement) -> Iterable[OriginRow]:
return map(OriginRow.from_dict, self._execute_with_retries(statement, [sha1]))
def origin_get_by_url(self, url: str) -> Iterable[OriginRow]:
return self.origin_get_by_sha1(hash_url(url))
@_prepared_statement(
f'SELECT token(sha1) AS tok, {", ".join(OriginRow.cols())} '
f"FROM origin WHERE token(sha1) >= ? LIMIT ?"
)
def origin_list(
self, start_token: int, limit: int, *, statement
) -> Iterable[Tuple[int, OriginRow]]:
"""Returns an iterable of (token, origin)"""
return (
(row["tok"], OriginRow.from_dict(remove_keys(row, ("tok",))))
for row in self._execute_with_retries(statement, [start_token, limit])
)
@_prepared_select_statement(OriginRow)
def origin_iter_all(self, *, statement) -> Iterable[OriginRow]:
return map(OriginRow.from_dict, self._execute_with_retries(statement, []))
@_prepared_statement("SELECT next_visit_id FROM origin WHERE sha1 = ?")
def _origin_get_next_visit_id(self, origin_sha1: bytes, *, statement) -> int:
rows = list(self._execute_with_retries(statement, [origin_sha1]))
assert len(rows) == 1 # TODO: error handling
return rows[0]["next_visit_id"]
@_prepared_statement(
"UPDATE origin SET next_visit_id=? WHERE sha1 = ? IF next_visit_id=?"
)
def origin_generate_unique_visit_id(self, origin_url: str, *, statement) -> int:
origin_sha1 = hash_url(origin_url)
next_id = self._origin_get_next_visit_id(origin_sha1)
while True:
res = list(
self._execute_with_retries(
statement, [next_id + 1, origin_sha1, next_id]
)
)
assert len(res) == 1
if res[0]["[applied]"]:
# No data race
return next_id
else:
# Someone else updated it before we did, let's try again
next_id = res[0]["next_visit_id"]
# TODO: abort after too many attempts
return next_id
##########################
# 'origin_visit' table
##########################
- @_prepared_select_statement(
- OriginVisitRow, "WHERE origin = ? AND visit > ? ORDER BY visit ASC"
- )
- def _origin_visit_get_pagination_asc_no_limit(
- self, origin_url: str, last_visit: int, *, statement
- ) -> ResultSet:
- return self._execute_with_retries(statement, [origin_url, last_visit])
-
@_prepared_select_statement(
OriginVisitRow, "WHERE origin = ? AND visit > ? ORDER BY visit ASC LIMIT ?"
)
- def _origin_visit_get_pagination_asc_limit(
+ def _origin_visit_get_pagination_asc(
self, origin_url: str, last_visit: int, limit: int, *, statement
) -> ResultSet:
return self._execute_with_retries(statement, [origin_url, last_visit, limit])
- @_prepared_select_statement(
- OriginVisitRow, "WHERE origin = ? AND visit < ? ORDER BY visit DESC"
- )
- def _origin_visit_get_pagination_desc_no_limit(
- self, origin_url: str, last_visit: int, *, statement
- ) -> ResultSet:
- return self._execute_with_retries(statement, [origin_url, last_visit])
-
@_prepared_select_statement(
OriginVisitRow, "WHERE origin = ? AND visit < ? ORDER BY visit DESC LIMIT ?"
)
- def _origin_visit_get_pagination_desc_limit(
+ def _origin_visit_get_pagination_desc(
self, origin_url: str, last_visit: int, limit: int, *, statement
) -> ResultSet:
return self._execute_with_retries(statement, [origin_url, last_visit, limit])
@_prepared_select_statement(
OriginVisitRow, "WHERE origin = ? ORDER BY visit ASC LIMIT ?"
)
- def _origin_visit_get_no_pagination_asc_limit(
+ def _origin_visit_get_no_pagination_asc(
self, origin_url: str, limit: int, *, statement
) -> ResultSet:
return self._execute_with_retries(statement, [origin_url, limit])
- @_prepared_select_statement(OriginVisitRow, "WHERE origin = ? ORDER BY visit ASC ")
- def _origin_visit_get_no_pagination_asc_no_limit(
- self, origin_url: str, *, statement
- ) -> ResultSet:
- return self._execute_with_retries(statement, [origin_url])
-
- @_prepared_select_statement(OriginVisitRow, "WHERE origin = ? ORDER BY visit DESC")
- def _origin_visit_get_no_pagination_desc_no_limit(
- self, origin_url: str, *, statement
- ) -> ResultSet:
- return self._execute_with_retries(statement, [origin_url])
-
@_prepared_select_statement(
OriginVisitRow, "WHERE origin = ? ORDER BY visit DESC LIMIT ?"
)
- def _origin_visit_get_no_pagination_desc_limit(
+ def _origin_visit_get_no_pagination_desc(
self, origin_url: str, limit: int, *, statement
) -> ResultSet:
return self._execute_with_retries(statement, [origin_url, limit])
def origin_visit_get(
- self,
- origin_url: str,
- last_visit: Optional[int],
- limit: Optional[int],
- order: ListOrder,
+ self, origin_url: str, last_visit: Optional[int], limit: int, order: ListOrder,
) -> Iterable[OriginVisitRow]:
args: List[Any] = [origin_url]
if last_visit is not None:
page_name = "pagination"
args.append(last_visit)
else:
page_name = "no_pagination"
- if limit is not None:
- limit_name = "limit"
- args.append(limit)
- else:
- limit_name = "no_limit"
+ args.append(limit)
- method_name = f"_origin_visit_get_{page_name}_{order.value}_{limit_name}"
+ method_name = f"_origin_visit_get_{page_name}_{order.value}"
origin_visit_get_method = getattr(self, method_name)
return map(OriginVisitRow.from_dict, origin_visit_get_method(*args))
@_prepared_insert_statement(OriginVisitRow)
def origin_visit_add_one(self, visit: OriginVisitRow, *, statement) -> None:
self._add_one(statement, visit)
@_prepared_select_statement(OriginVisitRow, "WHERE origin = ? AND visit = ?")
def origin_visit_get_one(
self, origin_url: str, visit_id: int, *, statement
) -> Optional[OriginVisitRow]:
# TODO: error handling
rows = list(self._execute_with_retries(statement, [origin_url, visit_id]))
if rows:
return OriginVisitRow.from_dict(rows[0])
else:
return None
@_prepared_select_statement(OriginVisitRow, "WHERE origin = ?")
def origin_visit_get_all(
self, origin_url: str, *, statement
) -> Iterable[OriginVisitRow]:
return map(
OriginVisitRow.from_dict,
self._execute_with_retries(statement, [origin_url]),
)
@_prepared_select_statement(OriginVisitRow, "WHERE token(origin) >= ?")
def _origin_visit_iter_from(
self, min_token: int, *, statement
) -> Iterable[OriginVisitRow]:
return map(
OriginVisitRow.from_dict, self._execute_with_retries(statement, [min_token])
)
@_prepared_select_statement(OriginVisitRow, "WHERE token(origin) < ?")
def _origin_visit_iter_to(
self, max_token: int, *, statement
) -> Iterable[OriginVisitRow]:
return map(
OriginVisitRow.from_dict, self._execute_with_retries(statement, [max_token])
)
def origin_visit_iter(self, start_token: int) -> Iterator[OriginVisitRow]:
"""Returns all origin visits in order from this token,
and wraps around the token space."""
yield from self._origin_visit_iter_from(start_token)
yield from self._origin_visit_iter_to(start_token)
##########################
# 'origin_visit_status' table
##########################
@_prepared_select_statement(
OriginVisitStatusRow,
"WHERE origin = ? AND visit = ? AND date >= ? ORDER BY date ASC LIMIT ?",
)
def _origin_visit_status_get_with_date_asc_limit(
self,
origin: str,
visit: int,
date_from: datetime.datetime,
limit: int,
*,
statement,
) -> ResultSet:
return self._execute_with_retries(statement, [origin, visit, date_from, limit])
@_prepared_select_statement(
OriginVisitStatusRow,
"WHERE origin = ? AND visit = ? AND date <= ? ORDER BY visit DESC LIMIT ?",
)
def _origin_visit_status_get_with_date_desc_limit(
self,
origin: str,
visit: int,
date_from: datetime.datetime,
limit: int,
*,
statement,
) -> ResultSet:
return self._execute_with_retries(statement, [origin, visit, date_from, limit])
@_prepared_select_statement(
OriginVisitStatusRow,
"WHERE origin = ? AND visit = ? ORDER BY visit ASC LIMIT ?",
)
def _origin_visit_status_get_with_no_date_asc_limit(
self, origin: str, visit: int, limit: int, *, statement
) -> ResultSet:
return self._execute_with_retries(statement, [origin, visit, limit])
@_prepared_select_statement(
OriginVisitStatusRow,
"WHERE origin = ? AND visit = ? ORDER BY visit DESC LIMIT ?",
)
def _origin_visit_status_get_with_no_date_desc_limit(
self, origin: str, visit: int, limit: int, *, statement
) -> ResultSet:
return self._execute_with_retries(statement, [origin, visit, limit])
def origin_visit_status_get_range(
self,
origin: str,
visit: int,
date_from: Optional[datetime.datetime],
limit: int,
order: ListOrder,
) -> Iterable[OriginVisitStatusRow]:
args: List[Any] = [origin, visit]
if date_from is not None:
date_name = "date"
args.append(date_from)
else:
date_name = "no_date"
args.append(limit)
method_name = f"_origin_visit_status_get_with_{date_name}_{order.value}_limit"
origin_visit_status_get_method = getattr(self, method_name)
return map(
OriginVisitStatusRow.from_dict, origin_visit_status_get_method(*args)
)
@_prepared_insert_statement(OriginVisitStatusRow)
def origin_visit_status_add_one(
self, visit_update: OriginVisitStatusRow, *, statement
) -> None:
self._add_one(statement, visit_update)
def origin_visit_status_get_latest(
self, origin: str, visit: int,
) -> Optional[OriginVisitStatusRow]:
"""Given an origin visit id, return its latest origin_visit_status
"""
return next(self.origin_visit_status_get(origin, visit), None)
@_prepared_select_statement(
OriginVisitStatusRow, "WHERE origin = ? AND visit = ? ORDER BY date DESC"
)
def origin_visit_status_get(
self, origin: str, visit: int, *, statement,
) -> Iterator[OriginVisitStatusRow]:
"""Return all origin visit statuses for a given visit
"""
return map(
OriginVisitStatusRow.from_dict,
self._execute_with_retries(statement, [origin, visit]),
)
##########################
# 'metadata_authority' table
##########################
@_prepared_insert_statement(MetadataAuthorityRow)
def metadata_authority_add(self, authority: MetadataAuthorityRow, *, statement):
self._add_one(statement, authority)
@_prepared_select_statement(MetadataAuthorityRow, "WHERE type = ? AND url = ?")
def metadata_authority_get(
self, type, url, *, statement
) -> Optional[MetadataAuthorityRow]:
rows = list(self._execute_with_retries(statement, [type, url]))
if rows:
return MetadataAuthorityRow.from_dict(rows[0])
else:
return None
##########################
# 'metadata_fetcher' table
##########################
@_prepared_insert_statement(MetadataFetcherRow)
def metadata_fetcher_add(self, fetcher, *, statement):
self._add_one(statement, fetcher)
@_prepared_select_statement(MetadataFetcherRow, "WHERE name = ? AND version = ?")
def metadata_fetcher_get(
self, name, version, *, statement
) -> Optional[MetadataFetcherRow]:
rows = list(self._execute_with_retries(statement, [name, version]))
if rows:
return MetadataFetcherRow.from_dict(rows[0])
else:
return None
#########################
# 'raw_extrinsic_metadata' table
#########################
@_prepared_insert_statement(RawExtrinsicMetadataRow)
def raw_extrinsic_metadata_add(self, raw_extrinsic_metadata, *, statement):
self._add_one(statement, raw_extrinsic_metadata)
@_prepared_select_statement(
RawExtrinsicMetadataRow,
"WHERE id=? AND authority_url=? AND discovery_date>? AND authority_type=?",
)
def raw_extrinsic_metadata_get_after_date(
self,
id: str,
authority_type: str,
authority_url: str,
after: datetime.datetime,
*,
statement,
) -> Iterable[RawExtrinsicMetadataRow]:
return map(
RawExtrinsicMetadataRow.from_dict,
self._execute_with_retries(
statement, [id, authority_url, after, authority_type]
),
)
@_prepared_select_statement(
RawExtrinsicMetadataRow,
"WHERE id=? AND authority_type=? AND authority_url=? "
"AND (discovery_date, fetcher_name, fetcher_version) > (?, ?, ?)",
)
def raw_extrinsic_metadata_get_after_date_and_fetcher(
self,
id: str,
authority_type: str,
authority_url: str,
after_date: datetime.datetime,
after_fetcher_name: str,
after_fetcher_version: str,
*,
statement,
) -> Iterable[RawExtrinsicMetadataRow]:
return map(
RawExtrinsicMetadataRow.from_dict,
self._execute_with_retries(
statement,
[
id,
authority_type,
authority_url,
after_date,
after_fetcher_name,
after_fetcher_version,
],
),
)
@_prepared_select_statement(
RawExtrinsicMetadataRow, "WHERE id=? AND authority_url=? AND authority_type=?"
)
def raw_extrinsic_metadata_get(
self, id: str, authority_type: str, authority_url: str, *, statement
) -> Iterable[RawExtrinsicMetadataRow]:
return map(
RawExtrinsicMetadataRow.from_dict,
self._execute_with_retries(statement, [id, authority_url, authority_type]),
)
##########################
# Miscellaneous
##########################
@_prepared_statement("SELECT uuid() FROM revision LIMIT 1;")
def check_read(self, *, statement):
self._execute_with_retries(statement, [])
@_prepared_select_statement(ObjectCountRow, "WHERE partition_key=0")
def stat_counters(self, *, statement) -> Iterable[ObjectCountRow]:
return map(ObjectCountRow.from_dict, self._execute_with_retries(statement, []))
diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py
index 7eb039d9..9eda0424 100644
--- a/swh/storage/in_memory.py
+++ b/swh/storage/in_memory.py
@@ -1,630 +1,625 @@
# Copyright (C) 2015-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import datetime
import functools
import random
from collections import defaultdict
from typing import (
Any,
Dict,
Generic,
Iterable,
Iterator,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from swh.model.model import (
Content,
SkippedContent,
Sha1Git,
)
from swh.storage.cassandra import CassandraStorage
from swh.storage.cassandra.model import (
BaseRow,
ContentRow,
DirectoryRow,
DirectoryEntryRow,
MetadataAuthorityRow,
MetadataFetcherRow,
ObjectCountRow,
OriginRow,
OriginVisitRow,
OriginVisitStatusRow,
RawExtrinsicMetadataRow,
ReleaseRow,
RevisionRow,
RevisionParentRow,
SkippedContentRow,
SnapshotRow,
SnapshotBranchRow,
)
from swh.storage.interface import ListOrder
from swh.storage.objstorage import ObjStorage
from .converters import origin_url_to_sha1
from .writer import JournalWriter
TRow = TypeVar("TRow", bound=BaseRow)
class Table(Generic[TRow]):
def __init__(self, row_class: Type[TRow]):
self.row_class = row_class
self.primary_key_cols = row_class.PARTITION_KEY + row_class.CLUSTERING_KEY
# Map from tokens to clustering keys to rows
# These are not actually partitions (or rather, there is one partition
# for each token) and they aren't sorted.
# But it is good enough if we don't care about performance;
# and makes the code a lot simpler.
self.data: Dict[int, Dict[Tuple, TRow]] = defaultdict(dict)
def __repr__(self):
return f"<__module__.Table[{self.row_class.__name__}] object>"
def partition_key(self, row: Union[TRow, Dict[str, Any]]) -> Tuple:
"""Returns the partition key of a row (ie. the cells which get hashed
into the token."""
if isinstance(row, dict):
row_d = row
else:
row_d = row.to_dict()
return tuple(row_d[col] for col in self.row_class.PARTITION_KEY)
def clustering_key(self, row: Union[TRow, Dict[str, Any]]) -> Tuple:
"""Returns the clustering key of a row (ie. the cells which are used
for sorting rows within a partition."""
if isinstance(row, dict):
row_d = row
else:
row_d = row.to_dict()
return tuple(row_d[col] for col in self.row_class.CLUSTERING_KEY)
def primary_key(self, row):
return self.partition_key(row) + self.clustering_key(row)
def primary_key_from_dict(self, d: Dict[str, Any]) -> Tuple:
"""Returns the primary key (ie. concatenation of partition key and
clustering key) of the given dictionary interpreted as a row."""
return tuple(d[col] for col in self.primary_key_cols)
def token(self, key: Tuple):
"""Returns the token of a row (ie. the hash of its partition key)."""
return hash(key)
def get_partition(self, token: int) -> Dict[Tuple, TRow]:
"""Returns the partition that contains this token."""
return self.data[token]
def insert(self, row: TRow):
partition = self.data[self.token(self.partition_key(row))]
partition[self.clustering_key(row)] = row
def split_primary_key(self, key: Tuple) -> Tuple[Tuple, Tuple]:
"""Returns (partition_key, clustering_key) from a partition key"""
assert len(key) == len(self.primary_key_cols)
partition_key = key[0 : len(self.row_class.PARTITION_KEY)]
clustering_key = key[len(self.row_class.PARTITION_KEY) :]
return (partition_key, clustering_key)
def get_from_partition_key(self, partition_key: Tuple) -> Iterable[TRow]:
"""Returns at most one row, from its partition key."""
token = self.token(partition_key)
for row in self.get_from_token(token):
if self.partition_key(row) == partition_key:
yield row
def get_from_primary_key(self, primary_key: Tuple) -> Optional[TRow]:
"""Returns at most one row, from its primary key."""
(partition_key, clustering_key) = self.split_primary_key(primary_key)
token = self.token(partition_key)
partition = self.get_partition(token)
return partition.get(clustering_key)
def get_from_token(self, token: int) -> Iterable[TRow]:
"""Returns all rows whose token (ie. non-cryptographic hash of the
partition key) is the one passed as argument."""
return (v for (k, v) in sorted(self.get_partition(token).items()))
def iter_all(self) -> Iterator[Tuple[Tuple, TRow]]:
return (
(self.primary_key(row), row)
for (token, partition) in self.data.items()
for (clustering_key, row) in partition.items()
)
def get_random(self) -> Optional[TRow]:
return random.choice([row for (pk, row) in self.iter_all()])
class InMemoryCqlRunner:
def __init__(self):
self._contents = Table(ContentRow)
self._content_indexes = defaultdict(lambda: defaultdict(set))
self._skipped_contents = Table(ContentRow)
self._skipped_content_indexes = defaultdict(lambda: defaultdict(set))
self._directories = Table(DirectoryRow)
self._directory_entries = Table(DirectoryEntryRow)
self._revisions = Table(RevisionRow)
self._revision_parents = Table(RevisionParentRow)
self._releases = Table(ReleaseRow)
self._snapshots = Table(SnapshotRow)
self._snapshot_branches = Table(SnapshotBranchRow)
self._origins = Table(OriginRow)
self._origin_visits = Table(OriginVisitRow)
self._origin_visit_statuses = Table(OriginVisitStatusRow)
self._metadata_authorities = Table(MetadataAuthorityRow)
self._metadata_fetchers = Table(MetadataFetcherRow)
self._raw_extrinsic_metadata = Table(RawExtrinsicMetadataRow)
self._stat_counters = defaultdict(int)
def increment_counter(self, object_type: str, nb: int):
self._stat_counters[object_type] += nb
def stat_counters(self) -> Iterable[ObjectCountRow]:
for (object_type, count) in self._stat_counters.items():
yield ObjectCountRow(partition_key=0, object_type=object_type, count=count)
##########################
# 'content' table
##########################
def _content_add_finalize(self, content: ContentRow) -> None:
self._contents.insert(content)
self.increment_counter("content", 1)
def content_add_prepare(self, content: ContentRow):
finalizer = functools.partial(self._content_add_finalize, content)
return (self._contents.token(self._contents.partition_key(content)), finalizer)
def content_get_from_pk(
self, content_hashes: Dict[str, bytes]
) -> Optional[ContentRow]:
primary_key = self._contents.primary_key_from_dict(content_hashes)
return self._contents.get_from_primary_key(primary_key)
def content_get_from_token(self, token: int) -> Iterable[ContentRow]:
return self._contents.get_from_token(token)
def content_get_random(self) -> Optional[ContentRow]:
return self._contents.get_random()
def content_get_token_range(
self, start: int, end: int, limit: int,
) -> Iterable[Tuple[int, ContentRow]]:
matches = [
(token, row)
for (token, partition) in self._contents.data.items()
for (clustering_key, row) in partition.items()
if start <= token <= end
]
matches.sort()
return matches[0:limit]
##########################
# 'content_by_*' tables
##########################
def content_missing_by_sha1_git(self, ids: List[bytes]) -> List[bytes]:
missing = []
for id_ in ids:
if id_ not in self._content_indexes["sha1_git"]:
missing.append(id_)
return missing
def content_index_add_one(self, algo: str, content: Content, token: int) -> None:
self._content_indexes[algo][content.get_hash(algo)].add(token)
def content_get_tokens_from_single_hash(
self, algo: str, hash_: bytes
) -> Iterable[int]:
return self._content_indexes[algo][hash_]
##########################
# 'skipped_content' table
##########################
def _skipped_content_add_finalize(self, content: SkippedContentRow) -> None:
self._skipped_contents.insert(content)
self.increment_counter("skipped_content", 1)
def skipped_content_add_prepare(self, content: SkippedContentRow):
finalizer = functools.partial(self._skipped_content_add_finalize, content)
return (
self._skipped_contents.token(self._contents.partition_key(content)),
finalizer,
)
def skipped_content_get_from_pk(
self, content_hashes: Dict[str, bytes]
) -> Optional[SkippedContentRow]:
primary_key = self._skipped_contents.primary_key_from_dict(content_hashes)
return self._skipped_contents.get_from_primary_key(primary_key)
##########################
# 'skipped_content_by_*' tables
##########################
def skipped_content_index_add_one(
self, algo: str, content: SkippedContent, token: int
) -> None:
self._skipped_content_indexes[algo][content.get_hash(algo)].add(token)
##########################
# 'directory' table
##########################
def directory_missing(self, ids: List[bytes]) -> List[bytes]:
missing = []
for id_ in ids:
if self._directories.get_from_primary_key((id_,)) is None:
missing.append(id_)
return missing
def directory_add_one(self, directory: DirectoryRow) -> None:
self._directories.insert(directory)
self.increment_counter("directory", 1)
def directory_get_random(self) -> Optional[DirectoryRow]:
return self._directories.get_random()
##########################
# 'directory_entry' table
##########################
def directory_entry_add_one(self, entry: DirectoryEntryRow) -> None:
self._directory_entries.insert(entry)
def directory_entry_get(
self, directory_ids: List[Sha1Git]
) -> Iterable[DirectoryEntryRow]:
for id_ in directory_ids:
yield from self._directory_entries.get_from_partition_key((id_,))
##########################
# 'revision' table
##########################
def revision_missing(self, ids: List[bytes]) -> Iterable[bytes]:
missing = []
for id_ in ids:
if self._revisions.get_from_primary_key((id_,)) is None:
missing.append(id_)
return missing
def revision_add_one(self, revision: RevisionRow) -> None:
self._revisions.insert(revision)
self.increment_counter("revision", 1)
def revision_get_ids(self, revision_ids) -> Iterable[int]:
for id_ in revision_ids:
if self._revisions.get_from_primary_key((id_,)) is not None:
yield id_
def revision_get(self, revision_ids: List[Sha1Git]) -> Iterable[RevisionRow]:
for id_ in revision_ids:
row = self._revisions.get_from_primary_key((id_,))
if row:
yield row
def revision_get_random(self) -> Optional[RevisionRow]:
return self._revisions.get_random()
##########################
# 'revision_parent' table
##########################
def revision_parent_add_one(self, revision_parent: RevisionParentRow) -> None:
self._revision_parents.insert(revision_parent)
def revision_parent_get(self, revision_id: Sha1Git) -> Iterable[bytes]:
for parent in self._revision_parents.get_from_partition_key((revision_id,)):
yield parent.parent_id
##########################
# 'release' table
##########################
def release_missing(self, ids: List[bytes]) -> List[bytes]:
missing = []
for id_ in ids:
if self._releases.get_from_primary_key((id_,)) is None:
missing.append(id_)
return missing
def release_add_one(self, release: ReleaseRow) -> None:
self._releases.insert(release)
self.increment_counter("release", 1)
def release_get(self, release_ids: List[str]) -> Iterable[ReleaseRow]:
for id_ in release_ids:
row = self._releases.get_from_primary_key((id_,))
if row:
yield row
def release_get_random(self) -> Optional[ReleaseRow]:
return self._releases.get_random()
##########################
# 'snapshot' table
##########################
def snapshot_missing(self, ids: List[bytes]) -> List[bytes]:
missing = []
for id_ in ids:
if self._snapshots.get_from_primary_key((id_,)) is None:
missing.append(id_)
return missing
def snapshot_add_one(self, snapshot: SnapshotRow) -> None:
self._snapshots.insert(snapshot)
self.increment_counter("snapshot", 1)
def snapshot_get_random(self) -> Optional[SnapshotRow]:
return self._snapshots.get_random()
##########################
# 'snapshot_branch' table
##########################
def snapshot_branch_add_one(self, branch: SnapshotBranchRow) -> None:
self._snapshot_branches.insert(branch)
def snapshot_count_branches(self, snapshot_id: Sha1Git) -> Dict[Optional[str], int]:
"""Returns a dictionary from type names to the number of branches
of that type."""
counts: Dict[Optional[str], int] = defaultdict(int)
for branch in self._snapshot_branches.get_from_partition_key((snapshot_id,)):
if branch.target_type is None:
target_type = None
else:
target_type = branch.target_type
counts[target_type] += 1
return counts
def snapshot_branch_get(
self, snapshot_id: Sha1Git, from_: bytes, limit: int
) -> Iterable[SnapshotBranchRow]:
count = 0
for branch in self._snapshot_branches.get_from_partition_key((snapshot_id,)):
if branch.name >= from_:
count += 1
yield branch
if count >= limit:
break
##########################
# 'origin' table
##########################
def origin_add_one(self, origin: OriginRow) -> None:
self._origins.insert(origin)
self.increment_counter("origin", 1)
def origin_get_by_sha1(self, sha1: bytes) -> Iterable[OriginRow]:
return self._origins.get_from_partition_key((sha1,))
def origin_get_by_url(self, url: str) -> Iterable[OriginRow]:
return self.origin_get_by_sha1(origin_url_to_sha1(url))
def origin_list(
self, start_token: int, limit: int
) -> Iterable[Tuple[int, OriginRow]]:
"""Returns an iterable of (token, origin)"""
matches = [
(token, row)
for (token, partition) in self._origins.data.items()
for (clustering_key, row) in partition.items()
if token >= start_token
]
matches.sort()
return matches[0:limit]
def origin_iter_all(self) -> Iterable[OriginRow]:
return (
row
for (token, partition) in self._origins.data.items()
for (clustering_key, row) in partition.items()
)
def origin_generate_unique_visit_id(self, origin_url: str) -> int:
origin = list(self.origin_get_by_url(origin_url))[0]
visit_id = origin.next_visit_id
origin.next_visit_id += 1
return visit_id
##########################
# 'origin_visit' table
##########################
def origin_visit_get(
- self,
- origin_url: str,
- last_visit: Optional[int],
- limit: Optional[int],
- order: ListOrder,
+ self, origin_url: str, last_visit: Optional[int], limit: int, order: ListOrder,
) -> Iterable[OriginVisitRow]:
visits = list(self._origin_visits.get_from_partition_key((origin_url,)))
if last_visit is not None:
if order == ListOrder.ASC:
visits = [v for v in visits if v.visit > last_visit]
else:
visits = [v for v in visits if v.visit < last_visit]
visits.sort(key=lambda v: v.visit, reverse=order == ListOrder.DESC)
- if limit is not None:
- visits = visits[0:limit]
+ visits = visits[0:limit]
return visits
def origin_visit_add_one(self, visit: OriginVisitRow) -> None:
self._origin_visits.insert(visit)
self.increment_counter("origin_visit", 1)
def origin_visit_get_one(
self, origin_url: str, visit_id: int
) -> Optional[OriginVisitRow]:
return self._origin_visits.get_from_primary_key((origin_url, visit_id))
def origin_visit_get_all(self, origin_url: str) -> Iterable[OriginVisitRow]:
return self._origin_visits.get_from_partition_key((origin_url,))
def origin_visit_iter(self, start_token: int) -> Iterator[OriginVisitRow]:
"""Returns all origin visits in order from this token,
and wraps around the token space."""
return (
row
for (token, partition) in self._origin_visits.data.items()
for (clustering_key, row) in partition.items()
)
##########################
# 'origin_visit_status' table
##########################
def origin_visit_status_get_range(
self,
origin: str,
visit: int,
date_from: Optional[datetime.datetime],
limit: int,
order: ListOrder,
) -> Iterable[OriginVisitStatusRow]:
statuses = list(self.origin_visit_status_get(origin, visit))
if date_from is not None:
if order == ListOrder.ASC:
statuses = [s for s in statuses if s.date >= date_from]
else:
statuses = [s for s in statuses if s.date <= date_from]
statuses.sort(key=lambda s: s.date, reverse=order == ListOrder.DESC)
return statuses[0:limit]
def origin_visit_status_add_one(self, visit_update: OriginVisitStatusRow) -> None:
self._origin_visit_statuses.insert(visit_update)
self.increment_counter("origin_visit_status", 1)
def origin_visit_status_get_latest(
self, origin: str, visit: int,
) -> Optional[OriginVisitStatusRow]:
"""Given an origin visit id, return its latest origin_visit_status
"""
return next(self.origin_visit_status_get(origin, visit), None)
def origin_visit_status_get(
self, origin: str, visit: int,
) -> Iterator[OriginVisitStatusRow]:
"""Return all origin visit statuses for a given visit
"""
statuses = [
s
for s in self._origin_visit_statuses.get_from_partition_key((origin,))
if s.visit == visit
]
statuses.sort(key=lambda s: s.date, reverse=True)
return iter(statuses)
##########################
# 'metadata_authority' table
##########################
def metadata_authority_add(self, authority: MetadataAuthorityRow):
self._metadata_authorities.insert(authority)
self.increment_counter("metadata_authority", 1)
def metadata_authority_get(self, type, url) -> Optional[MetadataAuthorityRow]:
return self._metadata_authorities.get_from_primary_key((url, type))
##########################
# 'metadata_fetcher' table
##########################
def metadata_fetcher_add(self, fetcher: MetadataFetcherRow):
self._metadata_fetchers.insert(fetcher)
self.increment_counter("metadata_fetcher", 1)
def metadata_fetcher_get(self, name, version) -> Optional[MetadataAuthorityRow]:
return self._metadata_fetchers.get_from_primary_key((name, version))
#########################
# 'raw_extrinsic_metadata' table
#########################
def raw_extrinsic_metadata_add(self, raw_extrinsic_metadata):
self._raw_extrinsic_metadata.insert(raw_extrinsic_metadata)
self.increment_counter("raw_extrinsic_metadata", 1)
def raw_extrinsic_metadata_get_after_date(
self,
id: str,
authority_type: str,
authority_url: str,
after: datetime.datetime,
) -> Iterable[RawExtrinsicMetadataRow]:
metadata = self.raw_extrinsic_metadata_get(id, authority_type, authority_url)
return (m for m in metadata if m.discovery_date > after)
def raw_extrinsic_metadata_get_after_date_and_fetcher(
self,
id: str,
authority_type: str,
authority_url: str,
after_date: datetime.datetime,
after_fetcher_name: str,
after_fetcher_version: str,
) -> Iterable[RawExtrinsicMetadataRow]:
metadata = self._raw_extrinsic_metadata.get_from_partition_key((id,))
after_tuple = (after_date, after_fetcher_name, after_fetcher_version)
return (
m
for m in metadata
if m.authority_type == authority_type
and m.authority_url == authority_url
and (m.discovery_date, m.fetcher_name, m.fetcher_version) > after_tuple
)
def raw_extrinsic_metadata_get(
self, id: str, authority_type: str, authority_url: str
) -> Iterable[RawExtrinsicMetadataRow]:
metadata = self._raw_extrinsic_metadata.get_from_partition_key((id,))
return (
m
for m in metadata
if m.authority_type == authority_type and m.authority_url == authority_url
)
class InMemoryStorage(CassandraStorage):
_cql_runner: InMemoryCqlRunner # type: ignore
def __init__(self, journal_writer=None):
self.reset()
self.journal_writer = JournalWriter(journal_writer)
def reset(self):
self._cql_runner = InMemoryCqlRunner()
self.objstorage = ObjStorage({"cls": "memory", "args": {}})
def check_config(self, *, check_write: bool) -> bool:
return True