diff --git a/swh/storage/cassandra/common.py b/swh/storage/cassandra/common.py --- a/swh/storage/cassandra/common.py +++ b/swh/storage/cassandra/common.py @@ -5,9 +5,7 @@ import hashlib - - -Row = tuple +from typing import Any, Dict, Tuple TOKEN_BEGIN = -(2 ** 63) @@ -18,3 +16,7 @@ def hash_url(url: str) -> bytes: return hashlib.sha1(url.encode("ascii")).digest() + + +def remove_keys(d: Dict[str, Any], keys: Tuple[str, ...]) -> Dict[str, Any]: + return {k: v for (k, v) in d.items() if k not in keys} diff --git a/swh/storage/cassandra/converters.py b/swh/storage/cassandra/converters.py --- a/swh/storage/cassandra/converters.py +++ b/swh/storage/cassandra/converters.py @@ -8,7 +8,7 @@ import attr from copy import deepcopy -from typing import Any, Dict, Tuple +from typing import Dict, Tuple from swh.model.model import ( ObjectType, @@ -21,10 +21,11 @@ ) from swh.model.hashutil import DEFAULT_ALGORITHMS -from .common import Row +from .common import remove_keys +from .model import OriginVisitRow, OriginVisitStatusRow, RevisionRow, ReleaseRow -def revision_to_db(revision: Revision) -> Dict[str, Any]: +def revision_to_db(revision: Revision) -> RevisionRow: # we use a deepcopy of the dict because we do not want to recurse the # Model->dict conversion (to keep Timestamp & al. entities), BUT we do not # want to modify original metadata (embedded in the Model entity), so we @@ -39,11 +40,13 @@ ) db_revision["extra_headers"] = extra_headers db_revision["type"] = db_revision["type"].value - return db_revision + return RevisionRow(**remove_keys(db_revision, ("parents",))) -def revision_from_db(db_revision: Row, parents: Tuple[Sha1Git, ...]) -> Revision: - revision = db_revision._asdict() # type: ignore +def revision_from_db( + db_revision: RevisionRow, parents: Tuple[Sha1Git, ...] +) -> Revision: + revision = db_revision.to_dict() metadata = json.loads(revision.pop("metadata", None)) extra_headers = revision.pop("extra_headers", ()) if not extra_headers and metadata and "extra_headers" in metadata: @@ -59,18 +62,18 @@ ) -def release_to_db(release: Release) -> Dict[str, Any]: +def release_to_db(release: Release) -> ReleaseRow: db_release = attr.asdict(release, recurse=False) db_release["target_type"] = db_release["target_type"].value - return db_release + return ReleaseRow(**remove_keys(db_release, ("metadata",))) -def release_from_db(db_release: Row) -> Release: - release = db_release._asdict() # type: ignore +def release_from_db(db_release: ReleaseRow) -> Release: + release = db_release.to_dict() return Release(target_type=ObjectType(release.pop("target_type")), **release,) -def row_to_content_hashes(row: Row) -> Dict[str, bytes]: +def row_to_content_hashes(row: ReleaseRow) -> Dict[str, bytes]: """Convert cassandra row to a content hashes """ @@ -80,7 +83,7 @@ return hashes -def row_to_visit(row) -> OriginVisit: +def row_to_visit(row: OriginVisitRow) -> OriginVisit: """Format a row representing an origin_visit to an actual OriginVisit. """ @@ -92,15 +95,19 @@ ) -def row_to_visit_status(row) -> OriginVisitStatus: +def row_to_visit_status(row: OriginVisitStatusRow) -> OriginVisitStatus: """Format a row representing a visit_status to an actual OriginVisitStatus. """ return OriginVisitStatus.from_dict( { - **row._asdict(), - "origin": row.origin, + **row.to_dict(), "date": row.date.replace(tzinfo=datetime.timezone.utc), "metadata": (json.loads(row.metadata) if row.metadata else None), } ) + + +def visit_status_to_row(status: OriginVisitStatus) -> OriginVisitStatusRow: + d = status.to_dict() + return OriginVisitStatusRow.from_dict({**d, "metadata": json.dumps(d["metadata"])}) diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -3,9 +3,9 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import dataclasses import datetime import functools -import json import logging import random from typing import ( @@ -17,13 +17,14 @@ List, Optional, Tuple, + Type, TypeVar, ) from cassandra import CoordinationFailure from cassandra.cluster import Cluster, EXEC_PROFILE_DEFAULT, ExecutionProfile, ResultSet from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy -from cassandra.query import PreparedStatement, BoundStatement +from cassandra.query import PreparedStatement, BoundStatement, dict_factory from tenacity import ( retry, stop_after_attempt, @@ -32,20 +33,36 @@ ) from swh.model.model import ( + Content, + SkippedContent, Sha1Git, TimestampWithTimezone, Timestamp, Person, - Content, - SkippedContent, - OriginVisit, - OriginVisitStatus, - Origin, ) from swh.storage.interface import ListOrder -from .common import Row, TOKEN_BEGIN, TOKEN_END, hash_url +from .common import TOKEN_BEGIN, TOKEN_END, hash_url, remove_keys +from .model import ( + BaseRow, + ContentRow, + DirectoryEntryRow, + DirectoryRow, + MetadataAuthorityRow, + MetadataFetcherRow, + ObjectCountRow, + OriginRow, + OriginVisitRow, + OriginVisitStatusRow, + RawExtrinsicMetadataRow, + ReleaseRow, + RevisionParentRow, + RevisionRow, + SkippedContentRow, + SnapshotBranchRow, + SnapshotRow, +) from .schema import CREATE_TABLES_QUERIES, HASH_ALGORITHMS @@ -54,7 +71,8 @@ _execution_profiles = { EXEC_PROFILE_DEFAULT: ExecutionProfile( - load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy()) + load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy()), + row_factory=dict_factory, ), } # Configuration for cassandra-driver's access to servers: @@ -166,11 +184,14 @@ ) -> None: self._execute_with_retries(statement, [nb, object_type]) - def _add_one(self, statement, object_type: str, obj, keys: List[str]) -> None: - self._increment_counter(object_type, 1) - self._execute_with_retries(statement, [getattr(obj, key) for key in keys]) + def _add_one(self, statement, object_type: Optional[str], obj: BaseRow) -> None: + if object_type: + self._increment_counter(object_type, 1) + self._execute_with_retries(statement, dataclasses.astuple(obj)) - def _get_random_row(self, statement) -> Optional[Row]: + _T = TypeVar("_T", bound=BaseRow) + + def _get_random_row(self, row_class: Type[_T], statement) -> Optional[_T]: # noqa """Takes a prepared statement of the form "SELECT * FROM WHERE token() > ? LIMIT 1" and uses it to return a random row""" @@ -181,13 +202,13 @@ # the row with the smallest token rows = self._execute_with_retries(statement, [TOKEN_BEGIN]) if rows: - return rows.one() + return row_class.from_dict(rows.one()) # type: ignore else: return None def _missing(self, statement, ids): - res = self._execute_with_retries(statement, [ids]) - found_ids = {id_ for (id_,) in res} + rows = self._execute_with_retries(statement, [ids]) + found_ids = {row["id"] for row in rows} return [id_ for id_ in ids if id_ not in found_ids] ########################## @@ -195,15 +216,6 @@ ########################## _content_pk = ["sha1", "sha1_git", "sha256", "blake2s256"] - _content_keys = [ - "sha1", - "sha1_git", - "sha256", - "blake2s256", - "length", - "ctime", - "status", - ] def _content_add_finalize(self, statement: BoundStatement) -> None: """Returned currified by content_add_prepare, to be called when the @@ -211,16 +223,14 @@ self._execute_with_retries(statement, None) self._increment_counter("content", 1) - @_prepared_insert_statement("content", _content_keys) + @_prepared_insert_statement("content", ContentRow.cols()) def content_add_prepare( - self, content, *, statement + self, content: ContentRow, *, statement ) -> Tuple[int, Callable[[], None]]: """Prepares insertion of a Content to the main 'content' table. Returns a token (to be used in secondary tables), and a function to be called to perform the insertion in the main table.""" - statement = statement.bind( - [getattr(content, key) for key in self._content_keys] - ) + statement = statement.bind(dataclasses.astuple(content)) # Type used for hashing keys (usually, it will be # cassandra.metadata.Murmur3Token) @@ -245,7 +255,7 @@ ) def content_get_from_pk( self, content_hashes: Dict[str, bytes], *, statement - ) -> Optional[Row]: + ) -> Optional[ContentRow]: rows = list( self._execute_with_retries( statement, [content_hashes[algo] for algo in HASH_ALGORITHMS] @@ -253,38 +263,44 @@ ) assert len(rows) <= 1 if rows: - return rows[0] + return ContentRow(**rows[0]) else: return None @_prepared_statement( "SELECT * FROM content WHERE token(" + ", ".join(_content_pk) + ") = ?" ) - def content_get_from_token(self, token, *, statement) -> Iterable[Row]: - return self._execute_with_retries(statement, [token]) + def content_get_from_token(self, token, *, statement) -> Iterable[ContentRow]: + return map(ContentRow.from_dict, self._execute_with_retries(statement, [token])) @_prepared_statement( "SELECT * FROM content WHERE token(%s) > ? LIMIT 1" % ", ".join(_content_pk) ) - def content_get_random(self, *, statement) -> Optional[Row]: - return self._get_random_row(statement) + def content_get_random(self, *, statement) -> Optional[ContentRow]: + return self._get_random_row(ContentRow, statement) @_prepared_statement( ( "SELECT token({0}) AS tok, {1} FROM content " "WHERE token({0}) >= ? AND token({0}) <= ? LIMIT ?" - ).format(", ".join(_content_pk), ", ".join(_content_keys)) + ).format(", ".join(_content_pk), ", ".join(ContentRow.cols())) ) def content_get_token_range( self, start: int, end: int, limit: int, *, statement - ) -> Iterable[Row]: - return self._execute_with_retries(statement, [start, end, limit]) + ) -> Iterable[Tuple[int, ContentRow]]: + """Returns an iterable of (token, row)""" + return ( + (row["tok"], ContentRow.from_dict(remove_keys(row, ("tok",)))) + for row in self._execute_with_retries(statement, [start, end, limit]) + ) ########################## # 'content_by_*' tables ########################## - @_prepared_statement("SELECT sha1_git FROM content_by_sha1_git WHERE sha1_git IN ?") + @_prepared_statement( + "SELECT sha1_git AS id FROM content_by_sha1_git WHERE sha1_git IN ?" + ) def content_missing_by_sha1_git( self, ids: List[bytes], *, statement ) -> List[bytes]: @@ -303,24 +319,15 @@ ) -> Iterable[int]: assert algo in HASH_ALGORITHMS query = f"SELECT target_token FROM content_by_{algo} WHERE {algo} = %s" - return (tok for (tok,) in self._execute_with_retries(query, [hash_])) + return ( + row["target_token"] for row in self._execute_with_retries(query, [hash_]) + ) ########################## # 'skipped_content' table ########################## _skipped_content_pk = ["sha1", "sha1_git", "sha256", "blake2s256"] - _skipped_content_keys = [ - "sha1", - "sha1_git", - "sha256", - "blake2s256", - "length", - "ctime", - "status", - "reason", - "origin", - ] _magic_null_pk = b"" """ NULLs (or all-empty blobs) are not allowed in primary keys; instead use a @@ -333,7 +340,7 @@ self._execute_with_retries(statement, None) self._increment_counter("skipped_content", 1) - @_prepared_insert_statement("skipped_content", _skipped_content_keys) + @_prepared_insert_statement("skipped_content", SkippedContentRow.cols()) def skipped_content_add_prepare( self, content, *, statement ) -> Tuple[int, Callable[[], None]]: @@ -343,14 +350,11 @@ # Replace NULLs (which are not allowed in the partition key) with # an empty byte string - content = content.to_dict() for key in self._skipped_content_pk: - if content[key] is None: - content[key] = self._magic_null_pk + if getattr(content, key) is None: + setattr(content, key, self._magic_null_pk) - statement = statement.bind( - [content.get(key) for key in self._skipped_content_keys] - ) + statement = statement.bind(dataclasses.astuple(content)) # Type used for hashing keys (usually, it will be # cassandra.metadata.Murmur3Token) @@ -376,7 +380,7 @@ ) def skipped_content_get_from_pk( self, content_hashes: Dict[str, bytes], *, statement - ) -> Optional[Row]: + ) -> Optional[SkippedContentRow]: rows = list( self._execute_with_retries( statement, @@ -389,7 +393,7 @@ assert len(rows) <= 1 if rows: # TODO: convert _magic_null_pk back to None? - return rows[0] + return SkippedContentRow.from_dict(rows[0]) else: return None @@ -414,218 +418,198 @@ # 'revision' table ########################## - _revision_keys = [ - "id", - "date", - "committer_date", - "type", - "directory", - "message", - "author", - "committer", - "synthetic", - "metadata", - "extra_headers", - ] - @_prepared_exists_statement("revision") def revision_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) - @_prepared_insert_statement("revision", _revision_keys) - def revision_add_one(self, revision: Dict[str, Any], *, statement) -> None: - self._execute_with_retries( - statement, [revision[key] for key in self._revision_keys] - ) - self._increment_counter("revision", 1) + @_prepared_insert_statement("revision", RevisionRow.cols()) + def revision_add_one(self, revision: RevisionRow, *, statement) -> None: + self._add_one(statement, "revision", revision) @_prepared_statement("SELECT id FROM revision WHERE id IN ?") - def revision_get_ids(self, revision_ids, *, statement) -> ResultSet: - return self._execute_with_retries(statement, [revision_ids]) + def revision_get_ids(self, revision_ids, *, statement) -> Iterable[int]: + return ( + row["id"] for row in self._execute_with_retries(statement, [revision_ids]) + ) @_prepared_statement("SELECT * FROM revision WHERE id IN ?") - def revision_get(self, revision_ids, *, statement) -> ResultSet: - return self._execute_with_retries(statement, [revision_ids]) + def revision_get(self, revision_ids, *, statement) -> Iterable[RevisionRow]: + return map( + RevisionRow.from_dict, self._execute_with_retries(statement, [revision_ids]) + ) @_prepared_statement("SELECT * FROM revision WHERE token(id) > ? LIMIT 1") - def revision_get_random(self, *, statement) -> Optional[Row]: - return self._get_random_row(statement) + def revision_get_random(self, *, statement) -> Optional[RevisionRow]: + return self._get_random_row(RevisionRow, statement) ########################## # 'revision_parent' table ########################## - _revision_parent_keys = ["id", "parent_rank", "parent_id"] - - @_prepared_insert_statement("revision_parent", _revision_parent_keys) + @_prepared_insert_statement("revision_parent", RevisionParentRow.cols()) def revision_parent_add_one( - self, id_: Sha1Git, parent_rank: int, parent_id: Sha1Git, *, statement + self, revision_parent: RevisionParentRow, *, statement ) -> None: - self._execute_with_retries(statement, [id_, parent_rank, parent_id]) + self._add_one(statement, None, revision_parent) @_prepared_statement("SELECT parent_id FROM revision_parent WHERE id = ?") - def revision_parent_get(self, revision_id: Sha1Git, *, statement) -> ResultSet: - return self._execute_with_retries(statement, [revision_id]) + def revision_parent_get( + self, revision_id: Sha1Git, *, statement + ) -> Iterable[bytes]: + return ( + row["parent_id"] + for row in self._execute_with_retries(statement, [revision_id]) + ) ########################## # 'release' table ########################## - _release_keys = [ - "id", - "target", - "target_type", - "date", - "name", - "message", - "author", - "synthetic", - ] - @_prepared_exists_statement("release") def release_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) - @_prepared_insert_statement("release", _release_keys) - def release_add_one(self, release: Dict[str, Any], *, statement) -> None: - self._execute_with_retries( - statement, [release[key] for key in self._release_keys] - ) - self._increment_counter("release", 1) + @_prepared_insert_statement("release", ReleaseRow.cols()) + def release_add_one(self, release: ReleaseRow, *, statement) -> None: + self._add_one(statement, "release", release) @_prepared_statement("SELECT * FROM release WHERE id in ?") - def release_get(self, release_ids: List[str], *, statement) -> None: - return self._execute_with_retries(statement, [release_ids]) + def release_get(self, release_ids: List[str], *, statement) -> Iterable[ReleaseRow]: + return map( + ReleaseRow.from_dict, self._execute_with_retries(statement, [release_ids]) + ) @_prepared_statement("SELECT * FROM release WHERE token(id) > ? LIMIT 1") - def release_get_random(self, *, statement) -> Optional[Row]: - return self._get_random_row(statement) + def release_get_random(self, *, statement) -> Optional[ReleaseRow]: + return self._get_random_row(ReleaseRow, statement) ########################## # 'directory' table ########################## - _directory_keys = ["id"] - @_prepared_exists_statement("directory") def directory_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) - @_prepared_insert_statement("directory", _directory_keys) - def directory_add_one(self, directory_id: Sha1Git, *, statement) -> None: + @_prepared_insert_statement("directory", DirectoryRow.cols()) + def directory_add_one(self, directory: DirectoryRow, *, statement) -> None: """Called after all calls to directory_entry_add_one, to commit/finalize the directory.""" - self._execute_with_retries(statement, [directory_id]) - self._increment_counter("directory", 1) + self._add_one(statement, "directory", directory) @_prepared_statement("SELECT * FROM directory WHERE token(id) > ? LIMIT 1") - def directory_get_random(self, *, statement) -> Optional[Row]: - return self._get_random_row(statement) + def directory_get_random(self, *, statement) -> Optional[DirectoryRow]: + return self._get_random_row(DirectoryRow, statement) ########################## # 'directory_entry' table ########################## - _directory_entry_keys = ["directory_id", "name", "type", "target", "perms"] - - @_prepared_insert_statement("directory_entry", _directory_entry_keys) - def directory_entry_add_one(self, entry: Dict[str, Any], *, statement) -> None: - self._execute_with_retries( - statement, [entry[key] for key in self._directory_entry_keys] - ) + @_prepared_insert_statement("directory_entry", DirectoryEntryRow.cols()) + def directory_entry_add_one(self, entry: DirectoryEntryRow, *, statement) -> None: + self._add_one(statement, None, entry) @_prepared_statement("SELECT * FROM directory_entry WHERE directory_id IN ?") - def directory_entry_get(self, directory_ids, *, statement) -> ResultSet: - return self._execute_with_retries(statement, [directory_ids]) + def directory_entry_get( + self, directory_ids, *, statement + ) -> Iterable[DirectoryEntryRow]: + return map( + DirectoryEntryRow.from_dict, + self._execute_with_retries(statement, [directory_ids]), + ) ########################## # 'snapshot' table ########################## - _snapshot_keys = ["id"] - @_prepared_exists_statement("snapshot") def snapshot_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) - @_prepared_insert_statement("snapshot", _snapshot_keys) - def snapshot_add_one(self, snapshot_id: Sha1Git, *, statement) -> None: - self._execute_with_retries(statement, [snapshot_id]) - self._increment_counter("snapshot", 1) + @_prepared_insert_statement("snapshot", SnapshotRow.cols()) + def snapshot_add_one(self, snapshot: SnapshotRow, *, statement) -> None: + self._add_one(statement, "snapshot", snapshot) @_prepared_statement("SELECT * FROM snapshot WHERE id = ?") def snapshot_get(self, snapshot_id: Sha1Git, *, statement) -> ResultSet: - return self._execute_with_retries(statement, [snapshot_id]) + return map( + SnapshotRow.from_dict, self._execute_with_retries(statement, [snapshot_id]) + ) @_prepared_statement("SELECT * FROM snapshot WHERE token(id) > ? LIMIT 1") - def snapshot_get_random(self, *, statement) -> Optional[Row]: - return self._get_random_row(statement) + def snapshot_get_random(self, *, statement) -> Optional[SnapshotRow]: + return self._get_random_row(SnapshotRow, statement) ########################## # 'snapshot_branch' table ########################## - _snapshot_branch_keys = ["snapshot_id", "name", "target_type", "target"] - - @_prepared_insert_statement("snapshot_branch", _snapshot_branch_keys) - def snapshot_branch_add_one(self, branch: Dict[str, Any], *, statement) -> None: - self._execute_with_retries( - statement, [branch[key] for key in self._snapshot_branch_keys] - ) + @_prepared_insert_statement("snapshot_branch", SnapshotBranchRow.cols()) + def snapshot_branch_add_one(self, branch: SnapshotBranchRow, *, statement) -> None: + self._add_one(statement, None, branch) @_prepared_statement( "SELECT ascii_bins_count(target_type) AS counts " "FROM snapshot_branch " "WHERE snapshot_id = ? " ) - def snapshot_count_branches(self, snapshot_id: Sha1Git, *, statement) -> ResultSet: - return self._execute_with_retries(statement, [snapshot_id]) + def snapshot_count_branches( + self, snapshot_id: Sha1Git, *, statement + ) -> Dict[Optional[str], int]: + """Returns a dictionary from type names to the number of branches + of that type.""" + row = self._execute_with_retries(statement, [snapshot_id]).one() + (nb_none, counts) = row["counts"] + return {None: nb_none, **counts} @_prepared_statement( "SELECT * FROM snapshot_branch WHERE snapshot_id = ? AND name >= ? LIMIT ?" ) def snapshot_branch_get( self, snapshot_id: Sha1Git, from_: bytes, limit: int, *, statement - ) -> ResultSet: - return self._execute_with_retries(statement, [snapshot_id, from_, limit]) + ) -> Iterable[SnapshotBranchRow]: + return map( + SnapshotBranchRow.from_dict, + self._execute_with_retries(statement, [snapshot_id, from_, limit]), + ) ########################## # 'origin' table ########################## - origin_keys = ["sha1", "url", "type", "next_visit_id"] - - @_prepared_statement( - "INSERT INTO origin (sha1, url, next_visit_id) " - "VALUES (?, ?, 1) IF NOT EXISTS" - ) - def origin_add_one(self, origin: Origin, *, statement) -> None: - self._execute_with_retries(statement, [hash_url(origin.url), origin.url]) - self._increment_counter("origin", 1) + @_prepared_insert_statement("origin", OriginRow.cols()) + def origin_add_one(self, origin: OriginRow, *, statement) -> None: + self._add_one(statement, "origin", origin) @_prepared_statement("SELECT * FROM origin WHERE sha1 = ?") - def origin_get_by_sha1(self, sha1: bytes, *, statement) -> ResultSet: - return self._execute_with_retries(statement, [sha1]) + def origin_get_by_sha1(self, sha1: bytes, *, statement) -> Iterable[OriginRow]: + return map(OriginRow.from_dict, self._execute_with_retries(statement, [sha1])) - def origin_get_by_url(self, url: str) -> ResultSet: + def origin_get_by_url(self, url: str) -> Iterable[OriginRow]: return self.origin_get_by_sha1(hash_url(url)) @_prepared_statement( - f'SELECT token(sha1) AS tok, {", ".join(origin_keys)} ' + f'SELECT token(sha1) AS tok, {", ".join(OriginRow.cols())} ' f"FROM origin WHERE token(sha1) >= ? LIMIT ?" ) - def origin_list(self, start_token: int, limit: int, *, statement) -> ResultSet: - return self._execute_with_retries(statement, [start_token, limit]) + def origin_list( + self, start_token: int, limit: int, *, statement + ) -> Iterable[Tuple[int, OriginRow]]: + """Returns an iterable of (token, origin)""" + return ( + (row["tok"], OriginRow.from_dict(remove_keys(row, ("tok",)))) + for row in self._execute_with_retries(statement, [start_token, limit]) + ) @_prepared_statement("SELECT * FROM origin") - def origin_iter_all(self, *, statement) -> ResultSet: - return self._execute_with_retries(statement, []) + def origin_iter_all(self, *, statement) -> Iterable[OriginRow]: + return map(OriginRow.from_dict, self._execute_with_retries(statement, [])) @_prepared_statement("SELECT next_visit_id FROM origin WHERE sha1 = ?") def _origin_get_next_visit_id(self, origin_sha1: bytes, *, statement) -> int: rows = list(self._execute_with_retries(statement, [origin_sha1])) assert len(rows) == 1 # TODO: error handling - return rows[0].next_visit_id + return rows[0]["next_visit_id"] @_prepared_statement( "UPDATE origin SET next_visit_id=? WHERE sha1 = ? IF next_visit_id=?" @@ -640,12 +624,12 @@ ) ) assert len(res) == 1 - if res[0].applied: + if res[0]["[applied]"]: # No data race return next_id else: # Someone else updated it before we did, let's try again - next_id = res[0].next_visit_id + next_id = res[0]["next_visit_id"] # TODO: abort after too many attempts return next_id @@ -654,13 +638,6 @@ # 'origin_visit' table ########################## - _origin_visit_keys = [ - "origin", - "visit", - "type", - "date", - ] - @_prepared_statement( "SELECT * FROM origin_visit WHERE origin = ? AND visit > ? " "ORDER BY visit ASC" @@ -737,7 +714,7 @@ last_visit: Optional[int], limit: Optional[int], order: ListOrder, - ) -> ResultSet: + ) -> Iterable[OriginVisitRow]: args: List[Any] = [origin_url] if last_visit is not None: @@ -754,7 +731,7 @@ method_name = f"_origin_visit_get_{page_name}_{order.value}_{limit_name}" origin_visit_get_method = getattr(self, method_name) - return origin_visit_get_method(*args) + return map(OriginVisitRow.from_dict, origin_visit_get_method(*args)) @_prepared_statement( "SELECT * FROM origin_visit_status WHERE origin = ? " @@ -817,7 +794,7 @@ date_from: Optional[datetime.datetime], limit: int, order: ListOrder, - ) -> ResultSet: + ) -> Iterable[OriginVisitStatusRow]: args: List[Any] = [origin, visit] if date_from is not None: @@ -830,41 +807,27 @@ method_name = f"_origin_visit_status_get_with_{date_name}_{order.value}_limit" origin_visit_status_get_method = getattr(self, method_name) - return origin_visit_status_get_method(*args) - - @_prepared_insert_statement("origin_visit", _origin_visit_keys) - def origin_visit_add_one(self, visit: OriginVisit, *, statement) -> None: - self._add_one(statement, "origin_visit", visit, self._origin_visit_keys) - - _origin_visit_status_keys = [ - "origin", - "visit", - "date", - "status", - "snapshot", - "metadata", - ] - - @_prepared_insert_statement("origin_visit_status", _origin_visit_status_keys) + return map( + OriginVisitStatusRow.from_dict, origin_visit_status_get_method(*args) + ) + + @_prepared_insert_statement("origin_visit", OriginVisitRow.cols()) + def origin_visit_add_one(self, visit: OriginVisitRow, *, statement) -> None: + self._add_one(statement, "origin_visit", visit) + + @_prepared_insert_statement("origin_visit_status", OriginVisitStatusRow.cols()) def origin_visit_status_add_one( - self, visit_update: OriginVisitStatus, *, statement + self, visit_update: OriginVisitStatusRow, *, statement ) -> None: - assert self._origin_visit_status_keys[-1] == "metadata" - keys = self._origin_visit_status_keys + self._add_one(statement, None, visit_update) - metadata = json.dumps( - dict(visit_update.metadata) if visit_update.metadata is not None else None - ) - self._execute_with_retries( - statement, [getattr(visit_update, key) for key in keys[:-1]] + [metadata] - ) - - def origin_visit_status_get_latest(self, origin: str, visit: int,) -> Optional[Row]: + def origin_visit_status_get_latest( + self, origin: str, visit: int, + ) -> Optional[OriginVisitStatusRow]: """Given an origin visit id, return its latest origin_visit_status """ - rows = self.origin_visit_status_get(origin, visit) - return rows[0] if rows else None + return next(self.origin_visit_status_get(origin, visit), None) @_prepared_statement( "SELECT * FROM origin_visit_status " @@ -879,36 +842,52 @@ require_snapshot: bool = False, *, statement, - ) -> List[Row]: + ) -> Iterator[OriginVisitStatusRow]: """Return all origin visit statuses for a given visit """ - return list(self._execute_with_retries(statement, [origin, visit])) + return map( + OriginVisitStatusRow.from_dict, + self._execute_with_retries(statement, [origin, visit]), + ) @_prepared_statement("SELECT * FROM origin_visit WHERE origin = ? AND visit = ?") def origin_visit_get_one( self, origin_url: str, visit_id: int, *, statement - ) -> Optional[Row]: + ) -> Optional[OriginVisitRow]: # TODO: error handling rows = list(self._execute_with_retries(statement, [origin_url, visit_id])) if rows: - return rows[0] + return OriginVisitRow.from_dict(rows[0]) else: return None @_prepared_statement("SELECT * FROM origin_visit WHERE origin = ?") - def origin_visit_get_all(self, origin_url: str, *, statement) -> ResultSet: - return self._execute_with_retries(statement, [origin_url]) + def origin_visit_get_all( + self, origin_url: str, *, statement + ) -> Iterable[OriginVisitRow]: + return map( + OriginVisitRow.from_dict, + self._execute_with_retries(statement, [origin_url]), + ) @_prepared_statement("SELECT * FROM origin_visit WHERE token(origin) >= ?") - def _origin_visit_iter_from(self, min_token: int, *, statement) -> Iterator[Row]: - yield from self._execute_with_retries(statement, [min_token]) + def _origin_visit_iter_from( + self, min_token: int, *, statement + ) -> Iterable[OriginVisitRow]: + return map( + OriginVisitRow.from_dict, self._execute_with_retries(statement, [min_token]) + ) @_prepared_statement("SELECT * FROM origin_visit WHERE token(origin) < ?") - def _origin_visit_iter_to(self, max_token: int, *, statement) -> Iterator[Row]: - yield from self._execute_with_retries(statement, [max_token]) + def _origin_visit_iter_to( + self, max_token: int, *, statement + ) -> Iterable[OriginVisitRow]: + return map( + OriginVisitRow.from_dict, self._execute_with_retries(statement, [max_token]) + ) - def origin_visit_iter(self, start_token: int) -> Iterator[Row]: + def origin_visit_iter(self, start_token: int) -> Iterator[OriginVisitRow]: """Returns all origin visits in order from this token, and wraps around the token space.""" yield from self._origin_visit_iter_from(start_token) @@ -918,68 +897,49 @@ # 'metadata_authority' table ########################## - _metadata_authority_keys = ["url", "type", "metadata"] - - @_prepared_insert_statement("metadata_authority", _metadata_authority_keys) - def metadata_authority_add(self, url, type, metadata, *, statement): - return self._execute_with_retries(statement, [url, type, metadata]) + @_prepared_insert_statement("metadata_authority", MetadataAuthorityRow.cols()) + def metadata_authority_add(self, authority: MetadataAuthorityRow, *, statement): + self._add_one(statement, None, authority) @_prepared_statement("SELECT * from metadata_authority WHERE type = ? AND url = ?") - def metadata_authority_get(self, type, url, *, statement) -> Optional[Row]: - return next(iter(self._execute_with_retries(statement, [type, url])), None) + def metadata_authority_get( + self, type, url, *, statement + ) -> Optional[MetadataAuthorityRow]: + rows = list(self._execute_with_retries(statement, [type, url])) + if rows: + return MetadataAuthorityRow.from_dict(rows[0]) + else: + return None ########################## # 'metadata_fetcher' table ########################## - _metadata_fetcher_keys = ["name", "version", "metadata"] - - @_prepared_insert_statement("metadata_fetcher", _metadata_fetcher_keys) - def metadata_fetcher_add(self, name, version, metadata, *, statement): - return self._execute_with_retries(statement, [name, version, metadata]) + @_prepared_insert_statement("metadata_fetcher", MetadataFetcherRow.cols()) + def metadata_fetcher_add(self, fetcher, *, statement): + self._add_one(statement, None, fetcher) @_prepared_statement( "SELECT * from metadata_fetcher WHERE name = ? AND version = ?" ) - def metadata_fetcher_get(self, name, version, *, statement) -> Optional[Row]: - return next(iter(self._execute_with_retries(statement, [name, version])), None) + def metadata_fetcher_get( + self, name, version, *, statement + ) -> Optional[MetadataFetcherRow]: + rows = list(self._execute_with_retries(statement, [name, version])) + if rows: + return MetadataFetcherRow.from_dict(rows[0]) + else: + return None ######################### # 'raw_extrinsic_metadata' table ######################### - _raw_extrinsic_metadata_keys = [ - "type", - "id", - "authority_type", - "authority_url", - "discovery_date", - "fetcher_name", - "fetcher_version", - "format", - "metadata", - "origin", - "visit", - "snapshot", - "release", - "revision", - "path", - "directory", - ] - - @_prepared_statement( - f"INSERT INTO raw_extrinsic_metadata " - f" ({', '.join(_raw_extrinsic_metadata_keys)}) " - f"VALUES ({', '.join('?' for _ in _raw_extrinsic_metadata_keys)})" + @_prepared_insert_statement( + "raw_extrinsic_metadata", RawExtrinsicMetadataRow.cols() ) - def raw_extrinsic_metadata_add( - self, statement, **kwargs, - ): - assert set(kwargs) == set( - self._raw_extrinsic_metadata_keys - ), f"Bad kwargs: {set(kwargs)}" - params = [kwargs[key] for key in self._raw_extrinsic_metadata_keys] - return self._execute_with_retries(statement, params,) + def raw_extrinsic_metadata_add(self, raw_extrinsic_metadata, *, statement): + self._add_one(statement, None, raw_extrinsic_metadata) @_prepared_statement( "SELECT * from raw_extrinsic_metadata " @@ -993,9 +953,12 @@ after: datetime.datetime, *, statement, - ): - return self._execute_with_retries( - statement, [id, authority_url, after, authority_type] + ) -> Iterable[RawExtrinsicMetadataRow]: + return map( + RawExtrinsicMetadataRow.from_dict, + self._execute_with_retries( + statement, [id, authority_url, after, authority_type] + ), ) @_prepared_statement( @@ -1013,17 +976,20 @@ after_fetcher_version: str, *, statement, - ): - return self._execute_with_retries( - statement, - [ - id, - authority_type, - authority_url, - after_date, - after_fetcher_name, - after_fetcher_version, - ], + ) -> Iterable[RawExtrinsicMetadataRow]: + return map( + RawExtrinsicMetadataRow.from_dict, + self._execute_with_retries( + statement, + [ + id, + authority_type, + authority_url, + after_date, + after_fetcher_name, + after_fetcher_version, + ], + ), ) @_prepared_statement( @@ -1032,9 +998,10 @@ ) def raw_extrinsic_metadata_get( self, id: str, authority_type: str, authority_url: str, *, statement - ) -> Iterable[Row]: - return self._execute_with_retries( - statement, [id, authority_url, authority_type] + ) -> Iterable[RawExtrinsicMetadataRow]: + return map( + RawExtrinsicMetadataRow.from_dict, + self._execute_with_retries(statement, [id, authority_url, authority_type]), ) ########################## @@ -1045,8 +1012,6 @@ def check_read(self, *, statement): self._execute_with_retries(statement, []) - @_prepared_statement( - "SELECT object_type, count FROM object_count WHERE partition_key=0" - ) + @_prepared_statement("SELECT * FROM object_count WHERE partition_key=0") def stat_counters(self, *, statement) -> ResultSet: - return self._execute_with_retries(statement, []) + return map(ObjectCountRow.from_dict, self._execute_with_retries(statement, [])) diff --git a/swh/storage/cassandra/model.py b/swh/storage/cassandra/model.py new file mode 100644 --- /dev/null +++ b/swh/storage/cassandra/model.py @@ -0,0 +1,196 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +"""Classes representing tables in the Cassandra database. + +They are very close to classes found in swh.model.model, but most of +them are subtly different: + +* Large objects are split into other classes (eg. RevisionRow has no + 'parents' field, because parents are stored in a different table, + represented by RevisionParentRow) +* They have a "cols" field, which returns the list of column names + of the table +* They only use types that map directly to Cassandra's schema (ie. no enums) + +Therefore, this model doesn't reuse swh.model.model, except for types +that can be mapped to UDTs (Person and TimestampWithTimezone). +""" + +import dataclasses +import datetime +from typing import Any, Dict, List, Optional, Type, TypeVar + +from swh.model.model import Person, TimestampWithTimezone + + +T = TypeVar("T", bound="BaseRow") + + +class BaseRow: + @classmethod + def from_dict(cls: Type[T], d: Dict[str, Any]) -> T: + return cls(**d) # type: ignore + + @classmethod + def cols(cls) -> List[str]: + return [field.name for field in dataclasses.fields(cls)] + + def to_dict(self) -> Dict[str, Any]: + return dataclasses.asdict(self) + + +@dataclasses.dataclass +class ContentRow(BaseRow): + sha1: bytes + sha1_git: bytes + sha256: bytes + blake2s256: bytes + length: int + ctime: datetime.datetime + status: str + + +@dataclasses.dataclass +class SkippedContentRow(BaseRow): + sha1: Optional[bytes] + sha1_git: Optional[bytes] + sha256: Optional[bytes] + blake2s256: Optional[bytes] + length: Optional[int] + ctime: Optional[datetime.datetime] + status: str + reason: str + origin: str + + +@dataclasses.dataclass +class DirectoryRow(BaseRow): + id: bytes + + +@dataclasses.dataclass +class DirectoryEntryRow(BaseRow): + directory_id: bytes + name: bytes + target: bytes + perms: int + type: str + + +@dataclasses.dataclass +class RevisionRow(BaseRow): + id: bytes + date: Optional[TimestampWithTimezone] + committer_date: Optional[TimestampWithTimezone] + type: str + directory: bytes + message: bytes + author: Person + committer: Person + synthetic: bool + metadata: str + extra_headers: dict + + +@dataclasses.dataclass +class RevisionParentRow(BaseRow): + id: bytes + parent_rank: int + parent_id: bytes + + +@dataclasses.dataclass +class ReleaseRow(BaseRow): + id: bytes + target_type: str + target: bytes + date: TimestampWithTimezone + name: bytes + message: bytes + author: Person + synthetic: bool + + +@dataclasses.dataclass +class SnapshotRow(BaseRow): + id: bytes + + +@dataclasses.dataclass +class SnapshotBranchRow(BaseRow): + snapshot_id: bytes + name: bytes + target_type: Optional[str] + target: Optional[bytes] + + +@dataclasses.dataclass +class OriginVisitRow(BaseRow): + origin: str + visit: int + date: datetime.datetime + type: str + + +@dataclasses.dataclass +class OriginVisitStatusRow(BaseRow): + origin: str + visit: int + date: datetime.datetime + status: str + metadata: str + snapshot: bytes + + +@dataclasses.dataclass +class OriginRow(BaseRow): + sha1: bytes + url: str + next_visit_id: int + + +@dataclasses.dataclass +class MetadataAuthorityRow(BaseRow): + url: str + type: str + metadata: str + + +@dataclasses.dataclass +class MetadataFetcherRow(BaseRow): + name: str + version: str + metadata: str + + +@dataclasses.dataclass +class RawExtrinsicMetadataRow(BaseRow): + type: str + id: str + + authority_type: str + authority_url: str + discovery_date: datetime.datetime + fetcher_name: str + fetcher_version: str + + format: str + metadata: bytes + + origin: Optional[str] + visit: Optional[int] + snapshot: Optional[str] + release: Optional[str] + revision: Optional[str] + path: Optional[bytes] + directory: Optional[str] + + +@dataclasses.dataclass +class ObjectCountRow(BaseRow): + partition_key: int + object_type: str + count: int diff --git a/swh/storage/cassandra/schema.py b/swh/storage/cassandra/schema.py --- a/swh/storage/cassandra/schema.py +++ b/swh/storage/cassandra/schema.py @@ -181,7 +181,6 @@ CREATE TABLE IF NOT EXISTS origin ( sha1 blob PRIMARY KEY, url text, - type text, next_visit_id int, -- We need integer visit ids for compatibility with the pgsql -- storage, so we're using lightweight transactions with this trick: diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -48,10 +48,24 @@ from swh.storage.utils import map_optional, now from ..exc import StorageArgumentException, HashCollision -from .common import TOKEN_BEGIN, TOKEN_END +from .common import TOKEN_BEGIN, TOKEN_END, hash_url, remove_keys from . import converters from .cql import CqlRunner from .schema import HASH_ALGORITHMS +from .model import ( + ContentRow, + DirectoryEntryRow, + DirectoryRow, + MetadataAuthorityRow, + MetadataFetcherRow, + OriginRow, + OriginVisitRow, + RawExtrinsicMetadataRow, + RevisionParentRow, + SkippedContentRow, + SnapshotBranchRow, + SnapshotRow, +) # Max block size of contents to return @@ -139,7 +153,9 @@ collisions.append(content.hashes()) raise HashCollision(algo, content.get_hash(algo), collisions) - (token, insertion_finalizer) = self._cql_runner.content_add_prepare(content) + (token, insertion_finalizer) = self._cql_runner.content_add_prepare( + ContentRow(**remove_keys(content.to_dict(), ("data",))) + ) # Then add to index tables for algo in HASH_ALGORITHMS: @@ -203,14 +219,12 @@ range_start, range_end, limit + 1 ) contents = [] - last_id: Optional[int] = None - for counter, row in enumerate(rows): + for counter, (tok, row) in enumerate(rows): if row.status == "absent": continue - row_d = row._asdict() - last_id = row_d.pop("tok") + row_d = row.to_dict() if counter >= limit: - next_page_token = str(last_id) + next_page_token = str(tok) break contents.append(Content(**row_d)) @@ -223,7 +237,7 @@ # Get all (sha1, sha1_git, sha256, blake2s256) whose sha1 # matches the argument, from the index table ('content_by_sha1') for row in self._content_get_from_hash("sha1", sha1): - row_d = row._asdict() + row_d = row.to_dict() row_d.pop("ctime") content = Content(**row_d) contents_by_sha1[content.sha1] = content @@ -251,7 +265,7 @@ break else: # All hashes match, keep this row. - row_d = row._asdict() + row_d = row.to_dict() row_d["ctime"] = row.ctime.replace(tzinfo=datetime.timezone.utc) results.append(Content(**row_d)) return results @@ -314,7 +328,7 @@ for content in contents: # Compute token of the row in the main table (token, insertion_finalizer) = self._cql_runner.skipped_content_add_prepare( - content + SkippedContentRow.from_dict({"origin": None, **content.to_dict()}) ) # Then add to index tables @@ -348,13 +362,13 @@ # Add directory entries to the 'directory_entry' table for entry in directory.entries: self._cql_runner.directory_entry_add_one( - {**entry.to_dict(), "directory_id": directory.id} + DirectoryEntryRow(directory_id=directory.id, **entry.to_dict()) ) # Add the directory *after* adding all the entries, so someone # calling snapshot_get_branch in the meantime won't end up # with half the entries. - self._cql_runner.directory_add_one(directory.id) + self._cql_runner.directory_add_one(DirectoryRow(id=directory.id)) return {"directory:add": len(directories)} @@ -387,10 +401,10 @@ rows = list(self._cql_runner.directory_entry_get([directory_id])) for row in rows: + entry_d = row.to_dict() # Build and yield the directory entry dict - entry = row._asdict() - del entry["directory_id"] - entry = DirectoryEntry.from_dict(entry) + del entry_d["directory_id"] + entry = DirectoryEntry.from_dict(entry_d) ret = self._join_dentry_to_content(entry) ret["name"] = prefix + ret["name"] ret["dir_id"] = directory_id @@ -458,9 +472,11 @@ revobject = converters.revision_to_db(revision) if revobject: # Add parents first - for (rank, parent) in enumerate(revobject["parents"]): + for (rank, parent) in enumerate(revision.parents): self._cql_runner.revision_parent_add_one( - revobject["id"], rank, parent + RevisionParentRow( + id=revobject.id, parent_rank=rank, parent_id=parent + ) ) # Then write the main revision row. @@ -484,10 +500,9 @@ # (it might have lower latency, but requires more code and more # bandwidth, because revision id would be part of each returned # row) - parent_rows = self._cql_runner.revision_parent_get(row.id) + parents = tuple(self._cql_runner.revision_parent_get(row.id)) # parent_rank is the clustering key, so results are already # sorted by rank. - parents = tuple(row.parent_id for row in parent_rows) rev = converters.revision_from_db(row, parents=parents) revs[rev.id] = rev.to_dict() @@ -514,27 +529,35 @@ # results (ie. not return only a subset of a revision's parents # if it is being written) if short: - rows = self._cql_runner.revision_get_ids(rev_ids) + ids = self._cql_runner.revision_get_ids(rev_ids) + for id_ in ids: + # TODO: use a single query to get all parents? + # (it might have less latency, but requires less code and more + # bandwidth (because revision id would be part of each returned + # row) + parents = tuple(self._cql_runner.revision_parent_get(id_)) + + # parent_rank is the clustering key, so results are already + # sorted by rank. + + yield (id_, parents) + yield from self._get_parent_revs(parents, seen, limit, short) else: rows = self._cql_runner.revision_get(rev_ids) - for row in rows: - # TODO: use a single query to get all parents? - # (it might have less latency, but requires less code and more - # bandwidth (because revision id would be part of each returned - # row) - parent_rows = self._cql_runner.revision_parent_get(row.id) + for row in rows: + # TODO: use a single query to get all parents? + # (it might have less latency, but requires less code and more + # bandwidth (because revision id would be part of each returned + # row) + parents = tuple(self._cql_runner.revision_parent_get(row.id)) - # parent_rank is the clustering key, so results are already - # sorted by rank. - parents = tuple(row.parent_id for row in parent_rows) + # parent_rank is the clustering key, so results are already + # sorted by rank. - if short: - yield (row.id, parents) - else: rev = converters.revision_from_db(row, parents=parents) yield rev.to_dict() - yield from self._get_parent_revs(parents, seen, limit, short) + yield from self._get_parent_revs(parents, seen, limit, short) def revision_log( self, revisions: List[Sha1Git], limit: Optional[int] = None @@ -595,24 +618,24 @@ # Add branches for (branch_name, branch) in snapshot.branches.items(): if branch is None: - target_type = None - target = None + target_type: Optional[str] = None + target: Optional[bytes] = None else: target_type = branch.target_type.value target = branch.target self._cql_runner.snapshot_branch_add_one( - { - "snapshot_id": snapshot.id, - "name": branch_name, - "target_type": target_type, - "target": target, - } + SnapshotBranchRow( + snapshot_id=snapshot.id, + name=branch_name, + target_type=target_type, + target=target, + ) ) # Add the snapshot *after* adding all the branches, so someone # calling snapshot_get_branch in the meantime won't end up # with half the branches. - self._cql_runner.snapshot_add_one(snapshot.id) + self._cql_runner.snapshot_add_one(SnapshotRow(id=snapshot.id)) return {"snapshot:add": len(snapshots)} @@ -647,13 +670,7 @@ # Makes sure we don't fetch branches for a snapshot that is # being added. return None - rows = list(self._cql_runner.snapshot_count_branches(snapshot_id)) - assert len(rows) == 1 - (nb_none, counts) = rows[0].counts - counts = dict(counts) - if nb_none: - counts[None] = nb_none - return counts + return self._cql_runner.snapshot_count_branches(snapshot_id) def snapshot_get_branches( self, @@ -760,8 +777,8 @@ def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]: results = [] for sha1 in sha1s: - rows = self._cql_runner.origin_get_by_sha1(sha1) - origin = {"url": rows.one().url} if rows else None + rows = list(self._cql_runner.origin_get_by_sha1(sha1)) + origin = {"url": rows[0].url} if rows else None results.append(origin) return results @@ -778,10 +795,10 @@ origins = [] # Take one more origin so we can reuse it as the next page token if any - for row in self._cql_runner.origin_list(start_token, limit + 1): + for (tok, row) in self._cql_runner.origin_list(start_token, limit + 1): origins.append(Origin(url=row.url)) # keep reference of the last id for pagination purposes - last_id = row.tok + last_id = tok if len(origins) > limit: # last origin id is the next page token @@ -805,15 +822,17 @@ next_page_token = None offset = int(page_token) if page_token else 0 - origins = self._cql_runner.origin_iter_all() + origin_rows = [row for row in self._cql_runner.origin_iter_all()] if regexp: pat = re.compile(url_pattern) - origins = [Origin(orig.url) for orig in origins if pat.search(orig.url)] + origin_rows = [row for row in origin_rows if pat.search(row.url)] else: - origins = [Origin(orig.url) for orig in origins if url_pattern in orig.url] + origin_rows = [row for row in origin_rows if url_pattern in row.url] if with_visit: - origins = [Origin(orig.url) for orig in origins if orig.next_visit_id > 1] + origin_rows = [row for row in origin_rows if row.next_visit_id > 1] + + origins = [Origin(url=row.url) for row in origin_rows] origins = origins[offset : offset + limit + 1] if len(origins) > limit: @@ -829,7 +848,9 @@ to_add = [ori for ori in origins if self.origin_get_one(ori.url) is None] self.journal_writer.origin_add(to_add) for origin in to_add: - self._cql_runner.origin_add_one(origin) + self._cql_runner.origin_add_one( + OriginRow(sha1=hash_url(origin.url), url=origin.url, next_visit_id=1) + ) return {"origin:add": len(to_add)} def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]: @@ -848,7 +869,7 @@ ) visit = attr.evolve(visit, visit=visit_id) self.journal_writer.origin_visit_add([visit]) - self._cql_runner.origin_visit_add_one(visit) + self._cql_runner.origin_visit_add_one(OriginVisitRow(**visit.to_dict())) assert visit.visit is not None all_visits.append(visit) self._origin_visit_status_add( @@ -866,7 +887,9 @@ def _origin_visit_status_add(self, visit_status: OriginVisitStatus) -> None: """Add an origin visit status""" self.journal_writer.origin_visit_status_add([visit_status]) - self._cql_runner.origin_visit_status_add_one(visit_status) + self._cql_runner.origin_visit_status_add_one( + converters.visit_status_to_row(visit_status) + ) def origin_visit_status_add(self, visit_statuses: List[OriginVisitStatus]) -> None: # First round to check existence (fail early if any is ko) @@ -912,7 +935,7 @@ @staticmethod def _format_origin_visit_row(visit): return { - **visit._asdict(), + **visit.to_dict(), "origin": visit.origin, "date": visit.date.replace(tzinfo=datetime.timezone.utc), } @@ -1048,8 +1071,10 @@ f"Unknown allowed statuses {','.join(allowed_statuses)}, only " f"{','.join(VISIT_STATUSES)} authorized" ) - rows = self._cql_runner.origin_visit_status_get( - origin_url, visit, allowed_statuses, require_snapshot + rows = list( + self._cql_runner.origin_visit_status_get( + origin_url, visit, allowed_statuses, require_snapshot + ) ) # filtering is done python side as we cannot do it server side if allowed_statuses: @@ -1113,7 +1138,7 @@ ) try: - self._cql_runner.raw_extrinsic_metadata_add( + row = RawExtrinsicMetadataRow( type=metadata_entry.type.value, id=str(metadata_entry.id), authority_type=metadata_entry.authority.type.value, @@ -1131,6 +1156,7 @@ path=metadata_entry.path, directory=map_optional(str, metadata_entry.directory), ) + self._cql_runner.raw_extrinsic_metadata_add(row) except TypeError as e: raise StorageArgumentException(*e.args) @@ -1236,9 +1262,11 @@ self.journal_writer.metadata_fetcher_add(fetchers) for fetcher in fetchers: self._cql_runner.metadata_fetcher_add( - fetcher.name, - fetcher.version, - json.dumps(map_optional(dict, fetcher.metadata)), + MetadataFetcherRow( + name=fetcher.name, + version=fetcher.version, + metadata=json.dumps(map_optional(dict, fetcher.metadata)), + ) ) def metadata_fetcher_get( @@ -1258,9 +1286,11 @@ self.journal_writer.metadata_authority_add(authorities) for authority in authorities: self._cql_runner.metadata_authority_add( - authority.url, - authority.type.value, - json.dumps(map_optional(dict, authority.metadata)), + MetadataAuthorityRow( + url=authority.url, + type=authority.type.value, + metadata=json.dumps(map_optional(dict, authority.metadata)), + ) ) def metadata_authority_get( diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -629,7 +629,9 @@ else: return None - def snapshot_count_branches(self, snapshot_id: Sha1Git) -> Optional[Dict[str, int]]: + def snapshot_count_branches( + self, snapshot_id: Sha1Git + ) -> Optional[Dict[Optional[str], int]]: snapshot = self._snapshots[snapshot_id] return collections.Counter( branch.target_type.value if branch else None diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -715,7 +715,9 @@ ... @remote_api_endpoint("snapshot/count_branches") - def snapshot_count_branches(self, snapshot_id: Sha1Git) -> Optional[Dict[str, int]]: + def snapshot_count_branches( + self, snapshot_id: Sha1Git + ) -> Optional[Dict[Optional[str], int]]: """Count the number of branches in the snapshot with the given id Args: diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -771,7 +771,7 @@ @db_transaction(statement_timeout=2000) def snapshot_count_branches( self, snapshot_id: Sha1Git, db=None, cur=None - ) -> Optional[Dict[str, int]]: + ) -> Optional[Dict[Optional[str], int]]: return dict([bc for bc in db.snapshot_count_branches(snapshot_id, cur)]) @timed diff --git a/swh/storage/tests/test_cassandra.py b/swh/storage/tests/test_cassandra.py --- a/swh/storage/tests/test_cassandra.py +++ b/swh/storage/tests/test_cassandra.py @@ -3,22 +3,22 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import attr +import datetime import os import signal import socket import subprocess import time - -from collections import namedtuple from typing import Dict +import attr import pytest from swh.core.api.classes import stream_results from swh.storage import get_storage from swh.storage.cassandra import create_keyspace from swh.storage.cassandra.schema import TABLES, HASH_ALGORITHMS +from swh.storage.cassandra.model import ContentRow from swh.storage.utils import now from swh.storage.tests.test_storage import TestStorage as _TestStorage @@ -209,12 +209,17 @@ ) # For all tokens, always return cont - Row = namedtuple("Row", HASH_ALGORITHMS) - def mock_cgft(token): nonlocal called called += 1 - return [Row(**{algo: getattr(cont, algo) for algo in HASH_ALGORITHMS})] + return [ + ContentRow( + length=10, + ctime=datetime.datetime.now(), + status="present", + **{algo: getattr(cont, algo) for algo in HASH_ALGORITHMS}, + ) + ] mocker.patch.object( swh_storage._cql_runner, "content_get_from_token", mock_cgft @@ -253,13 +258,12 @@ # For all tokens, always return cont and cont2 cols = list(set(cont.to_dict()) - {"data"}) - Row = namedtuple("Row", cols) def mock_cgft(token): nonlocal called called += 1 return [ - Row(**{col: getattr(cont, col) for col in cols}) + ContentRow(**{col: getattr(cont, col) for col in cols},) for cont in [cont, cont2] ] @@ -299,13 +303,12 @@ # For all tokens, always return cont and cont2 cols = list(set(cont.to_dict()) - {"data"}) - Row = namedtuple("Row", cols) def mock_cgft(token): nonlocal called called += 1 return [ - Row(**{col: getattr(cont, col) for col in cols}) + ContentRow(**{col: getattr(cont, col) for col in cols}) for cont in [cont, cont2] ] @@ -342,16 +345,15 @@ rows[tok] = row_d # For all tokens, always return cont - keys = set(["tok"] + list(content.to_dict().keys())).difference(set(["data"])) - Row = namedtuple("Row", keys) def mock_content_get_token_range(range_start, range_end, limit): nonlocal called called += 1 for tok in list(rows.keys()) * 3: # yield multiple times the same tok - row_d = rows[tok] - yield Row(**row_d) + row_d = dict(rows[tok].items()) + row_d.pop("tok") + yield (tok, ContentRow(**row_d)) mocker.patch.object( swh_storage._cql_runner,