diff --git a/swh/storage/cassandra/common.py b/swh/storage/cassandra/common.py
--- a/swh/storage/cassandra/common.py
+++ b/swh/storage/cassandra/common.py
@@ -5,9 +5,7 @@
import hashlib
-
-
-Row = tuple
+from typing import Any, Dict, Tuple
TOKEN_BEGIN = -(2 ** 63)
@@ -18,3 +16,7 @@
def hash_url(url: str) -> bytes:
return hashlib.sha1(url.encode("ascii")).digest()
+
+
+def remove_keys(d: Dict[str, Any], keys: Tuple[str, ...]) -> Dict[str, Any]:
+ return {k: v for (k, v) in d.items() if k not in keys}
diff --git a/swh/storage/cassandra/converters.py b/swh/storage/cassandra/converters.py
--- a/swh/storage/cassandra/converters.py
+++ b/swh/storage/cassandra/converters.py
@@ -8,7 +8,7 @@
import attr
from copy import deepcopy
-from typing import Any, Dict, Tuple
+from typing import Dict, Tuple
from swh.model.model import (
ObjectType,
@@ -21,10 +21,11 @@
)
from swh.model.hashutil import DEFAULT_ALGORITHMS
-from .common import Row
+from .common import remove_keys
+from .model import OriginVisitRow, OriginVisitStatusRow, RevisionRow, ReleaseRow
-def revision_to_db(revision: Revision) -> Dict[str, Any]:
+def revision_to_db(revision: Revision) -> RevisionRow:
# we use a deepcopy of the dict because we do not want to recurse the
# Model->dict conversion (to keep Timestamp & al. entities), BUT we do not
# want to modify original metadata (embedded in the Model entity), so we
@@ -39,11 +40,13 @@
)
db_revision["extra_headers"] = extra_headers
db_revision["type"] = db_revision["type"].value
- return db_revision
+ return RevisionRow(**remove_keys(db_revision, ("parents",)))
-def revision_from_db(db_revision: Row, parents: Tuple[Sha1Git, ...]) -> Revision:
- revision = db_revision._asdict() # type: ignore
+def revision_from_db(
+ db_revision: RevisionRow, parents: Tuple[Sha1Git, ...]
+) -> Revision:
+ revision = db_revision.to_dict()
metadata = json.loads(revision.pop("metadata", None))
extra_headers = revision.pop("extra_headers", ())
if not extra_headers and metadata and "extra_headers" in metadata:
@@ -59,18 +62,18 @@
)
-def release_to_db(release: Release) -> Dict[str, Any]:
+def release_to_db(release: Release) -> ReleaseRow:
db_release = attr.asdict(release, recurse=False)
db_release["target_type"] = db_release["target_type"].value
- return db_release
+ return ReleaseRow(**remove_keys(db_release, ("metadata",)))
-def release_from_db(db_release: Row) -> Release:
- release = db_release._asdict() # type: ignore
+def release_from_db(db_release: ReleaseRow) -> Release:
+ release = db_release.to_dict()
return Release(target_type=ObjectType(release.pop("target_type")), **release,)
-def row_to_content_hashes(row: Row) -> Dict[str, bytes]:
+def row_to_content_hashes(row: ReleaseRow) -> Dict[str, bytes]:
"""Convert cassandra row to a content hashes
"""
@@ -80,7 +83,7 @@
return hashes
-def row_to_visit(row) -> OriginVisit:
+def row_to_visit(row: OriginVisitRow) -> OriginVisit:
"""Format a row representing an origin_visit to an actual OriginVisit.
"""
@@ -92,15 +95,19 @@
)
-def row_to_visit_status(row) -> OriginVisitStatus:
+def row_to_visit_status(row: OriginVisitStatusRow) -> OriginVisitStatus:
"""Format a row representing a visit_status to an actual OriginVisitStatus.
"""
return OriginVisitStatus.from_dict(
{
- **row._asdict(),
- "origin": row.origin,
+ **row.to_dict(),
"date": row.date.replace(tzinfo=datetime.timezone.utc),
"metadata": (json.loads(row.metadata) if row.metadata else None),
}
)
+
+
+def visit_status_to_row(status: OriginVisitStatus) -> OriginVisitStatusRow:
+ d = status.to_dict()
+ return OriginVisitStatusRow.from_dict({**d, "metadata": json.dumps(d["metadata"])})
diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py
--- a/swh/storage/cassandra/cql.py
+++ b/swh/storage/cassandra/cql.py
@@ -3,9 +3,9 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
+import dataclasses
import datetime
import functools
-import json
import logging
import random
from typing import (
@@ -17,13 +17,14 @@
List,
Optional,
Tuple,
+ Type,
TypeVar,
)
from cassandra import CoordinationFailure
from cassandra.cluster import Cluster, EXEC_PROFILE_DEFAULT, ExecutionProfile, ResultSet
from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy
-from cassandra.query import PreparedStatement, BoundStatement
+from cassandra.query import PreparedStatement, BoundStatement, dict_factory
from tenacity import (
retry,
stop_after_attempt,
@@ -32,20 +33,36 @@
)
from swh.model.model import (
+ Content,
+ SkippedContent,
Sha1Git,
TimestampWithTimezone,
Timestamp,
Person,
- Content,
- SkippedContent,
- OriginVisit,
- OriginVisitStatus,
- Origin,
)
from swh.storage.interface import ListOrder
-from .common import Row, TOKEN_BEGIN, TOKEN_END, hash_url
+from .common import TOKEN_BEGIN, TOKEN_END, hash_url, remove_keys
+from .model import (
+ BaseRow,
+ ContentRow,
+ DirectoryEntryRow,
+ DirectoryRow,
+ MetadataAuthorityRow,
+ MetadataFetcherRow,
+ ObjectCountRow,
+ OriginRow,
+ OriginVisitRow,
+ OriginVisitStatusRow,
+ RawExtrinsicMetadataRow,
+ ReleaseRow,
+ RevisionParentRow,
+ RevisionRow,
+ SkippedContentRow,
+ SnapshotBranchRow,
+ SnapshotRow,
+)
from .schema import CREATE_TABLES_QUERIES, HASH_ALGORITHMS
@@ -54,7 +71,8 @@
_execution_profiles = {
EXEC_PROFILE_DEFAULT: ExecutionProfile(
- load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy())
+ load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy()),
+ row_factory=dict_factory,
),
}
# Configuration for cassandra-driver's access to servers:
@@ -166,11 +184,14 @@
) -> None:
self._execute_with_retries(statement, [nb, object_type])
- def _add_one(self, statement, object_type: str, obj, keys: List[str]) -> None:
- self._increment_counter(object_type, 1)
- self._execute_with_retries(statement, [getattr(obj, key) for key in keys])
+ def _add_one(self, statement, object_type: Optional[str], obj: BaseRow) -> None:
+ if object_type:
+ self._increment_counter(object_type, 1)
+ self._execute_with_retries(statement, dataclasses.astuple(obj))
- def _get_random_row(self, statement) -> Optional[Row]:
+ _T = TypeVar("_T", bound=BaseRow)
+
+ def _get_random_row(self, row_class: Type[_T], statement) -> Optional[_T]: # noqa
"""Takes a prepared statement of the form
"SELECT * FROM
WHERE token() > ? LIMIT 1"
and uses it to return a random row"""
@@ -181,13 +202,13 @@
# the row with the smallest token
rows = self._execute_with_retries(statement, [TOKEN_BEGIN])
if rows:
- return rows.one()
+ return row_class.from_dict(rows.one()) # type: ignore
else:
return None
def _missing(self, statement, ids):
- res = self._execute_with_retries(statement, [ids])
- found_ids = {id_ for (id_,) in res}
+ rows = self._execute_with_retries(statement, [ids])
+ found_ids = {row["id"] for row in rows}
return [id_ for id_ in ids if id_ not in found_ids]
##########################
@@ -195,15 +216,6 @@
##########################
_content_pk = ["sha1", "sha1_git", "sha256", "blake2s256"]
- _content_keys = [
- "sha1",
- "sha1_git",
- "sha256",
- "blake2s256",
- "length",
- "ctime",
- "status",
- ]
def _content_add_finalize(self, statement: BoundStatement) -> None:
"""Returned currified by content_add_prepare, to be called when the
@@ -211,16 +223,14 @@
self._execute_with_retries(statement, None)
self._increment_counter("content", 1)
- @_prepared_insert_statement("content", _content_keys)
+ @_prepared_insert_statement("content", ContentRow.cols())
def content_add_prepare(
- self, content, *, statement
+ self, content: ContentRow, *, statement
) -> Tuple[int, Callable[[], None]]:
"""Prepares insertion of a Content to the main 'content' table.
Returns a token (to be used in secondary tables), and a function to be
called to perform the insertion in the main table."""
- statement = statement.bind(
- [getattr(content, key) for key in self._content_keys]
- )
+ statement = statement.bind(dataclasses.astuple(content))
# Type used for hashing keys (usually, it will be
# cassandra.metadata.Murmur3Token)
@@ -245,7 +255,7 @@
)
def content_get_from_pk(
self, content_hashes: Dict[str, bytes], *, statement
- ) -> Optional[Row]:
+ ) -> Optional[ContentRow]:
rows = list(
self._execute_with_retries(
statement, [content_hashes[algo] for algo in HASH_ALGORITHMS]
@@ -253,38 +263,44 @@
)
assert len(rows) <= 1
if rows:
- return rows[0]
+ return ContentRow(**rows[0])
else:
return None
@_prepared_statement(
"SELECT * FROM content WHERE token(" + ", ".join(_content_pk) + ") = ?"
)
- def content_get_from_token(self, token, *, statement) -> Iterable[Row]:
- return self._execute_with_retries(statement, [token])
+ def content_get_from_token(self, token, *, statement) -> Iterable[ContentRow]:
+ return map(ContentRow.from_dict, self._execute_with_retries(statement, [token]))
@_prepared_statement(
"SELECT * FROM content WHERE token(%s) > ? LIMIT 1" % ", ".join(_content_pk)
)
- def content_get_random(self, *, statement) -> Optional[Row]:
- return self._get_random_row(statement)
+ def content_get_random(self, *, statement) -> Optional[ContentRow]:
+ return self._get_random_row(ContentRow, statement)
@_prepared_statement(
(
"SELECT token({0}) AS tok, {1} FROM content "
"WHERE token({0}) >= ? AND token({0}) <= ? LIMIT ?"
- ).format(", ".join(_content_pk), ", ".join(_content_keys))
+ ).format(", ".join(_content_pk), ", ".join(ContentRow.cols()))
)
def content_get_token_range(
self, start: int, end: int, limit: int, *, statement
- ) -> Iterable[Row]:
- return self._execute_with_retries(statement, [start, end, limit])
+ ) -> Iterable[Tuple[int, ContentRow]]:
+ """Returns an iterable of (token, row)"""
+ return (
+ (row["tok"], ContentRow.from_dict(remove_keys(row, ("tok",))))
+ for row in self._execute_with_retries(statement, [start, end, limit])
+ )
##########################
# 'content_by_*' tables
##########################
- @_prepared_statement("SELECT sha1_git FROM content_by_sha1_git WHERE sha1_git IN ?")
+ @_prepared_statement(
+ "SELECT sha1_git AS id FROM content_by_sha1_git WHERE sha1_git IN ?"
+ )
def content_missing_by_sha1_git(
self, ids: List[bytes], *, statement
) -> List[bytes]:
@@ -303,24 +319,15 @@
) -> Iterable[int]:
assert algo in HASH_ALGORITHMS
query = f"SELECT target_token FROM content_by_{algo} WHERE {algo} = %s"
- return (tok for (tok,) in self._execute_with_retries(query, [hash_]))
+ return (
+ row["target_token"] for row in self._execute_with_retries(query, [hash_])
+ )
##########################
# 'skipped_content' table
##########################
_skipped_content_pk = ["sha1", "sha1_git", "sha256", "blake2s256"]
- _skipped_content_keys = [
- "sha1",
- "sha1_git",
- "sha256",
- "blake2s256",
- "length",
- "ctime",
- "status",
- "reason",
- "origin",
- ]
_magic_null_pk = b""
"""
NULLs (or all-empty blobs) are not allowed in primary keys; instead use a
@@ -333,7 +340,7 @@
self._execute_with_retries(statement, None)
self._increment_counter("skipped_content", 1)
- @_prepared_insert_statement("skipped_content", _skipped_content_keys)
+ @_prepared_insert_statement("skipped_content", SkippedContentRow.cols())
def skipped_content_add_prepare(
self, content, *, statement
) -> Tuple[int, Callable[[], None]]:
@@ -343,14 +350,11 @@
# Replace NULLs (which are not allowed in the partition key) with
# an empty byte string
- content = content.to_dict()
for key in self._skipped_content_pk:
- if content[key] is None:
- content[key] = self._magic_null_pk
+ if getattr(content, key) is None:
+ setattr(content, key, self._magic_null_pk)
- statement = statement.bind(
- [content.get(key) for key in self._skipped_content_keys]
- )
+ statement = statement.bind(dataclasses.astuple(content))
# Type used for hashing keys (usually, it will be
# cassandra.metadata.Murmur3Token)
@@ -376,7 +380,7 @@
)
def skipped_content_get_from_pk(
self, content_hashes: Dict[str, bytes], *, statement
- ) -> Optional[Row]:
+ ) -> Optional[SkippedContentRow]:
rows = list(
self._execute_with_retries(
statement,
@@ -389,7 +393,7 @@
assert len(rows) <= 1
if rows:
# TODO: convert _magic_null_pk back to None?
- return rows[0]
+ return SkippedContentRow.from_dict(rows[0])
else:
return None
@@ -414,218 +418,198 @@
# 'revision' table
##########################
- _revision_keys = [
- "id",
- "date",
- "committer_date",
- "type",
- "directory",
- "message",
- "author",
- "committer",
- "synthetic",
- "metadata",
- "extra_headers",
- ]
-
@_prepared_exists_statement("revision")
def revision_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
return self._missing(statement, ids)
- @_prepared_insert_statement("revision", _revision_keys)
- def revision_add_one(self, revision: Dict[str, Any], *, statement) -> None:
- self._execute_with_retries(
- statement, [revision[key] for key in self._revision_keys]
- )
- self._increment_counter("revision", 1)
+ @_prepared_insert_statement("revision", RevisionRow.cols())
+ def revision_add_one(self, revision: RevisionRow, *, statement) -> None:
+ self._add_one(statement, "revision", revision)
@_prepared_statement("SELECT id FROM revision WHERE id IN ?")
- def revision_get_ids(self, revision_ids, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [revision_ids])
+ def revision_get_ids(self, revision_ids, *, statement) -> Iterable[int]:
+ return (
+ row["id"] for row in self._execute_with_retries(statement, [revision_ids])
+ )
@_prepared_statement("SELECT * FROM revision WHERE id IN ?")
- def revision_get(self, revision_ids, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [revision_ids])
+ def revision_get(self, revision_ids, *, statement) -> Iterable[RevisionRow]:
+ return map(
+ RevisionRow.from_dict, self._execute_with_retries(statement, [revision_ids])
+ )
@_prepared_statement("SELECT * FROM revision WHERE token(id) > ? LIMIT 1")
- def revision_get_random(self, *, statement) -> Optional[Row]:
- return self._get_random_row(statement)
+ def revision_get_random(self, *, statement) -> Optional[RevisionRow]:
+ return self._get_random_row(RevisionRow, statement)
##########################
# 'revision_parent' table
##########################
- _revision_parent_keys = ["id", "parent_rank", "parent_id"]
-
- @_prepared_insert_statement("revision_parent", _revision_parent_keys)
+ @_prepared_insert_statement("revision_parent", RevisionParentRow.cols())
def revision_parent_add_one(
- self, id_: Sha1Git, parent_rank: int, parent_id: Sha1Git, *, statement
+ self, revision_parent: RevisionParentRow, *, statement
) -> None:
- self._execute_with_retries(statement, [id_, parent_rank, parent_id])
+ self._add_one(statement, None, revision_parent)
@_prepared_statement("SELECT parent_id FROM revision_parent WHERE id = ?")
- def revision_parent_get(self, revision_id: Sha1Git, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [revision_id])
+ def revision_parent_get(
+ self, revision_id: Sha1Git, *, statement
+ ) -> Iterable[bytes]:
+ return (
+ row["parent_id"]
+ for row in self._execute_with_retries(statement, [revision_id])
+ )
##########################
# 'release' table
##########################
- _release_keys = [
- "id",
- "target",
- "target_type",
- "date",
- "name",
- "message",
- "author",
- "synthetic",
- ]
-
@_prepared_exists_statement("release")
def release_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
return self._missing(statement, ids)
- @_prepared_insert_statement("release", _release_keys)
- def release_add_one(self, release: Dict[str, Any], *, statement) -> None:
- self._execute_with_retries(
- statement, [release[key] for key in self._release_keys]
- )
- self._increment_counter("release", 1)
+ @_prepared_insert_statement("release", ReleaseRow.cols())
+ def release_add_one(self, release: ReleaseRow, *, statement) -> None:
+ self._add_one(statement, "release", release)
@_prepared_statement("SELECT * FROM release WHERE id in ?")
- def release_get(self, release_ids: List[str], *, statement) -> None:
- return self._execute_with_retries(statement, [release_ids])
+ def release_get(self, release_ids: List[str], *, statement) -> Iterable[ReleaseRow]:
+ return map(
+ ReleaseRow.from_dict, self._execute_with_retries(statement, [release_ids])
+ )
@_prepared_statement("SELECT * FROM release WHERE token(id) > ? LIMIT 1")
- def release_get_random(self, *, statement) -> Optional[Row]:
- return self._get_random_row(statement)
+ def release_get_random(self, *, statement) -> Optional[ReleaseRow]:
+ return self._get_random_row(ReleaseRow, statement)
##########################
# 'directory' table
##########################
- _directory_keys = ["id"]
-
@_prepared_exists_statement("directory")
def directory_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
return self._missing(statement, ids)
- @_prepared_insert_statement("directory", _directory_keys)
- def directory_add_one(self, directory_id: Sha1Git, *, statement) -> None:
+ @_prepared_insert_statement("directory", DirectoryRow.cols())
+ def directory_add_one(self, directory: DirectoryRow, *, statement) -> None:
"""Called after all calls to directory_entry_add_one, to
commit/finalize the directory."""
- self._execute_with_retries(statement, [directory_id])
- self._increment_counter("directory", 1)
+ self._add_one(statement, "directory", directory)
@_prepared_statement("SELECT * FROM directory WHERE token(id) > ? LIMIT 1")
- def directory_get_random(self, *, statement) -> Optional[Row]:
- return self._get_random_row(statement)
+ def directory_get_random(self, *, statement) -> Optional[DirectoryRow]:
+ return self._get_random_row(DirectoryRow, statement)
##########################
# 'directory_entry' table
##########################
- _directory_entry_keys = ["directory_id", "name", "type", "target", "perms"]
-
- @_prepared_insert_statement("directory_entry", _directory_entry_keys)
- def directory_entry_add_one(self, entry: Dict[str, Any], *, statement) -> None:
- self._execute_with_retries(
- statement, [entry[key] for key in self._directory_entry_keys]
- )
+ @_prepared_insert_statement("directory_entry", DirectoryEntryRow.cols())
+ def directory_entry_add_one(self, entry: DirectoryEntryRow, *, statement) -> None:
+ self._add_one(statement, None, entry)
@_prepared_statement("SELECT * FROM directory_entry WHERE directory_id IN ?")
- def directory_entry_get(self, directory_ids, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [directory_ids])
+ def directory_entry_get(
+ self, directory_ids, *, statement
+ ) -> Iterable[DirectoryEntryRow]:
+ return map(
+ DirectoryEntryRow.from_dict,
+ self._execute_with_retries(statement, [directory_ids]),
+ )
##########################
# 'snapshot' table
##########################
- _snapshot_keys = ["id"]
-
@_prepared_exists_statement("snapshot")
def snapshot_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
return self._missing(statement, ids)
- @_prepared_insert_statement("snapshot", _snapshot_keys)
- def snapshot_add_one(self, snapshot_id: Sha1Git, *, statement) -> None:
- self._execute_with_retries(statement, [snapshot_id])
- self._increment_counter("snapshot", 1)
+ @_prepared_insert_statement("snapshot", SnapshotRow.cols())
+ def snapshot_add_one(self, snapshot: SnapshotRow, *, statement) -> None:
+ self._add_one(statement, "snapshot", snapshot)
@_prepared_statement("SELECT * FROM snapshot WHERE id = ?")
def snapshot_get(self, snapshot_id: Sha1Git, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [snapshot_id])
+ return map(
+ SnapshotRow.from_dict, self._execute_with_retries(statement, [snapshot_id])
+ )
@_prepared_statement("SELECT * FROM snapshot WHERE token(id) > ? LIMIT 1")
- def snapshot_get_random(self, *, statement) -> Optional[Row]:
- return self._get_random_row(statement)
+ def snapshot_get_random(self, *, statement) -> Optional[SnapshotRow]:
+ return self._get_random_row(SnapshotRow, statement)
##########################
# 'snapshot_branch' table
##########################
- _snapshot_branch_keys = ["snapshot_id", "name", "target_type", "target"]
-
- @_prepared_insert_statement("snapshot_branch", _snapshot_branch_keys)
- def snapshot_branch_add_one(self, branch: Dict[str, Any], *, statement) -> None:
- self._execute_with_retries(
- statement, [branch[key] for key in self._snapshot_branch_keys]
- )
+ @_prepared_insert_statement("snapshot_branch", SnapshotBranchRow.cols())
+ def snapshot_branch_add_one(self, branch: SnapshotBranchRow, *, statement) -> None:
+ self._add_one(statement, None, branch)
@_prepared_statement(
"SELECT ascii_bins_count(target_type) AS counts "
"FROM snapshot_branch "
"WHERE snapshot_id = ? "
)
- def snapshot_count_branches(self, snapshot_id: Sha1Git, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [snapshot_id])
+ def snapshot_count_branches(
+ self, snapshot_id: Sha1Git, *, statement
+ ) -> Dict[Optional[str], int]:
+ """Returns a dictionary from type names to the number of branches
+ of that type."""
+ row = self._execute_with_retries(statement, [snapshot_id]).one()
+ (nb_none, counts) = row["counts"]
+ return {None: nb_none, **counts}
@_prepared_statement(
"SELECT * FROM snapshot_branch WHERE snapshot_id = ? AND name >= ? LIMIT ?"
)
def snapshot_branch_get(
self, snapshot_id: Sha1Git, from_: bytes, limit: int, *, statement
- ) -> ResultSet:
- return self._execute_with_retries(statement, [snapshot_id, from_, limit])
+ ) -> Iterable[SnapshotBranchRow]:
+ return map(
+ SnapshotBranchRow.from_dict,
+ self._execute_with_retries(statement, [snapshot_id, from_, limit]),
+ )
##########################
# 'origin' table
##########################
- origin_keys = ["sha1", "url", "type", "next_visit_id"]
-
- @_prepared_statement(
- "INSERT INTO origin (sha1, url, next_visit_id) "
- "VALUES (?, ?, 1) IF NOT EXISTS"
- )
- def origin_add_one(self, origin: Origin, *, statement) -> None:
- self._execute_with_retries(statement, [hash_url(origin.url), origin.url])
- self._increment_counter("origin", 1)
+ @_prepared_insert_statement("origin", OriginRow.cols())
+ def origin_add_one(self, origin: OriginRow, *, statement) -> None:
+ self._add_one(statement, "origin", origin)
@_prepared_statement("SELECT * FROM origin WHERE sha1 = ?")
- def origin_get_by_sha1(self, sha1: bytes, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [sha1])
+ def origin_get_by_sha1(self, sha1: bytes, *, statement) -> Iterable[OriginRow]:
+ return map(OriginRow.from_dict, self._execute_with_retries(statement, [sha1]))
- def origin_get_by_url(self, url: str) -> ResultSet:
+ def origin_get_by_url(self, url: str) -> Iterable[OriginRow]:
return self.origin_get_by_sha1(hash_url(url))
@_prepared_statement(
- f'SELECT token(sha1) AS tok, {", ".join(origin_keys)} '
+ f'SELECT token(sha1) AS tok, {", ".join(OriginRow.cols())} '
f"FROM origin WHERE token(sha1) >= ? LIMIT ?"
)
- def origin_list(self, start_token: int, limit: int, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [start_token, limit])
+ def origin_list(
+ self, start_token: int, limit: int, *, statement
+ ) -> Iterable[Tuple[int, OriginRow]]:
+ """Returns an iterable of (token, origin)"""
+ return (
+ (row["tok"], OriginRow.from_dict(remove_keys(row, ("tok",))))
+ for row in self._execute_with_retries(statement, [start_token, limit])
+ )
@_prepared_statement("SELECT * FROM origin")
- def origin_iter_all(self, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [])
+ def origin_iter_all(self, *, statement) -> Iterable[OriginRow]:
+ return map(OriginRow.from_dict, self._execute_with_retries(statement, []))
@_prepared_statement("SELECT next_visit_id FROM origin WHERE sha1 = ?")
def _origin_get_next_visit_id(self, origin_sha1: bytes, *, statement) -> int:
rows = list(self._execute_with_retries(statement, [origin_sha1]))
assert len(rows) == 1 # TODO: error handling
- return rows[0].next_visit_id
+ return rows[0]["next_visit_id"]
@_prepared_statement(
"UPDATE origin SET next_visit_id=? WHERE sha1 = ? IF next_visit_id=?"
@@ -640,12 +624,12 @@
)
)
assert len(res) == 1
- if res[0].applied:
+ if res[0]["[applied]"]:
# No data race
return next_id
else:
# Someone else updated it before we did, let's try again
- next_id = res[0].next_visit_id
+ next_id = res[0]["next_visit_id"]
# TODO: abort after too many attempts
return next_id
@@ -654,13 +638,6 @@
# 'origin_visit' table
##########################
- _origin_visit_keys = [
- "origin",
- "visit",
- "type",
- "date",
- ]
-
@_prepared_statement(
"SELECT * FROM origin_visit WHERE origin = ? AND visit > ? "
"ORDER BY visit ASC"
@@ -737,7 +714,7 @@
last_visit: Optional[int],
limit: Optional[int],
order: ListOrder,
- ) -> ResultSet:
+ ) -> Iterable[OriginVisitRow]:
args: List[Any] = [origin_url]
if last_visit is not None:
@@ -754,7 +731,7 @@
method_name = f"_origin_visit_get_{page_name}_{order.value}_{limit_name}"
origin_visit_get_method = getattr(self, method_name)
- return origin_visit_get_method(*args)
+ return map(OriginVisitRow.from_dict, origin_visit_get_method(*args))
@_prepared_statement(
"SELECT * FROM origin_visit_status WHERE origin = ? "
@@ -817,7 +794,7 @@
date_from: Optional[datetime.datetime],
limit: int,
order: ListOrder,
- ) -> ResultSet:
+ ) -> Iterable[OriginVisitStatusRow]:
args: List[Any] = [origin, visit]
if date_from is not None:
@@ -830,41 +807,27 @@
method_name = f"_origin_visit_status_get_with_{date_name}_{order.value}_limit"
origin_visit_status_get_method = getattr(self, method_name)
- return origin_visit_status_get_method(*args)
-
- @_prepared_insert_statement("origin_visit", _origin_visit_keys)
- def origin_visit_add_one(self, visit: OriginVisit, *, statement) -> None:
- self._add_one(statement, "origin_visit", visit, self._origin_visit_keys)
-
- _origin_visit_status_keys = [
- "origin",
- "visit",
- "date",
- "status",
- "snapshot",
- "metadata",
- ]
-
- @_prepared_insert_statement("origin_visit_status", _origin_visit_status_keys)
+ return map(
+ OriginVisitStatusRow.from_dict, origin_visit_status_get_method(*args)
+ )
+
+ @_prepared_insert_statement("origin_visit", OriginVisitRow.cols())
+ def origin_visit_add_one(self, visit: OriginVisitRow, *, statement) -> None:
+ self._add_one(statement, "origin_visit", visit)
+
+ @_prepared_insert_statement("origin_visit_status", OriginVisitStatusRow.cols())
def origin_visit_status_add_one(
- self, visit_update: OriginVisitStatus, *, statement
+ self, visit_update: OriginVisitStatusRow, *, statement
) -> None:
- assert self._origin_visit_status_keys[-1] == "metadata"
- keys = self._origin_visit_status_keys
+ self._add_one(statement, None, visit_update)
- metadata = json.dumps(
- dict(visit_update.metadata) if visit_update.metadata is not None else None
- )
- self._execute_with_retries(
- statement, [getattr(visit_update, key) for key in keys[:-1]] + [metadata]
- )
-
- def origin_visit_status_get_latest(self, origin: str, visit: int,) -> Optional[Row]:
+ def origin_visit_status_get_latest(
+ self, origin: str, visit: int,
+ ) -> Optional[OriginVisitStatusRow]:
"""Given an origin visit id, return its latest origin_visit_status
"""
- rows = self.origin_visit_status_get(origin, visit)
- return rows[0] if rows else None
+ return next(self.origin_visit_status_get(origin, visit), None)
@_prepared_statement(
"SELECT * FROM origin_visit_status "
@@ -879,36 +842,52 @@
require_snapshot: bool = False,
*,
statement,
- ) -> List[Row]:
+ ) -> Iterator[OriginVisitStatusRow]:
"""Return all origin visit statuses for a given visit
"""
- return list(self._execute_with_retries(statement, [origin, visit]))
+ return map(
+ OriginVisitStatusRow.from_dict,
+ self._execute_with_retries(statement, [origin, visit]),
+ )
@_prepared_statement("SELECT * FROM origin_visit WHERE origin = ? AND visit = ?")
def origin_visit_get_one(
self, origin_url: str, visit_id: int, *, statement
- ) -> Optional[Row]:
+ ) -> Optional[OriginVisitRow]:
# TODO: error handling
rows = list(self._execute_with_retries(statement, [origin_url, visit_id]))
if rows:
- return rows[0]
+ return OriginVisitRow.from_dict(rows[0])
else:
return None
@_prepared_statement("SELECT * FROM origin_visit WHERE origin = ?")
- def origin_visit_get_all(self, origin_url: str, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [origin_url])
+ def origin_visit_get_all(
+ self, origin_url: str, *, statement
+ ) -> Iterable[OriginVisitRow]:
+ return map(
+ OriginVisitRow.from_dict,
+ self._execute_with_retries(statement, [origin_url]),
+ )
@_prepared_statement("SELECT * FROM origin_visit WHERE token(origin) >= ?")
- def _origin_visit_iter_from(self, min_token: int, *, statement) -> Iterator[Row]:
- yield from self._execute_with_retries(statement, [min_token])
+ def _origin_visit_iter_from(
+ self, min_token: int, *, statement
+ ) -> Iterable[OriginVisitRow]:
+ return map(
+ OriginVisitRow.from_dict, self._execute_with_retries(statement, [min_token])
+ )
@_prepared_statement("SELECT * FROM origin_visit WHERE token(origin) < ?")
- def _origin_visit_iter_to(self, max_token: int, *, statement) -> Iterator[Row]:
- yield from self._execute_with_retries(statement, [max_token])
+ def _origin_visit_iter_to(
+ self, max_token: int, *, statement
+ ) -> Iterable[OriginVisitRow]:
+ return map(
+ OriginVisitRow.from_dict, self._execute_with_retries(statement, [max_token])
+ )
- def origin_visit_iter(self, start_token: int) -> Iterator[Row]:
+ def origin_visit_iter(self, start_token: int) -> Iterator[OriginVisitRow]:
"""Returns all origin visits in order from this token,
and wraps around the token space."""
yield from self._origin_visit_iter_from(start_token)
@@ -918,68 +897,49 @@
# 'metadata_authority' table
##########################
- _metadata_authority_keys = ["url", "type", "metadata"]
-
- @_prepared_insert_statement("metadata_authority", _metadata_authority_keys)
- def metadata_authority_add(self, url, type, metadata, *, statement):
- return self._execute_with_retries(statement, [url, type, metadata])
+ @_prepared_insert_statement("metadata_authority", MetadataAuthorityRow.cols())
+ def metadata_authority_add(self, authority: MetadataAuthorityRow, *, statement):
+ self._add_one(statement, None, authority)
@_prepared_statement("SELECT * from metadata_authority WHERE type = ? AND url = ?")
- def metadata_authority_get(self, type, url, *, statement) -> Optional[Row]:
- return next(iter(self._execute_with_retries(statement, [type, url])), None)
+ def metadata_authority_get(
+ self, type, url, *, statement
+ ) -> Optional[MetadataAuthorityRow]:
+ rows = list(self._execute_with_retries(statement, [type, url]))
+ if rows:
+ return MetadataAuthorityRow.from_dict(rows[0])
+ else:
+ return None
##########################
# 'metadata_fetcher' table
##########################
- _metadata_fetcher_keys = ["name", "version", "metadata"]
-
- @_prepared_insert_statement("metadata_fetcher", _metadata_fetcher_keys)
- def metadata_fetcher_add(self, name, version, metadata, *, statement):
- return self._execute_with_retries(statement, [name, version, metadata])
+ @_prepared_insert_statement("metadata_fetcher", MetadataFetcherRow.cols())
+ def metadata_fetcher_add(self, fetcher, *, statement):
+ self._add_one(statement, None, fetcher)
@_prepared_statement(
"SELECT * from metadata_fetcher WHERE name = ? AND version = ?"
)
- def metadata_fetcher_get(self, name, version, *, statement) -> Optional[Row]:
- return next(iter(self._execute_with_retries(statement, [name, version])), None)
+ def metadata_fetcher_get(
+ self, name, version, *, statement
+ ) -> Optional[MetadataFetcherRow]:
+ rows = list(self._execute_with_retries(statement, [name, version]))
+ if rows:
+ return MetadataFetcherRow.from_dict(rows[0])
+ else:
+ return None
#########################
# 'raw_extrinsic_metadata' table
#########################
- _raw_extrinsic_metadata_keys = [
- "type",
- "id",
- "authority_type",
- "authority_url",
- "discovery_date",
- "fetcher_name",
- "fetcher_version",
- "format",
- "metadata",
- "origin",
- "visit",
- "snapshot",
- "release",
- "revision",
- "path",
- "directory",
- ]
-
- @_prepared_statement(
- f"INSERT INTO raw_extrinsic_metadata "
- f" ({', '.join(_raw_extrinsic_metadata_keys)}) "
- f"VALUES ({', '.join('?' for _ in _raw_extrinsic_metadata_keys)})"
+ @_prepared_insert_statement(
+ "raw_extrinsic_metadata", RawExtrinsicMetadataRow.cols()
)
- def raw_extrinsic_metadata_add(
- self, statement, **kwargs,
- ):
- assert set(kwargs) == set(
- self._raw_extrinsic_metadata_keys
- ), f"Bad kwargs: {set(kwargs)}"
- params = [kwargs[key] for key in self._raw_extrinsic_metadata_keys]
- return self._execute_with_retries(statement, params,)
+ def raw_extrinsic_metadata_add(self, raw_extrinsic_metadata, *, statement):
+ self._add_one(statement, None, raw_extrinsic_metadata)
@_prepared_statement(
"SELECT * from raw_extrinsic_metadata "
@@ -993,9 +953,12 @@
after: datetime.datetime,
*,
statement,
- ):
- return self._execute_with_retries(
- statement, [id, authority_url, after, authority_type]
+ ) -> Iterable[RawExtrinsicMetadataRow]:
+ return map(
+ RawExtrinsicMetadataRow.from_dict,
+ self._execute_with_retries(
+ statement, [id, authority_url, after, authority_type]
+ ),
)
@_prepared_statement(
@@ -1013,17 +976,20 @@
after_fetcher_version: str,
*,
statement,
- ):
- return self._execute_with_retries(
- statement,
- [
- id,
- authority_type,
- authority_url,
- after_date,
- after_fetcher_name,
- after_fetcher_version,
- ],
+ ) -> Iterable[RawExtrinsicMetadataRow]:
+ return map(
+ RawExtrinsicMetadataRow.from_dict,
+ self._execute_with_retries(
+ statement,
+ [
+ id,
+ authority_type,
+ authority_url,
+ after_date,
+ after_fetcher_name,
+ after_fetcher_version,
+ ],
+ ),
)
@_prepared_statement(
@@ -1032,9 +998,10 @@
)
def raw_extrinsic_metadata_get(
self, id: str, authority_type: str, authority_url: str, *, statement
- ) -> Iterable[Row]:
- return self._execute_with_retries(
- statement, [id, authority_url, authority_type]
+ ) -> Iterable[RawExtrinsicMetadataRow]:
+ return map(
+ RawExtrinsicMetadataRow.from_dict,
+ self._execute_with_retries(statement, [id, authority_url, authority_type]),
)
##########################
@@ -1045,8 +1012,6 @@
def check_read(self, *, statement):
self._execute_with_retries(statement, [])
- @_prepared_statement(
- "SELECT object_type, count FROM object_count WHERE partition_key=0"
- )
+ @_prepared_statement("SELECT * FROM object_count WHERE partition_key=0")
def stat_counters(self, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [])
+ return map(ObjectCountRow.from_dict, self._execute_with_retries(statement, []))
diff --git a/swh/storage/cassandra/model.py b/swh/storage/cassandra/model.py
new file mode 100644
--- /dev/null
+++ b/swh/storage/cassandra/model.py
@@ -0,0 +1,196 @@
+# Copyright (C) 2020 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+"""Classes representing tables in the Cassandra database.
+
+They are very close to classes found in swh.model.model, but most of
+them are subtly different:
+
+* Large objects are split into other classes (eg. RevisionRow has no
+ 'parents' field, because parents are stored in a different table,
+ represented by RevisionParentRow)
+* They have a "cols" field, which returns the list of column names
+ of the table
+* They only use types that map directly to Cassandra's schema (ie. no enums)
+
+Therefore, this model doesn't reuse swh.model.model, except for types
+that can be mapped to UDTs (Person and TimestampWithTimezone).
+"""
+
+import dataclasses
+import datetime
+from typing import Any, Dict, List, Optional, Type, TypeVar
+
+from swh.model.model import Person, TimestampWithTimezone
+
+
+T = TypeVar("T", bound="BaseRow")
+
+
+class BaseRow:
+ @classmethod
+ def from_dict(cls: Type[T], d: Dict[str, Any]) -> T:
+ return cls(**d) # type: ignore
+
+ @classmethod
+ def cols(cls) -> List[str]:
+ return [field.name for field in dataclasses.fields(cls)]
+
+ def to_dict(self) -> Dict[str, Any]:
+ return dataclasses.asdict(self)
+
+
+@dataclasses.dataclass
+class ContentRow(BaseRow):
+ sha1: bytes
+ sha1_git: bytes
+ sha256: bytes
+ blake2s256: bytes
+ length: int
+ ctime: datetime.datetime
+ status: str
+
+
+@dataclasses.dataclass
+class SkippedContentRow(BaseRow):
+ sha1: Optional[bytes]
+ sha1_git: Optional[bytes]
+ sha256: Optional[bytes]
+ blake2s256: Optional[bytes]
+ length: Optional[int]
+ ctime: Optional[datetime.datetime]
+ status: str
+ reason: str
+ origin: str
+
+
+@dataclasses.dataclass
+class DirectoryRow(BaseRow):
+ id: bytes
+
+
+@dataclasses.dataclass
+class DirectoryEntryRow(BaseRow):
+ directory_id: bytes
+ name: bytes
+ target: bytes
+ perms: int
+ type: str
+
+
+@dataclasses.dataclass
+class RevisionRow(BaseRow):
+ id: bytes
+ date: Optional[TimestampWithTimezone]
+ committer_date: Optional[TimestampWithTimezone]
+ type: str
+ directory: bytes
+ message: bytes
+ author: Person
+ committer: Person
+ synthetic: bool
+ metadata: str
+ extra_headers: dict
+
+
+@dataclasses.dataclass
+class RevisionParentRow(BaseRow):
+ id: bytes
+ parent_rank: int
+ parent_id: bytes
+
+
+@dataclasses.dataclass
+class ReleaseRow(BaseRow):
+ id: bytes
+ target_type: str
+ target: bytes
+ date: TimestampWithTimezone
+ name: bytes
+ message: bytes
+ author: Person
+ synthetic: bool
+
+
+@dataclasses.dataclass
+class SnapshotRow(BaseRow):
+ id: bytes
+
+
+@dataclasses.dataclass
+class SnapshotBranchRow(BaseRow):
+ snapshot_id: bytes
+ name: bytes
+ target_type: Optional[str]
+ target: Optional[bytes]
+
+
+@dataclasses.dataclass
+class OriginVisitRow(BaseRow):
+ origin: str
+ visit: int
+ date: datetime.datetime
+ type: str
+
+
+@dataclasses.dataclass
+class OriginVisitStatusRow(BaseRow):
+ origin: str
+ visit: int
+ date: datetime.datetime
+ status: str
+ metadata: str
+ snapshot: bytes
+
+
+@dataclasses.dataclass
+class OriginRow(BaseRow):
+ sha1: bytes
+ url: str
+ next_visit_id: int
+
+
+@dataclasses.dataclass
+class MetadataAuthorityRow(BaseRow):
+ url: str
+ type: str
+ metadata: str
+
+
+@dataclasses.dataclass
+class MetadataFetcherRow(BaseRow):
+ name: str
+ version: str
+ metadata: str
+
+
+@dataclasses.dataclass
+class RawExtrinsicMetadataRow(BaseRow):
+ type: str
+ id: str
+
+ authority_type: str
+ authority_url: str
+ discovery_date: datetime.datetime
+ fetcher_name: str
+ fetcher_version: str
+
+ format: str
+ metadata: bytes
+
+ origin: Optional[str]
+ visit: Optional[int]
+ snapshot: Optional[str]
+ release: Optional[str]
+ revision: Optional[str]
+ path: Optional[bytes]
+ directory: Optional[str]
+
+
+@dataclasses.dataclass
+class ObjectCountRow(BaseRow):
+ partition_key: int
+ object_type: str
+ count: int
diff --git a/swh/storage/cassandra/schema.py b/swh/storage/cassandra/schema.py
--- a/swh/storage/cassandra/schema.py
+++ b/swh/storage/cassandra/schema.py
@@ -181,7 +181,6 @@
CREATE TABLE IF NOT EXISTS origin (
sha1 blob PRIMARY KEY,
url text,
- type text,
next_visit_id int,
-- We need integer visit ids for compatibility with the pgsql
-- storage, so we're using lightweight transactions with this trick:
diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py
--- a/swh/storage/cassandra/storage.py
+++ b/swh/storage/cassandra/storage.py
@@ -48,10 +48,24 @@
from swh.storage.utils import map_optional, now
from ..exc import StorageArgumentException, HashCollision
-from .common import TOKEN_BEGIN, TOKEN_END
+from .common import TOKEN_BEGIN, TOKEN_END, hash_url, remove_keys
from . import converters
from .cql import CqlRunner
from .schema import HASH_ALGORITHMS
+from .model import (
+ ContentRow,
+ DirectoryEntryRow,
+ DirectoryRow,
+ MetadataAuthorityRow,
+ MetadataFetcherRow,
+ OriginRow,
+ OriginVisitRow,
+ RawExtrinsicMetadataRow,
+ RevisionParentRow,
+ SkippedContentRow,
+ SnapshotBranchRow,
+ SnapshotRow,
+)
# Max block size of contents to return
@@ -139,7 +153,9 @@
collisions.append(content.hashes())
raise HashCollision(algo, content.get_hash(algo), collisions)
- (token, insertion_finalizer) = self._cql_runner.content_add_prepare(content)
+ (token, insertion_finalizer) = self._cql_runner.content_add_prepare(
+ ContentRow(**remove_keys(content.to_dict(), ("data",)))
+ )
# Then add to index tables
for algo in HASH_ALGORITHMS:
@@ -203,14 +219,12 @@
range_start, range_end, limit + 1
)
contents = []
- last_id: Optional[int] = None
- for counter, row in enumerate(rows):
+ for counter, (tok, row) in enumerate(rows):
if row.status == "absent":
continue
- row_d = row._asdict()
- last_id = row_d.pop("tok")
+ row_d = row.to_dict()
if counter >= limit:
- next_page_token = str(last_id)
+ next_page_token = str(tok)
break
contents.append(Content(**row_d))
@@ -223,7 +237,7 @@
# Get all (sha1, sha1_git, sha256, blake2s256) whose sha1
# matches the argument, from the index table ('content_by_sha1')
for row in self._content_get_from_hash("sha1", sha1):
- row_d = row._asdict()
+ row_d = row.to_dict()
row_d.pop("ctime")
content = Content(**row_d)
contents_by_sha1[content.sha1] = content
@@ -251,7 +265,7 @@
break
else:
# All hashes match, keep this row.
- row_d = row._asdict()
+ row_d = row.to_dict()
row_d["ctime"] = row.ctime.replace(tzinfo=datetime.timezone.utc)
results.append(Content(**row_d))
return results
@@ -314,7 +328,7 @@
for content in contents:
# Compute token of the row in the main table
(token, insertion_finalizer) = self._cql_runner.skipped_content_add_prepare(
- content
+ SkippedContentRow.from_dict({"origin": None, **content.to_dict()})
)
# Then add to index tables
@@ -348,13 +362,13 @@
# Add directory entries to the 'directory_entry' table
for entry in directory.entries:
self._cql_runner.directory_entry_add_one(
- {**entry.to_dict(), "directory_id": directory.id}
+ DirectoryEntryRow(directory_id=directory.id, **entry.to_dict())
)
# Add the directory *after* adding all the entries, so someone
# calling snapshot_get_branch in the meantime won't end up
# with half the entries.
- self._cql_runner.directory_add_one(directory.id)
+ self._cql_runner.directory_add_one(DirectoryRow(id=directory.id))
return {"directory:add": len(directories)}
@@ -387,10 +401,10 @@
rows = list(self._cql_runner.directory_entry_get([directory_id]))
for row in rows:
+ entry_d = row.to_dict()
# Build and yield the directory entry dict
- entry = row._asdict()
- del entry["directory_id"]
- entry = DirectoryEntry.from_dict(entry)
+ del entry_d["directory_id"]
+ entry = DirectoryEntry.from_dict(entry_d)
ret = self._join_dentry_to_content(entry)
ret["name"] = prefix + ret["name"]
ret["dir_id"] = directory_id
@@ -458,9 +472,11 @@
revobject = converters.revision_to_db(revision)
if revobject:
# Add parents first
- for (rank, parent) in enumerate(revobject["parents"]):
+ for (rank, parent) in enumerate(revision.parents):
self._cql_runner.revision_parent_add_one(
- revobject["id"], rank, parent
+ RevisionParentRow(
+ id=revobject.id, parent_rank=rank, parent_id=parent
+ )
)
# Then write the main revision row.
@@ -484,10 +500,9 @@
# (it might have lower latency, but requires more code and more
# bandwidth, because revision id would be part of each returned
# row)
- parent_rows = self._cql_runner.revision_parent_get(row.id)
+ parents = tuple(self._cql_runner.revision_parent_get(row.id))
# parent_rank is the clustering key, so results are already
# sorted by rank.
- parents = tuple(row.parent_id for row in parent_rows)
rev = converters.revision_from_db(row, parents=parents)
revs[rev.id] = rev.to_dict()
@@ -514,27 +529,35 @@
# results (ie. not return only a subset of a revision's parents
# if it is being written)
if short:
- rows = self._cql_runner.revision_get_ids(rev_ids)
+ ids = self._cql_runner.revision_get_ids(rev_ids)
+ for id_ in ids:
+ # TODO: use a single query to get all parents?
+ # (it might have less latency, but requires less code and more
+ # bandwidth (because revision id would be part of each returned
+ # row)
+ parents = tuple(self._cql_runner.revision_parent_get(id_))
+
+ # parent_rank is the clustering key, so results are already
+ # sorted by rank.
+
+ yield (id_, parents)
+ yield from self._get_parent_revs(parents, seen, limit, short)
else:
rows = self._cql_runner.revision_get(rev_ids)
- for row in rows:
- # TODO: use a single query to get all parents?
- # (it might have less latency, but requires less code and more
- # bandwidth (because revision id would be part of each returned
- # row)
- parent_rows = self._cql_runner.revision_parent_get(row.id)
+ for row in rows:
+ # TODO: use a single query to get all parents?
+ # (it might have less latency, but requires less code and more
+ # bandwidth (because revision id would be part of each returned
+ # row)
+ parents = tuple(self._cql_runner.revision_parent_get(row.id))
- # parent_rank is the clustering key, so results are already
- # sorted by rank.
- parents = tuple(row.parent_id for row in parent_rows)
+ # parent_rank is the clustering key, so results are already
+ # sorted by rank.
- if short:
- yield (row.id, parents)
- else:
rev = converters.revision_from_db(row, parents=parents)
yield rev.to_dict()
- yield from self._get_parent_revs(parents, seen, limit, short)
+ yield from self._get_parent_revs(parents, seen, limit, short)
def revision_log(
self, revisions: List[Sha1Git], limit: Optional[int] = None
@@ -595,24 +618,24 @@
# Add branches
for (branch_name, branch) in snapshot.branches.items():
if branch is None:
- target_type = None
- target = None
+ target_type: Optional[str] = None
+ target: Optional[bytes] = None
else:
target_type = branch.target_type.value
target = branch.target
self._cql_runner.snapshot_branch_add_one(
- {
- "snapshot_id": snapshot.id,
- "name": branch_name,
- "target_type": target_type,
- "target": target,
- }
+ SnapshotBranchRow(
+ snapshot_id=snapshot.id,
+ name=branch_name,
+ target_type=target_type,
+ target=target,
+ )
)
# Add the snapshot *after* adding all the branches, so someone
# calling snapshot_get_branch in the meantime won't end up
# with half the branches.
- self._cql_runner.snapshot_add_one(snapshot.id)
+ self._cql_runner.snapshot_add_one(SnapshotRow(id=snapshot.id))
return {"snapshot:add": len(snapshots)}
@@ -647,13 +670,7 @@
# Makes sure we don't fetch branches for a snapshot that is
# being added.
return None
- rows = list(self._cql_runner.snapshot_count_branches(snapshot_id))
- assert len(rows) == 1
- (nb_none, counts) = rows[0].counts
- counts = dict(counts)
- if nb_none:
- counts[None] = nb_none
- return counts
+ return self._cql_runner.snapshot_count_branches(snapshot_id)
def snapshot_get_branches(
self,
@@ -760,8 +777,8 @@
def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]:
results = []
for sha1 in sha1s:
- rows = self._cql_runner.origin_get_by_sha1(sha1)
- origin = {"url": rows.one().url} if rows else None
+ rows = list(self._cql_runner.origin_get_by_sha1(sha1))
+ origin = {"url": rows[0].url} if rows else None
results.append(origin)
return results
@@ -778,10 +795,10 @@
origins = []
# Take one more origin so we can reuse it as the next page token if any
- for row in self._cql_runner.origin_list(start_token, limit + 1):
+ for (tok, row) in self._cql_runner.origin_list(start_token, limit + 1):
origins.append(Origin(url=row.url))
# keep reference of the last id for pagination purposes
- last_id = row.tok
+ last_id = tok
if len(origins) > limit:
# last origin id is the next page token
@@ -805,15 +822,17 @@
next_page_token = None
offset = int(page_token) if page_token else 0
- origins = self._cql_runner.origin_iter_all()
+ origin_rows = [row for row in self._cql_runner.origin_iter_all()]
if regexp:
pat = re.compile(url_pattern)
- origins = [Origin(orig.url) for orig in origins if pat.search(orig.url)]
+ origin_rows = [row for row in origin_rows if pat.search(row.url)]
else:
- origins = [Origin(orig.url) for orig in origins if url_pattern in orig.url]
+ origin_rows = [row for row in origin_rows if url_pattern in row.url]
if with_visit:
- origins = [Origin(orig.url) for orig in origins if orig.next_visit_id > 1]
+ origin_rows = [row for row in origin_rows if row.next_visit_id > 1]
+
+ origins = [Origin(url=row.url) for row in origin_rows]
origins = origins[offset : offset + limit + 1]
if len(origins) > limit:
@@ -829,7 +848,9 @@
to_add = [ori for ori in origins if self.origin_get_one(ori.url) is None]
self.journal_writer.origin_add(to_add)
for origin in to_add:
- self._cql_runner.origin_add_one(origin)
+ self._cql_runner.origin_add_one(
+ OriginRow(sha1=hash_url(origin.url), url=origin.url, next_visit_id=1)
+ )
return {"origin:add": len(to_add)}
def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]:
@@ -848,7 +869,7 @@
)
visit = attr.evolve(visit, visit=visit_id)
self.journal_writer.origin_visit_add([visit])
- self._cql_runner.origin_visit_add_one(visit)
+ self._cql_runner.origin_visit_add_one(OriginVisitRow(**visit.to_dict()))
assert visit.visit is not None
all_visits.append(visit)
self._origin_visit_status_add(
@@ -866,7 +887,9 @@
def _origin_visit_status_add(self, visit_status: OriginVisitStatus) -> None:
"""Add an origin visit status"""
self.journal_writer.origin_visit_status_add([visit_status])
- self._cql_runner.origin_visit_status_add_one(visit_status)
+ self._cql_runner.origin_visit_status_add_one(
+ converters.visit_status_to_row(visit_status)
+ )
def origin_visit_status_add(self, visit_statuses: List[OriginVisitStatus]) -> None:
# First round to check existence (fail early if any is ko)
@@ -912,7 +935,7 @@
@staticmethod
def _format_origin_visit_row(visit):
return {
- **visit._asdict(),
+ **visit.to_dict(),
"origin": visit.origin,
"date": visit.date.replace(tzinfo=datetime.timezone.utc),
}
@@ -1048,8 +1071,10 @@
f"Unknown allowed statuses {','.join(allowed_statuses)}, only "
f"{','.join(VISIT_STATUSES)} authorized"
)
- rows = self._cql_runner.origin_visit_status_get(
- origin_url, visit, allowed_statuses, require_snapshot
+ rows = list(
+ self._cql_runner.origin_visit_status_get(
+ origin_url, visit, allowed_statuses, require_snapshot
+ )
)
# filtering is done python side as we cannot do it server side
if allowed_statuses:
@@ -1113,7 +1138,7 @@
)
try:
- self._cql_runner.raw_extrinsic_metadata_add(
+ row = RawExtrinsicMetadataRow(
type=metadata_entry.type.value,
id=str(metadata_entry.id),
authority_type=metadata_entry.authority.type.value,
@@ -1131,6 +1156,7 @@
path=metadata_entry.path,
directory=map_optional(str, metadata_entry.directory),
)
+ self._cql_runner.raw_extrinsic_metadata_add(row)
except TypeError as e:
raise StorageArgumentException(*e.args)
@@ -1236,9 +1262,11 @@
self.journal_writer.metadata_fetcher_add(fetchers)
for fetcher in fetchers:
self._cql_runner.metadata_fetcher_add(
- fetcher.name,
- fetcher.version,
- json.dumps(map_optional(dict, fetcher.metadata)),
+ MetadataFetcherRow(
+ name=fetcher.name,
+ version=fetcher.version,
+ metadata=json.dumps(map_optional(dict, fetcher.metadata)),
+ )
)
def metadata_fetcher_get(
@@ -1258,9 +1286,11 @@
self.journal_writer.metadata_authority_add(authorities)
for authority in authorities:
self._cql_runner.metadata_authority_add(
- authority.url,
- authority.type.value,
- json.dumps(map_optional(dict, authority.metadata)),
+ MetadataAuthorityRow(
+ url=authority.url,
+ type=authority.type.value,
+ metadata=json.dumps(map_optional(dict, authority.metadata)),
+ )
)
def metadata_authority_get(
diff --git a/swh/storage/tests/test_cassandra.py b/swh/storage/tests/test_cassandra.py
--- a/swh/storage/tests/test_cassandra.py
+++ b/swh/storage/tests/test_cassandra.py
@@ -3,22 +3,22 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
-import attr
+import datetime
import os
import signal
import socket
import subprocess
import time
-
-from collections import namedtuple
from typing import Dict
+import attr
import pytest
from swh.core.api.classes import stream_results
from swh.storage import get_storage
from swh.storage.cassandra import create_keyspace
from swh.storage.cassandra.schema import TABLES, HASH_ALGORITHMS
+from swh.storage.cassandra.model import ContentRow
from swh.storage.utils import now
from swh.storage.tests.test_storage import TestStorage as _TestStorage
@@ -209,12 +209,17 @@
)
# For all tokens, always return cont
- Row = namedtuple("Row", HASH_ALGORITHMS)
-
def mock_cgft(token):
nonlocal called
called += 1
- return [Row(**{algo: getattr(cont, algo) for algo in HASH_ALGORITHMS})]
+ return [
+ ContentRow(
+ length=10,
+ ctime=datetime.datetime.now(),
+ status="present",
+ **{algo: getattr(cont, algo) for algo in HASH_ALGORITHMS},
+ )
+ ]
mocker.patch.object(
swh_storage._cql_runner, "content_get_from_token", mock_cgft
@@ -253,13 +258,12 @@
# For all tokens, always return cont and cont2
cols = list(set(cont.to_dict()) - {"data"})
- Row = namedtuple("Row", cols)
def mock_cgft(token):
nonlocal called
called += 1
return [
- Row(**{col: getattr(cont, col) for col in cols})
+ ContentRow(**{col: getattr(cont, col) for col in cols},)
for cont in [cont, cont2]
]
@@ -299,13 +303,12 @@
# For all tokens, always return cont and cont2
cols = list(set(cont.to_dict()) - {"data"})
- Row = namedtuple("Row", cols)
def mock_cgft(token):
nonlocal called
called += 1
return [
- Row(**{col: getattr(cont, col) for col in cols})
+ ContentRow(**{col: getattr(cont, col) for col in cols})
for cont in [cont, cont2]
]
@@ -342,16 +345,15 @@
rows[tok] = row_d
# For all tokens, always return cont
- keys = set(["tok"] + list(content.to_dict().keys())).difference(set(["data"]))
- Row = namedtuple("Row", keys)
def mock_content_get_token_range(range_start, range_end, limit):
nonlocal called
called += 1
for tok in list(rows.keys()) * 3: # yield multiple times the same tok
- row_d = rows[tok]
- yield Row(**row_d)
+ row_d = dict(rows[tok].items())
+ row_d.pop("tok")
+ yield (tok, ContentRow(**row_d))
mocker.patch.object(
swh_storage._cql_runner,