Page Menu
Home
Software Heritage
Search
Configure Global Search
Log In
Files
F7163803
D3756.id.diff
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
63 KB
Subscribers
None
D3756.id.diff
View Options
diff --git a/swh/storage/cassandra/common.py b/swh/storage/cassandra/common.py
--- a/swh/storage/cassandra/common.py
+++ b/swh/storage/cassandra/common.py
@@ -5,9 +5,7 @@
import hashlib
-
-
-Row = tuple
+from typing import Any, Dict, Tuple
TOKEN_BEGIN = -(2 ** 63)
@@ -18,3 +16,7 @@
def hash_url(url: str) -> bytes:
return hashlib.sha1(url.encode("ascii")).digest()
+
+
+def remove_keys(d: Dict[str, Any], keys: Tuple[str, ...]) -> Dict[str, Any]:
+ return {k: v for (k, v) in d.items() if k not in keys}
diff --git a/swh/storage/cassandra/converters.py b/swh/storage/cassandra/converters.py
--- a/swh/storage/cassandra/converters.py
+++ b/swh/storage/cassandra/converters.py
@@ -8,7 +8,7 @@
import attr
from copy import deepcopy
-from typing import Any, Dict, Tuple
+from typing import Dict, Tuple
from swh.model.model import (
ObjectType,
@@ -21,10 +21,11 @@
)
from swh.model.hashutil import DEFAULT_ALGORITHMS
-from .common import Row
+from .common import remove_keys
+from .model import OriginVisitRow, OriginVisitStatusRow, RevisionRow, ReleaseRow
-def revision_to_db(revision: Revision) -> Dict[str, Any]:
+def revision_to_db(revision: Revision) -> RevisionRow:
# we use a deepcopy of the dict because we do not want to recurse the
# Model->dict conversion (to keep Timestamp & al. entities), BUT we do not
# want to modify original metadata (embedded in the Model entity), so we
@@ -39,11 +40,13 @@
)
db_revision["extra_headers"] = extra_headers
db_revision["type"] = db_revision["type"].value
- return db_revision
+ return RevisionRow(**remove_keys(db_revision, ("parents",)))
-def revision_from_db(db_revision: Row, parents: Tuple[Sha1Git, ...]) -> Revision:
- revision = db_revision._asdict() # type: ignore
+def revision_from_db(
+ db_revision: RevisionRow, parents: Tuple[Sha1Git, ...]
+) -> Revision:
+ revision = db_revision.to_dict()
metadata = json.loads(revision.pop("metadata", None))
extra_headers = revision.pop("extra_headers", ())
if not extra_headers and metadata and "extra_headers" in metadata:
@@ -59,18 +62,18 @@
)
-def release_to_db(release: Release) -> Dict[str, Any]:
+def release_to_db(release: Release) -> ReleaseRow:
db_release = attr.asdict(release, recurse=False)
db_release["target_type"] = db_release["target_type"].value
- return db_release
+ return ReleaseRow(**remove_keys(db_release, ("metadata",)))
-def release_from_db(db_release: Row) -> Release:
- release = db_release._asdict() # type: ignore
+def release_from_db(db_release: ReleaseRow) -> Release:
+ release = db_release.to_dict()
return Release(target_type=ObjectType(release.pop("target_type")), **release,)
-def row_to_content_hashes(row: Row) -> Dict[str, bytes]:
+def row_to_content_hashes(row: ReleaseRow) -> Dict[str, bytes]:
"""Convert cassandra row to a content hashes
"""
@@ -80,7 +83,7 @@
return hashes
-def row_to_visit(row) -> OriginVisit:
+def row_to_visit(row: OriginVisitRow) -> OriginVisit:
"""Format a row representing an origin_visit to an actual OriginVisit.
"""
@@ -92,15 +95,19 @@
)
-def row_to_visit_status(row) -> OriginVisitStatus:
+def row_to_visit_status(row: OriginVisitStatusRow) -> OriginVisitStatus:
"""Format a row representing a visit_status to an actual OriginVisitStatus.
"""
return OriginVisitStatus.from_dict(
{
- **row._asdict(),
- "origin": row.origin,
+ **row.to_dict(),
"date": row.date.replace(tzinfo=datetime.timezone.utc),
"metadata": (json.loads(row.metadata) if row.metadata else None),
}
)
+
+
+def visit_status_to_row(status: OriginVisitStatus) -> OriginVisitStatusRow:
+ d = status.to_dict()
+ return OriginVisitStatusRow.from_dict({**d, "metadata": json.dumps(d["metadata"])})
diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py
--- a/swh/storage/cassandra/cql.py
+++ b/swh/storage/cassandra/cql.py
@@ -3,9 +3,9 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
+import dataclasses
import datetime
import functools
-import json
import logging
import random
from typing import (
@@ -17,13 +17,14 @@
List,
Optional,
Tuple,
+ Type,
TypeVar,
)
from cassandra import CoordinationFailure
from cassandra.cluster import Cluster, EXEC_PROFILE_DEFAULT, ExecutionProfile, ResultSet
from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy
-from cassandra.query import PreparedStatement, BoundStatement
+from cassandra.query import PreparedStatement, BoundStatement, dict_factory
from tenacity import (
retry,
stop_after_attempt,
@@ -32,20 +33,36 @@
)
from swh.model.model import (
+ Content,
+ SkippedContent,
Sha1Git,
TimestampWithTimezone,
Timestamp,
Person,
- Content,
- SkippedContent,
- OriginVisit,
- OriginVisitStatus,
- Origin,
)
from swh.storage.interface import ListOrder
-from .common import Row, TOKEN_BEGIN, TOKEN_END, hash_url
+from .common import TOKEN_BEGIN, TOKEN_END, hash_url, remove_keys
+from .model import (
+ BaseRow,
+ ContentRow,
+ DirectoryEntryRow,
+ DirectoryRow,
+ MetadataAuthorityRow,
+ MetadataFetcherRow,
+ ObjectCountRow,
+ OriginRow,
+ OriginVisitRow,
+ OriginVisitStatusRow,
+ RawExtrinsicMetadataRow,
+ ReleaseRow,
+ RevisionParentRow,
+ RevisionRow,
+ SkippedContentRow,
+ SnapshotBranchRow,
+ SnapshotRow,
+)
from .schema import CREATE_TABLES_QUERIES, HASH_ALGORITHMS
@@ -54,7 +71,8 @@
_execution_profiles = {
EXEC_PROFILE_DEFAULT: ExecutionProfile(
- load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy())
+ load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy()),
+ row_factory=dict_factory,
),
}
# Configuration for cassandra-driver's access to servers:
@@ -166,11 +184,14 @@
) -> None:
self._execute_with_retries(statement, [nb, object_type])
- def _add_one(self, statement, object_type: str, obj, keys: List[str]) -> None:
- self._increment_counter(object_type, 1)
- self._execute_with_retries(statement, [getattr(obj, key) for key in keys])
+ def _add_one(self, statement, object_type: Optional[str], obj: BaseRow) -> None:
+ if object_type:
+ self._increment_counter(object_type, 1)
+ self._execute_with_retries(statement, dataclasses.astuple(obj))
- def _get_random_row(self, statement) -> Optional[Row]:
+ _T = TypeVar("_T", bound=BaseRow)
+
+ def _get_random_row(self, row_class: Type[_T], statement) -> Optional[_T]: # noqa
"""Takes a prepared statement of the form
"SELECT * FROM <table> WHERE token(<keys>) > ? LIMIT 1"
and uses it to return a random row"""
@@ -181,13 +202,13 @@
# the row with the smallest token
rows = self._execute_with_retries(statement, [TOKEN_BEGIN])
if rows:
- return rows.one()
+ return row_class.from_dict(rows.one()) # type: ignore
else:
return None
def _missing(self, statement, ids):
- res = self._execute_with_retries(statement, [ids])
- found_ids = {id_ for (id_,) in res}
+ rows = self._execute_with_retries(statement, [ids])
+ found_ids = {row["id"] for row in rows}
return [id_ for id_ in ids if id_ not in found_ids]
##########################
@@ -195,15 +216,6 @@
##########################
_content_pk = ["sha1", "sha1_git", "sha256", "blake2s256"]
- _content_keys = [
- "sha1",
- "sha1_git",
- "sha256",
- "blake2s256",
- "length",
- "ctime",
- "status",
- ]
def _content_add_finalize(self, statement: BoundStatement) -> None:
"""Returned currified by content_add_prepare, to be called when the
@@ -211,16 +223,14 @@
self._execute_with_retries(statement, None)
self._increment_counter("content", 1)
- @_prepared_insert_statement("content", _content_keys)
+ @_prepared_insert_statement("content", ContentRow.cols())
def content_add_prepare(
- self, content, *, statement
+ self, content: ContentRow, *, statement
) -> Tuple[int, Callable[[], None]]:
"""Prepares insertion of a Content to the main 'content' table.
Returns a token (to be used in secondary tables), and a function to be
called to perform the insertion in the main table."""
- statement = statement.bind(
- [getattr(content, key) for key in self._content_keys]
- )
+ statement = statement.bind(dataclasses.astuple(content))
# Type used for hashing keys (usually, it will be
# cassandra.metadata.Murmur3Token)
@@ -245,7 +255,7 @@
)
def content_get_from_pk(
self, content_hashes: Dict[str, bytes], *, statement
- ) -> Optional[Row]:
+ ) -> Optional[ContentRow]:
rows = list(
self._execute_with_retries(
statement, [content_hashes[algo] for algo in HASH_ALGORITHMS]
@@ -253,38 +263,44 @@
)
assert len(rows) <= 1
if rows:
- return rows[0]
+ return ContentRow(**rows[0])
else:
return None
@_prepared_statement(
"SELECT * FROM content WHERE token(" + ", ".join(_content_pk) + ") = ?"
)
- def content_get_from_token(self, token, *, statement) -> Iterable[Row]:
- return self._execute_with_retries(statement, [token])
+ def content_get_from_token(self, token, *, statement) -> Iterable[ContentRow]:
+ return map(ContentRow.from_dict, self._execute_with_retries(statement, [token]))
@_prepared_statement(
"SELECT * FROM content WHERE token(%s) > ? LIMIT 1" % ", ".join(_content_pk)
)
- def content_get_random(self, *, statement) -> Optional[Row]:
- return self._get_random_row(statement)
+ def content_get_random(self, *, statement) -> Optional[ContentRow]:
+ return self._get_random_row(ContentRow, statement)
@_prepared_statement(
(
"SELECT token({0}) AS tok, {1} FROM content "
"WHERE token({0}) >= ? AND token({0}) <= ? LIMIT ?"
- ).format(", ".join(_content_pk), ", ".join(_content_keys))
+ ).format(", ".join(_content_pk), ", ".join(ContentRow.cols()))
)
def content_get_token_range(
self, start: int, end: int, limit: int, *, statement
- ) -> Iterable[Row]:
- return self._execute_with_retries(statement, [start, end, limit])
+ ) -> Iterable[Tuple[int, ContentRow]]:
+ """Returns an iterable of (token, row)"""
+ return (
+ (row["tok"], ContentRow.from_dict(remove_keys(row, ("tok",))))
+ for row in self._execute_with_retries(statement, [start, end, limit])
+ )
##########################
# 'content_by_*' tables
##########################
- @_prepared_statement("SELECT sha1_git FROM content_by_sha1_git WHERE sha1_git IN ?")
+ @_prepared_statement(
+ "SELECT sha1_git AS id FROM content_by_sha1_git WHERE sha1_git IN ?"
+ )
def content_missing_by_sha1_git(
self, ids: List[bytes], *, statement
) -> List[bytes]:
@@ -303,24 +319,15 @@
) -> Iterable[int]:
assert algo in HASH_ALGORITHMS
query = f"SELECT target_token FROM content_by_{algo} WHERE {algo} = %s"
- return (tok for (tok,) in self._execute_with_retries(query, [hash_]))
+ return (
+ row["target_token"] for row in self._execute_with_retries(query, [hash_])
+ )
##########################
# 'skipped_content' table
##########################
_skipped_content_pk = ["sha1", "sha1_git", "sha256", "blake2s256"]
- _skipped_content_keys = [
- "sha1",
- "sha1_git",
- "sha256",
- "blake2s256",
- "length",
- "ctime",
- "status",
- "reason",
- "origin",
- ]
_magic_null_pk = b"<null>"
"""
NULLs (or all-empty blobs) are not allowed in primary keys; instead use a
@@ -333,7 +340,7 @@
self._execute_with_retries(statement, None)
self._increment_counter("skipped_content", 1)
- @_prepared_insert_statement("skipped_content", _skipped_content_keys)
+ @_prepared_insert_statement("skipped_content", SkippedContentRow.cols())
def skipped_content_add_prepare(
self, content, *, statement
) -> Tuple[int, Callable[[], None]]:
@@ -343,14 +350,11 @@
# Replace NULLs (which are not allowed in the partition key) with
# an empty byte string
- content = content.to_dict()
for key in self._skipped_content_pk:
- if content[key] is None:
- content[key] = self._magic_null_pk
+ if getattr(content, key) is None:
+ setattr(content, key, self._magic_null_pk)
- statement = statement.bind(
- [content.get(key) for key in self._skipped_content_keys]
- )
+ statement = statement.bind(dataclasses.astuple(content))
# Type used for hashing keys (usually, it will be
# cassandra.metadata.Murmur3Token)
@@ -376,7 +380,7 @@
)
def skipped_content_get_from_pk(
self, content_hashes: Dict[str, bytes], *, statement
- ) -> Optional[Row]:
+ ) -> Optional[SkippedContentRow]:
rows = list(
self._execute_with_retries(
statement,
@@ -389,7 +393,7 @@
assert len(rows) <= 1
if rows:
# TODO: convert _magic_null_pk back to None?
- return rows[0]
+ return SkippedContentRow.from_dict(rows[0])
else:
return None
@@ -414,218 +418,198 @@
# 'revision' table
##########################
- _revision_keys = [
- "id",
- "date",
- "committer_date",
- "type",
- "directory",
- "message",
- "author",
- "committer",
- "synthetic",
- "metadata",
- "extra_headers",
- ]
-
@_prepared_exists_statement("revision")
def revision_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
return self._missing(statement, ids)
- @_prepared_insert_statement("revision", _revision_keys)
- def revision_add_one(self, revision: Dict[str, Any], *, statement) -> None:
- self._execute_with_retries(
- statement, [revision[key] for key in self._revision_keys]
- )
- self._increment_counter("revision", 1)
+ @_prepared_insert_statement("revision", RevisionRow.cols())
+ def revision_add_one(self, revision: RevisionRow, *, statement) -> None:
+ self._add_one(statement, "revision", revision)
@_prepared_statement("SELECT id FROM revision WHERE id IN ?")
- def revision_get_ids(self, revision_ids, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [revision_ids])
+ def revision_get_ids(self, revision_ids, *, statement) -> Iterable[int]:
+ return (
+ row["id"] for row in self._execute_with_retries(statement, [revision_ids])
+ )
@_prepared_statement("SELECT * FROM revision WHERE id IN ?")
- def revision_get(self, revision_ids, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [revision_ids])
+ def revision_get(self, revision_ids, *, statement) -> Iterable[RevisionRow]:
+ return map(
+ RevisionRow.from_dict, self._execute_with_retries(statement, [revision_ids])
+ )
@_prepared_statement("SELECT * FROM revision WHERE token(id) > ? LIMIT 1")
- def revision_get_random(self, *, statement) -> Optional[Row]:
- return self._get_random_row(statement)
+ def revision_get_random(self, *, statement) -> Optional[RevisionRow]:
+ return self._get_random_row(RevisionRow, statement)
##########################
# 'revision_parent' table
##########################
- _revision_parent_keys = ["id", "parent_rank", "parent_id"]
-
- @_prepared_insert_statement("revision_parent", _revision_parent_keys)
+ @_prepared_insert_statement("revision_parent", RevisionParentRow.cols())
def revision_parent_add_one(
- self, id_: Sha1Git, parent_rank: int, parent_id: Sha1Git, *, statement
+ self, revision_parent: RevisionParentRow, *, statement
) -> None:
- self._execute_with_retries(statement, [id_, parent_rank, parent_id])
+ self._add_one(statement, None, revision_parent)
@_prepared_statement("SELECT parent_id FROM revision_parent WHERE id = ?")
- def revision_parent_get(self, revision_id: Sha1Git, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [revision_id])
+ def revision_parent_get(
+ self, revision_id: Sha1Git, *, statement
+ ) -> Iterable[bytes]:
+ return (
+ row["parent_id"]
+ for row in self._execute_with_retries(statement, [revision_id])
+ )
##########################
# 'release' table
##########################
- _release_keys = [
- "id",
- "target",
- "target_type",
- "date",
- "name",
- "message",
- "author",
- "synthetic",
- ]
-
@_prepared_exists_statement("release")
def release_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
return self._missing(statement, ids)
- @_prepared_insert_statement("release", _release_keys)
- def release_add_one(self, release: Dict[str, Any], *, statement) -> None:
- self._execute_with_retries(
- statement, [release[key] for key in self._release_keys]
- )
- self._increment_counter("release", 1)
+ @_prepared_insert_statement("release", ReleaseRow.cols())
+ def release_add_one(self, release: ReleaseRow, *, statement) -> None:
+ self._add_one(statement, "release", release)
@_prepared_statement("SELECT * FROM release WHERE id in ?")
- def release_get(self, release_ids: List[str], *, statement) -> None:
- return self._execute_with_retries(statement, [release_ids])
+ def release_get(self, release_ids: List[str], *, statement) -> Iterable[ReleaseRow]:
+ return map(
+ ReleaseRow.from_dict, self._execute_with_retries(statement, [release_ids])
+ )
@_prepared_statement("SELECT * FROM release WHERE token(id) > ? LIMIT 1")
- def release_get_random(self, *, statement) -> Optional[Row]:
- return self._get_random_row(statement)
+ def release_get_random(self, *, statement) -> Optional[ReleaseRow]:
+ return self._get_random_row(ReleaseRow, statement)
##########################
# 'directory' table
##########################
- _directory_keys = ["id"]
-
@_prepared_exists_statement("directory")
def directory_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
return self._missing(statement, ids)
- @_prepared_insert_statement("directory", _directory_keys)
- def directory_add_one(self, directory_id: Sha1Git, *, statement) -> None:
+ @_prepared_insert_statement("directory", DirectoryRow.cols())
+ def directory_add_one(self, directory: DirectoryRow, *, statement) -> None:
"""Called after all calls to directory_entry_add_one, to
commit/finalize the directory."""
- self._execute_with_retries(statement, [directory_id])
- self._increment_counter("directory", 1)
+ self._add_one(statement, "directory", directory)
@_prepared_statement("SELECT * FROM directory WHERE token(id) > ? LIMIT 1")
- def directory_get_random(self, *, statement) -> Optional[Row]:
- return self._get_random_row(statement)
+ def directory_get_random(self, *, statement) -> Optional[DirectoryRow]:
+ return self._get_random_row(DirectoryRow, statement)
##########################
# 'directory_entry' table
##########################
- _directory_entry_keys = ["directory_id", "name", "type", "target", "perms"]
-
- @_prepared_insert_statement("directory_entry", _directory_entry_keys)
- def directory_entry_add_one(self, entry: Dict[str, Any], *, statement) -> None:
- self._execute_with_retries(
- statement, [entry[key] for key in self._directory_entry_keys]
- )
+ @_prepared_insert_statement("directory_entry", DirectoryEntryRow.cols())
+ def directory_entry_add_one(self, entry: DirectoryEntryRow, *, statement) -> None:
+ self._add_one(statement, None, entry)
@_prepared_statement("SELECT * FROM directory_entry WHERE directory_id IN ?")
- def directory_entry_get(self, directory_ids, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [directory_ids])
+ def directory_entry_get(
+ self, directory_ids, *, statement
+ ) -> Iterable[DirectoryEntryRow]:
+ return map(
+ DirectoryEntryRow.from_dict,
+ self._execute_with_retries(statement, [directory_ids]),
+ )
##########################
# 'snapshot' table
##########################
- _snapshot_keys = ["id"]
-
@_prepared_exists_statement("snapshot")
def snapshot_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
return self._missing(statement, ids)
- @_prepared_insert_statement("snapshot", _snapshot_keys)
- def snapshot_add_one(self, snapshot_id: Sha1Git, *, statement) -> None:
- self._execute_with_retries(statement, [snapshot_id])
- self._increment_counter("snapshot", 1)
+ @_prepared_insert_statement("snapshot", SnapshotRow.cols())
+ def snapshot_add_one(self, snapshot: SnapshotRow, *, statement) -> None:
+ self._add_one(statement, "snapshot", snapshot)
@_prepared_statement("SELECT * FROM snapshot WHERE id = ?")
def snapshot_get(self, snapshot_id: Sha1Git, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [snapshot_id])
+ return map(
+ SnapshotRow.from_dict, self._execute_with_retries(statement, [snapshot_id])
+ )
@_prepared_statement("SELECT * FROM snapshot WHERE token(id) > ? LIMIT 1")
- def snapshot_get_random(self, *, statement) -> Optional[Row]:
- return self._get_random_row(statement)
+ def snapshot_get_random(self, *, statement) -> Optional[SnapshotRow]:
+ return self._get_random_row(SnapshotRow, statement)
##########################
# 'snapshot_branch' table
##########################
- _snapshot_branch_keys = ["snapshot_id", "name", "target_type", "target"]
-
- @_prepared_insert_statement("snapshot_branch", _snapshot_branch_keys)
- def snapshot_branch_add_one(self, branch: Dict[str, Any], *, statement) -> None:
- self._execute_with_retries(
- statement, [branch[key] for key in self._snapshot_branch_keys]
- )
+ @_prepared_insert_statement("snapshot_branch", SnapshotBranchRow.cols())
+ def snapshot_branch_add_one(self, branch: SnapshotBranchRow, *, statement) -> None:
+ self._add_one(statement, None, branch)
@_prepared_statement(
"SELECT ascii_bins_count(target_type) AS counts "
"FROM snapshot_branch "
"WHERE snapshot_id = ? "
)
- def snapshot_count_branches(self, snapshot_id: Sha1Git, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [snapshot_id])
+ def snapshot_count_branches(
+ self, snapshot_id: Sha1Git, *, statement
+ ) -> Dict[Optional[str], int]:
+ """Returns a dictionary from type names to the number of branches
+ of that type."""
+ row = self._execute_with_retries(statement, [snapshot_id]).one()
+ (nb_none, counts) = row["counts"]
+ return {None: nb_none, **counts}
@_prepared_statement(
"SELECT * FROM snapshot_branch WHERE snapshot_id = ? AND name >= ? LIMIT ?"
)
def snapshot_branch_get(
self, snapshot_id: Sha1Git, from_: bytes, limit: int, *, statement
- ) -> ResultSet:
- return self._execute_with_retries(statement, [snapshot_id, from_, limit])
+ ) -> Iterable[SnapshotBranchRow]:
+ return map(
+ SnapshotBranchRow.from_dict,
+ self._execute_with_retries(statement, [snapshot_id, from_, limit]),
+ )
##########################
# 'origin' table
##########################
- origin_keys = ["sha1", "url", "type", "next_visit_id"]
-
- @_prepared_statement(
- "INSERT INTO origin (sha1, url, next_visit_id) "
- "VALUES (?, ?, 1) IF NOT EXISTS"
- )
- def origin_add_one(self, origin: Origin, *, statement) -> None:
- self._execute_with_retries(statement, [hash_url(origin.url), origin.url])
- self._increment_counter("origin", 1)
+ @_prepared_insert_statement("origin", OriginRow.cols())
+ def origin_add_one(self, origin: OriginRow, *, statement) -> None:
+ self._add_one(statement, "origin", origin)
@_prepared_statement("SELECT * FROM origin WHERE sha1 = ?")
- def origin_get_by_sha1(self, sha1: bytes, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [sha1])
+ def origin_get_by_sha1(self, sha1: bytes, *, statement) -> Iterable[OriginRow]:
+ return map(OriginRow.from_dict, self._execute_with_retries(statement, [sha1]))
- def origin_get_by_url(self, url: str) -> ResultSet:
+ def origin_get_by_url(self, url: str) -> Iterable[OriginRow]:
return self.origin_get_by_sha1(hash_url(url))
@_prepared_statement(
- f'SELECT token(sha1) AS tok, {", ".join(origin_keys)} '
+ f'SELECT token(sha1) AS tok, {", ".join(OriginRow.cols())} '
f"FROM origin WHERE token(sha1) >= ? LIMIT ?"
)
- def origin_list(self, start_token: int, limit: int, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [start_token, limit])
+ def origin_list(
+ self, start_token: int, limit: int, *, statement
+ ) -> Iterable[Tuple[int, OriginRow]]:
+ """Returns an iterable of (token, origin)"""
+ return (
+ (row["tok"], OriginRow.from_dict(remove_keys(row, ("tok",))))
+ for row in self._execute_with_retries(statement, [start_token, limit])
+ )
@_prepared_statement("SELECT * FROM origin")
- def origin_iter_all(self, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [])
+ def origin_iter_all(self, *, statement) -> Iterable[OriginRow]:
+ return map(OriginRow.from_dict, self._execute_with_retries(statement, []))
@_prepared_statement("SELECT next_visit_id FROM origin WHERE sha1 = ?")
def _origin_get_next_visit_id(self, origin_sha1: bytes, *, statement) -> int:
rows = list(self._execute_with_retries(statement, [origin_sha1]))
assert len(rows) == 1 # TODO: error handling
- return rows[0].next_visit_id
+ return rows[0]["next_visit_id"]
@_prepared_statement(
"UPDATE origin SET next_visit_id=? WHERE sha1 = ? IF next_visit_id=?"
@@ -640,12 +624,12 @@
)
)
assert len(res) == 1
- if res[0].applied:
+ if res[0]["[applied]"]:
# No data race
return next_id
else:
# Someone else updated it before we did, let's try again
- next_id = res[0].next_visit_id
+ next_id = res[0]["next_visit_id"]
# TODO: abort after too many attempts
return next_id
@@ -654,13 +638,6 @@
# 'origin_visit' table
##########################
- _origin_visit_keys = [
- "origin",
- "visit",
- "type",
- "date",
- ]
-
@_prepared_statement(
"SELECT * FROM origin_visit WHERE origin = ? AND visit > ? "
"ORDER BY visit ASC"
@@ -737,7 +714,7 @@
last_visit: Optional[int],
limit: Optional[int],
order: ListOrder,
- ) -> ResultSet:
+ ) -> Iterable[OriginVisitRow]:
args: List[Any] = [origin_url]
if last_visit is not None:
@@ -754,7 +731,7 @@
method_name = f"_origin_visit_get_{page_name}_{order.value}_{limit_name}"
origin_visit_get_method = getattr(self, method_name)
- return origin_visit_get_method(*args)
+ return map(OriginVisitRow.from_dict, origin_visit_get_method(*args))
@_prepared_statement(
"SELECT * FROM origin_visit_status WHERE origin = ? "
@@ -817,7 +794,7 @@
date_from: Optional[datetime.datetime],
limit: int,
order: ListOrder,
- ) -> ResultSet:
+ ) -> Iterable[OriginVisitStatusRow]:
args: List[Any] = [origin, visit]
if date_from is not None:
@@ -830,41 +807,27 @@
method_name = f"_origin_visit_status_get_with_{date_name}_{order.value}_limit"
origin_visit_status_get_method = getattr(self, method_name)
- return origin_visit_status_get_method(*args)
-
- @_prepared_insert_statement("origin_visit", _origin_visit_keys)
- def origin_visit_add_one(self, visit: OriginVisit, *, statement) -> None:
- self._add_one(statement, "origin_visit", visit, self._origin_visit_keys)
-
- _origin_visit_status_keys = [
- "origin",
- "visit",
- "date",
- "status",
- "snapshot",
- "metadata",
- ]
-
- @_prepared_insert_statement("origin_visit_status", _origin_visit_status_keys)
+ return map(
+ OriginVisitStatusRow.from_dict, origin_visit_status_get_method(*args)
+ )
+
+ @_prepared_insert_statement("origin_visit", OriginVisitRow.cols())
+ def origin_visit_add_one(self, visit: OriginVisitRow, *, statement) -> None:
+ self._add_one(statement, "origin_visit", visit)
+
+ @_prepared_insert_statement("origin_visit_status", OriginVisitStatusRow.cols())
def origin_visit_status_add_one(
- self, visit_update: OriginVisitStatus, *, statement
+ self, visit_update: OriginVisitStatusRow, *, statement
) -> None:
- assert self._origin_visit_status_keys[-1] == "metadata"
- keys = self._origin_visit_status_keys
+ self._add_one(statement, None, visit_update)
- metadata = json.dumps(
- dict(visit_update.metadata) if visit_update.metadata is not None else None
- )
- self._execute_with_retries(
- statement, [getattr(visit_update, key) for key in keys[:-1]] + [metadata]
- )
-
- def origin_visit_status_get_latest(self, origin: str, visit: int,) -> Optional[Row]:
+ def origin_visit_status_get_latest(
+ self, origin: str, visit: int,
+ ) -> Optional[OriginVisitStatusRow]:
"""Given an origin visit id, return its latest origin_visit_status
"""
- rows = self.origin_visit_status_get(origin, visit)
- return rows[0] if rows else None
+ return next(self.origin_visit_status_get(origin, visit), None)
@_prepared_statement(
"SELECT * FROM origin_visit_status "
@@ -879,36 +842,52 @@
require_snapshot: bool = False,
*,
statement,
- ) -> List[Row]:
+ ) -> Iterator[OriginVisitStatusRow]:
"""Return all origin visit statuses for a given visit
"""
- return list(self._execute_with_retries(statement, [origin, visit]))
+ return map(
+ OriginVisitStatusRow.from_dict,
+ self._execute_with_retries(statement, [origin, visit]),
+ )
@_prepared_statement("SELECT * FROM origin_visit WHERE origin = ? AND visit = ?")
def origin_visit_get_one(
self, origin_url: str, visit_id: int, *, statement
- ) -> Optional[Row]:
+ ) -> Optional[OriginVisitRow]:
# TODO: error handling
rows = list(self._execute_with_retries(statement, [origin_url, visit_id]))
if rows:
- return rows[0]
+ return OriginVisitRow.from_dict(rows[0])
else:
return None
@_prepared_statement("SELECT * FROM origin_visit WHERE origin = ?")
- def origin_visit_get_all(self, origin_url: str, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [origin_url])
+ def origin_visit_get_all(
+ self, origin_url: str, *, statement
+ ) -> Iterable[OriginVisitRow]:
+ return map(
+ OriginVisitRow.from_dict,
+ self._execute_with_retries(statement, [origin_url]),
+ )
@_prepared_statement("SELECT * FROM origin_visit WHERE token(origin) >= ?")
- def _origin_visit_iter_from(self, min_token: int, *, statement) -> Iterator[Row]:
- yield from self._execute_with_retries(statement, [min_token])
+ def _origin_visit_iter_from(
+ self, min_token: int, *, statement
+ ) -> Iterable[OriginVisitRow]:
+ return map(
+ OriginVisitRow.from_dict, self._execute_with_retries(statement, [min_token])
+ )
@_prepared_statement("SELECT * FROM origin_visit WHERE token(origin) < ?")
- def _origin_visit_iter_to(self, max_token: int, *, statement) -> Iterator[Row]:
- yield from self._execute_with_retries(statement, [max_token])
+ def _origin_visit_iter_to(
+ self, max_token: int, *, statement
+ ) -> Iterable[OriginVisitRow]:
+ return map(
+ OriginVisitRow.from_dict, self._execute_with_retries(statement, [max_token])
+ )
- def origin_visit_iter(self, start_token: int) -> Iterator[Row]:
+ def origin_visit_iter(self, start_token: int) -> Iterator[OriginVisitRow]:
"""Returns all origin visits in order from this token,
and wraps around the token space."""
yield from self._origin_visit_iter_from(start_token)
@@ -918,68 +897,49 @@
# 'metadata_authority' table
##########################
- _metadata_authority_keys = ["url", "type", "metadata"]
-
- @_prepared_insert_statement("metadata_authority", _metadata_authority_keys)
- def metadata_authority_add(self, url, type, metadata, *, statement):
- return self._execute_with_retries(statement, [url, type, metadata])
+ @_prepared_insert_statement("metadata_authority", MetadataAuthorityRow.cols())
+ def metadata_authority_add(self, authority: MetadataAuthorityRow, *, statement):
+ self._add_one(statement, None, authority)
@_prepared_statement("SELECT * from metadata_authority WHERE type = ? AND url = ?")
- def metadata_authority_get(self, type, url, *, statement) -> Optional[Row]:
- return next(iter(self._execute_with_retries(statement, [type, url])), None)
+ def metadata_authority_get(
+ self, type, url, *, statement
+ ) -> Optional[MetadataAuthorityRow]:
+ rows = list(self._execute_with_retries(statement, [type, url]))
+ if rows:
+ return MetadataAuthorityRow.from_dict(rows[0])
+ else:
+ return None
##########################
# 'metadata_fetcher' table
##########################
- _metadata_fetcher_keys = ["name", "version", "metadata"]
-
- @_prepared_insert_statement("metadata_fetcher", _metadata_fetcher_keys)
- def metadata_fetcher_add(self, name, version, metadata, *, statement):
- return self._execute_with_retries(statement, [name, version, metadata])
+ @_prepared_insert_statement("metadata_fetcher", MetadataFetcherRow.cols())
+ def metadata_fetcher_add(self, fetcher, *, statement):
+ self._add_one(statement, None, fetcher)
@_prepared_statement(
"SELECT * from metadata_fetcher WHERE name = ? AND version = ?"
)
- def metadata_fetcher_get(self, name, version, *, statement) -> Optional[Row]:
- return next(iter(self._execute_with_retries(statement, [name, version])), None)
+ def metadata_fetcher_get(
+ self, name, version, *, statement
+ ) -> Optional[MetadataFetcherRow]:
+ rows = list(self._execute_with_retries(statement, [name, version]))
+ if rows:
+ return MetadataFetcherRow.from_dict(rows[0])
+ else:
+ return None
#########################
# 'raw_extrinsic_metadata' table
#########################
- _raw_extrinsic_metadata_keys = [
- "type",
- "id",
- "authority_type",
- "authority_url",
- "discovery_date",
- "fetcher_name",
- "fetcher_version",
- "format",
- "metadata",
- "origin",
- "visit",
- "snapshot",
- "release",
- "revision",
- "path",
- "directory",
- ]
-
- @_prepared_statement(
- f"INSERT INTO raw_extrinsic_metadata "
- f" ({', '.join(_raw_extrinsic_metadata_keys)}) "
- f"VALUES ({', '.join('?' for _ in _raw_extrinsic_metadata_keys)})"
+ @_prepared_insert_statement(
+ "raw_extrinsic_metadata", RawExtrinsicMetadataRow.cols()
)
- def raw_extrinsic_metadata_add(
- self, statement, **kwargs,
- ):
- assert set(kwargs) == set(
- self._raw_extrinsic_metadata_keys
- ), f"Bad kwargs: {set(kwargs)}"
- params = [kwargs[key] for key in self._raw_extrinsic_metadata_keys]
- return self._execute_with_retries(statement, params,)
+ def raw_extrinsic_metadata_add(self, raw_extrinsic_metadata, *, statement):
+ self._add_one(statement, None, raw_extrinsic_metadata)
@_prepared_statement(
"SELECT * from raw_extrinsic_metadata "
@@ -993,9 +953,12 @@
after: datetime.datetime,
*,
statement,
- ):
- return self._execute_with_retries(
- statement, [id, authority_url, after, authority_type]
+ ) -> Iterable[RawExtrinsicMetadataRow]:
+ return map(
+ RawExtrinsicMetadataRow.from_dict,
+ self._execute_with_retries(
+ statement, [id, authority_url, after, authority_type]
+ ),
)
@_prepared_statement(
@@ -1013,17 +976,20 @@
after_fetcher_version: str,
*,
statement,
- ):
- return self._execute_with_retries(
- statement,
- [
- id,
- authority_type,
- authority_url,
- after_date,
- after_fetcher_name,
- after_fetcher_version,
- ],
+ ) -> Iterable[RawExtrinsicMetadataRow]:
+ return map(
+ RawExtrinsicMetadataRow.from_dict,
+ self._execute_with_retries(
+ statement,
+ [
+ id,
+ authority_type,
+ authority_url,
+ after_date,
+ after_fetcher_name,
+ after_fetcher_version,
+ ],
+ ),
)
@_prepared_statement(
@@ -1032,9 +998,10 @@
)
def raw_extrinsic_metadata_get(
self, id: str, authority_type: str, authority_url: str, *, statement
- ) -> Iterable[Row]:
- return self._execute_with_retries(
- statement, [id, authority_url, authority_type]
+ ) -> Iterable[RawExtrinsicMetadataRow]:
+ return map(
+ RawExtrinsicMetadataRow.from_dict,
+ self._execute_with_retries(statement, [id, authority_url, authority_type]),
)
##########################
@@ -1045,8 +1012,6 @@
def check_read(self, *, statement):
self._execute_with_retries(statement, [])
- @_prepared_statement(
- "SELECT object_type, count FROM object_count WHERE partition_key=0"
- )
+ @_prepared_statement("SELECT * FROM object_count WHERE partition_key=0")
def stat_counters(self, *, statement) -> ResultSet:
- return self._execute_with_retries(statement, [])
+ return map(ObjectCountRow.from_dict, self._execute_with_retries(statement, []))
diff --git a/swh/storage/cassandra/model.py b/swh/storage/cassandra/model.py
new file mode 100644
--- /dev/null
+++ b/swh/storage/cassandra/model.py
@@ -0,0 +1,196 @@
+# Copyright (C) 2020 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+"""Classes representing tables in the Cassandra database.
+
+They are very close to classes found in swh.model.model, but most of
+them are subtly different:
+
+* Large objects are split into other classes (eg. RevisionRow has no
+ 'parents' field, because parents are stored in a different table,
+ represented by RevisionParentRow)
+* They have a "cols" field, which returns the list of column names
+ of the table
+* They only use types that map directly to Cassandra's schema (ie. no enums)
+
+Therefore, this model doesn't reuse swh.model.model, except for types
+that can be mapped to UDTs (Person and TimestampWithTimezone).
+"""
+
+import dataclasses
+import datetime
+from typing import Any, Dict, List, Optional, Type, TypeVar
+
+from swh.model.model import Person, TimestampWithTimezone
+
+
+T = TypeVar("T", bound="BaseRow")
+
+
+class BaseRow:
+ @classmethod
+ def from_dict(cls: Type[T], d: Dict[str, Any]) -> T:
+ return cls(**d) # type: ignore
+
+ @classmethod
+ def cols(cls) -> List[str]:
+ return [field.name for field in dataclasses.fields(cls)]
+
+ def to_dict(self) -> Dict[str, Any]:
+ return dataclasses.asdict(self)
+
+
+@dataclasses.dataclass
+class ContentRow(BaseRow):
+ sha1: bytes
+ sha1_git: bytes
+ sha256: bytes
+ blake2s256: bytes
+ length: int
+ ctime: datetime.datetime
+ status: str
+
+
+@dataclasses.dataclass
+class SkippedContentRow(BaseRow):
+ sha1: Optional[bytes]
+ sha1_git: Optional[bytes]
+ sha256: Optional[bytes]
+ blake2s256: Optional[bytes]
+ length: Optional[int]
+ ctime: Optional[datetime.datetime]
+ status: str
+ reason: str
+ origin: str
+
+
+@dataclasses.dataclass
+class DirectoryRow(BaseRow):
+ id: bytes
+
+
+@dataclasses.dataclass
+class DirectoryEntryRow(BaseRow):
+ directory_id: bytes
+ name: bytes
+ target: bytes
+ perms: int
+ type: str
+
+
+@dataclasses.dataclass
+class RevisionRow(BaseRow):
+ id: bytes
+ date: Optional[TimestampWithTimezone]
+ committer_date: Optional[TimestampWithTimezone]
+ type: str
+ directory: bytes
+ message: bytes
+ author: Person
+ committer: Person
+ synthetic: bool
+ metadata: str
+ extra_headers: dict
+
+
+@dataclasses.dataclass
+class RevisionParentRow(BaseRow):
+ id: bytes
+ parent_rank: int
+ parent_id: bytes
+
+
+@dataclasses.dataclass
+class ReleaseRow(BaseRow):
+ id: bytes
+ target_type: str
+ target: bytes
+ date: TimestampWithTimezone
+ name: bytes
+ message: bytes
+ author: Person
+ synthetic: bool
+
+
+@dataclasses.dataclass
+class SnapshotRow(BaseRow):
+ id: bytes
+
+
+@dataclasses.dataclass
+class SnapshotBranchRow(BaseRow):
+ snapshot_id: bytes
+ name: bytes
+ target_type: Optional[str]
+ target: Optional[bytes]
+
+
+@dataclasses.dataclass
+class OriginVisitRow(BaseRow):
+ origin: str
+ visit: int
+ date: datetime.datetime
+ type: str
+
+
+@dataclasses.dataclass
+class OriginVisitStatusRow(BaseRow):
+ origin: str
+ visit: int
+ date: datetime.datetime
+ status: str
+ metadata: str
+ snapshot: bytes
+
+
+@dataclasses.dataclass
+class OriginRow(BaseRow):
+ sha1: bytes
+ url: str
+ next_visit_id: int
+
+
+@dataclasses.dataclass
+class MetadataAuthorityRow(BaseRow):
+ url: str
+ type: str
+ metadata: str
+
+
+@dataclasses.dataclass
+class MetadataFetcherRow(BaseRow):
+ name: str
+ version: str
+ metadata: str
+
+
+@dataclasses.dataclass
+class RawExtrinsicMetadataRow(BaseRow):
+ type: str
+ id: str
+
+ authority_type: str
+ authority_url: str
+ discovery_date: datetime.datetime
+ fetcher_name: str
+ fetcher_version: str
+
+ format: str
+ metadata: bytes
+
+ origin: Optional[str]
+ visit: Optional[int]
+ snapshot: Optional[str]
+ release: Optional[str]
+ revision: Optional[str]
+ path: Optional[bytes]
+ directory: Optional[str]
+
+
+@dataclasses.dataclass
+class ObjectCountRow(BaseRow):
+ partition_key: int
+ object_type: str
+ count: int
diff --git a/swh/storage/cassandra/schema.py b/swh/storage/cassandra/schema.py
--- a/swh/storage/cassandra/schema.py
+++ b/swh/storage/cassandra/schema.py
@@ -181,7 +181,6 @@
CREATE TABLE IF NOT EXISTS origin (
sha1 blob PRIMARY KEY,
url text,
- type text,
next_visit_id int,
-- We need integer visit ids for compatibility with the pgsql
-- storage, so we're using lightweight transactions with this trick:
diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py
--- a/swh/storage/cassandra/storage.py
+++ b/swh/storage/cassandra/storage.py
@@ -48,10 +48,24 @@
from swh.storage.utils import map_optional, now
from ..exc import StorageArgumentException, HashCollision
-from .common import TOKEN_BEGIN, TOKEN_END
+from .common import TOKEN_BEGIN, TOKEN_END, hash_url, remove_keys
from . import converters
from .cql import CqlRunner
from .schema import HASH_ALGORITHMS
+from .model import (
+ ContentRow,
+ DirectoryEntryRow,
+ DirectoryRow,
+ MetadataAuthorityRow,
+ MetadataFetcherRow,
+ OriginRow,
+ OriginVisitRow,
+ RawExtrinsicMetadataRow,
+ RevisionParentRow,
+ SkippedContentRow,
+ SnapshotBranchRow,
+ SnapshotRow,
+)
# Max block size of contents to return
@@ -139,7 +153,9 @@
collisions.append(content.hashes())
raise HashCollision(algo, content.get_hash(algo), collisions)
- (token, insertion_finalizer) = self._cql_runner.content_add_prepare(content)
+ (token, insertion_finalizer) = self._cql_runner.content_add_prepare(
+ ContentRow(**remove_keys(content.to_dict(), ("data",)))
+ )
# Then add to index tables
for algo in HASH_ALGORITHMS:
@@ -203,14 +219,12 @@
range_start, range_end, limit + 1
)
contents = []
- last_id: Optional[int] = None
- for counter, row in enumerate(rows):
+ for counter, (tok, row) in enumerate(rows):
if row.status == "absent":
continue
- row_d = row._asdict()
- last_id = row_d.pop("tok")
+ row_d = row.to_dict()
if counter >= limit:
- next_page_token = str(last_id)
+ next_page_token = str(tok)
break
contents.append(Content(**row_d))
@@ -223,7 +237,7 @@
# Get all (sha1, sha1_git, sha256, blake2s256) whose sha1
# matches the argument, from the index table ('content_by_sha1')
for row in self._content_get_from_hash("sha1", sha1):
- row_d = row._asdict()
+ row_d = row.to_dict()
row_d.pop("ctime")
content = Content(**row_d)
contents_by_sha1[content.sha1] = content
@@ -251,7 +265,7 @@
break
else:
# All hashes match, keep this row.
- row_d = row._asdict()
+ row_d = row.to_dict()
row_d["ctime"] = row.ctime.replace(tzinfo=datetime.timezone.utc)
results.append(Content(**row_d))
return results
@@ -314,7 +328,7 @@
for content in contents:
# Compute token of the row in the main table
(token, insertion_finalizer) = self._cql_runner.skipped_content_add_prepare(
- content
+ SkippedContentRow.from_dict({"origin": None, **content.to_dict()})
)
# Then add to index tables
@@ -348,13 +362,13 @@
# Add directory entries to the 'directory_entry' table
for entry in directory.entries:
self._cql_runner.directory_entry_add_one(
- {**entry.to_dict(), "directory_id": directory.id}
+ DirectoryEntryRow(directory_id=directory.id, **entry.to_dict())
)
# Add the directory *after* adding all the entries, so someone
# calling snapshot_get_branch in the meantime won't end up
# with half the entries.
- self._cql_runner.directory_add_one(directory.id)
+ self._cql_runner.directory_add_one(DirectoryRow(id=directory.id))
return {"directory:add": len(directories)}
@@ -387,10 +401,10 @@
rows = list(self._cql_runner.directory_entry_get([directory_id]))
for row in rows:
+ entry_d = row.to_dict()
# Build and yield the directory entry dict
- entry = row._asdict()
- del entry["directory_id"]
- entry = DirectoryEntry.from_dict(entry)
+ del entry_d["directory_id"]
+ entry = DirectoryEntry.from_dict(entry_d)
ret = self._join_dentry_to_content(entry)
ret["name"] = prefix + ret["name"]
ret["dir_id"] = directory_id
@@ -458,9 +472,11 @@
revobject = converters.revision_to_db(revision)
if revobject:
# Add parents first
- for (rank, parent) in enumerate(revobject["parents"]):
+ for (rank, parent) in enumerate(revision.parents):
self._cql_runner.revision_parent_add_one(
- revobject["id"], rank, parent
+ RevisionParentRow(
+ id=revobject.id, parent_rank=rank, parent_id=parent
+ )
)
# Then write the main revision row.
@@ -484,10 +500,9 @@
# (it might have lower latency, but requires more code and more
# bandwidth, because revision id would be part of each returned
# row)
- parent_rows = self._cql_runner.revision_parent_get(row.id)
+ parents = tuple(self._cql_runner.revision_parent_get(row.id))
# parent_rank is the clustering key, so results are already
# sorted by rank.
- parents = tuple(row.parent_id for row in parent_rows)
rev = converters.revision_from_db(row, parents=parents)
revs[rev.id] = rev.to_dict()
@@ -514,27 +529,35 @@
# results (ie. not return only a subset of a revision's parents
# if it is being written)
if short:
- rows = self._cql_runner.revision_get_ids(rev_ids)
+ ids = self._cql_runner.revision_get_ids(rev_ids)
+ for id_ in ids:
+ # TODO: use a single query to get all parents?
+ # (it might have less latency, but requires less code and more
+ # bandwidth (because revision id would be part of each returned
+ # row)
+ parents = tuple(self._cql_runner.revision_parent_get(id_))
+
+ # parent_rank is the clustering key, so results are already
+ # sorted by rank.
+
+ yield (id_, parents)
+ yield from self._get_parent_revs(parents, seen, limit, short)
else:
rows = self._cql_runner.revision_get(rev_ids)
- for row in rows:
- # TODO: use a single query to get all parents?
- # (it might have less latency, but requires less code and more
- # bandwidth (because revision id would be part of each returned
- # row)
- parent_rows = self._cql_runner.revision_parent_get(row.id)
+ for row in rows:
+ # TODO: use a single query to get all parents?
+ # (it might have less latency, but requires less code and more
+ # bandwidth (because revision id would be part of each returned
+ # row)
+ parents = tuple(self._cql_runner.revision_parent_get(row.id))
- # parent_rank is the clustering key, so results are already
- # sorted by rank.
- parents = tuple(row.parent_id for row in parent_rows)
+ # parent_rank is the clustering key, so results are already
+ # sorted by rank.
- if short:
- yield (row.id, parents)
- else:
rev = converters.revision_from_db(row, parents=parents)
yield rev.to_dict()
- yield from self._get_parent_revs(parents, seen, limit, short)
+ yield from self._get_parent_revs(parents, seen, limit, short)
def revision_log(
self, revisions: List[Sha1Git], limit: Optional[int] = None
@@ -595,24 +618,24 @@
# Add branches
for (branch_name, branch) in snapshot.branches.items():
if branch is None:
- target_type = None
- target = None
+ target_type: Optional[str] = None
+ target: Optional[bytes] = None
else:
target_type = branch.target_type.value
target = branch.target
self._cql_runner.snapshot_branch_add_one(
- {
- "snapshot_id": snapshot.id,
- "name": branch_name,
- "target_type": target_type,
- "target": target,
- }
+ SnapshotBranchRow(
+ snapshot_id=snapshot.id,
+ name=branch_name,
+ target_type=target_type,
+ target=target,
+ )
)
# Add the snapshot *after* adding all the branches, so someone
# calling snapshot_get_branch in the meantime won't end up
# with half the branches.
- self._cql_runner.snapshot_add_one(snapshot.id)
+ self._cql_runner.snapshot_add_one(SnapshotRow(id=snapshot.id))
return {"snapshot:add": len(snapshots)}
@@ -647,13 +670,7 @@
# Makes sure we don't fetch branches for a snapshot that is
# being added.
return None
- rows = list(self._cql_runner.snapshot_count_branches(snapshot_id))
- assert len(rows) == 1
- (nb_none, counts) = rows[0].counts
- counts = dict(counts)
- if nb_none:
- counts[None] = nb_none
- return counts
+ return self._cql_runner.snapshot_count_branches(snapshot_id)
def snapshot_get_branches(
self,
@@ -760,8 +777,8 @@
def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]:
results = []
for sha1 in sha1s:
- rows = self._cql_runner.origin_get_by_sha1(sha1)
- origin = {"url": rows.one().url} if rows else None
+ rows = list(self._cql_runner.origin_get_by_sha1(sha1))
+ origin = {"url": rows[0].url} if rows else None
results.append(origin)
return results
@@ -778,10 +795,10 @@
origins = []
# Take one more origin so we can reuse it as the next page token if any
- for row in self._cql_runner.origin_list(start_token, limit + 1):
+ for (tok, row) in self._cql_runner.origin_list(start_token, limit + 1):
origins.append(Origin(url=row.url))
# keep reference of the last id for pagination purposes
- last_id = row.tok
+ last_id = tok
if len(origins) > limit:
# last origin id is the next page token
@@ -805,15 +822,17 @@
next_page_token = None
offset = int(page_token) if page_token else 0
- origins = self._cql_runner.origin_iter_all()
+ origin_rows = [row for row in self._cql_runner.origin_iter_all()]
if regexp:
pat = re.compile(url_pattern)
- origins = [Origin(orig.url) for orig in origins if pat.search(orig.url)]
+ origin_rows = [row for row in origin_rows if pat.search(row.url)]
else:
- origins = [Origin(orig.url) for orig in origins if url_pattern in orig.url]
+ origin_rows = [row for row in origin_rows if url_pattern in row.url]
if with_visit:
- origins = [Origin(orig.url) for orig in origins if orig.next_visit_id > 1]
+ origin_rows = [row for row in origin_rows if row.next_visit_id > 1]
+
+ origins = [Origin(url=row.url) for row in origin_rows]
origins = origins[offset : offset + limit + 1]
if len(origins) > limit:
@@ -829,7 +848,9 @@
to_add = [ori for ori in origins if self.origin_get_one(ori.url) is None]
self.journal_writer.origin_add(to_add)
for origin in to_add:
- self._cql_runner.origin_add_one(origin)
+ self._cql_runner.origin_add_one(
+ OriginRow(sha1=hash_url(origin.url), url=origin.url, next_visit_id=1)
+ )
return {"origin:add": len(to_add)}
def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]:
@@ -848,7 +869,7 @@
)
visit = attr.evolve(visit, visit=visit_id)
self.journal_writer.origin_visit_add([visit])
- self._cql_runner.origin_visit_add_one(visit)
+ self._cql_runner.origin_visit_add_one(OriginVisitRow(**visit.to_dict()))
assert visit.visit is not None
all_visits.append(visit)
self._origin_visit_status_add(
@@ -866,7 +887,9 @@
def _origin_visit_status_add(self, visit_status: OriginVisitStatus) -> None:
"""Add an origin visit status"""
self.journal_writer.origin_visit_status_add([visit_status])
- self._cql_runner.origin_visit_status_add_one(visit_status)
+ self._cql_runner.origin_visit_status_add_one(
+ converters.visit_status_to_row(visit_status)
+ )
def origin_visit_status_add(self, visit_statuses: List[OriginVisitStatus]) -> None:
# First round to check existence (fail early if any is ko)
@@ -912,7 +935,7 @@
@staticmethod
def _format_origin_visit_row(visit):
return {
- **visit._asdict(),
+ **visit.to_dict(),
"origin": visit.origin,
"date": visit.date.replace(tzinfo=datetime.timezone.utc),
}
@@ -1048,8 +1071,10 @@
f"Unknown allowed statuses {','.join(allowed_statuses)}, only "
f"{','.join(VISIT_STATUSES)} authorized"
)
- rows = self._cql_runner.origin_visit_status_get(
- origin_url, visit, allowed_statuses, require_snapshot
+ rows = list(
+ self._cql_runner.origin_visit_status_get(
+ origin_url, visit, allowed_statuses, require_snapshot
+ )
)
# filtering is done python side as we cannot do it server side
if allowed_statuses:
@@ -1113,7 +1138,7 @@
)
try:
- self._cql_runner.raw_extrinsic_metadata_add(
+ row = RawExtrinsicMetadataRow(
type=metadata_entry.type.value,
id=str(metadata_entry.id),
authority_type=metadata_entry.authority.type.value,
@@ -1131,6 +1156,7 @@
path=metadata_entry.path,
directory=map_optional(str, metadata_entry.directory),
)
+ self._cql_runner.raw_extrinsic_metadata_add(row)
except TypeError as e:
raise StorageArgumentException(*e.args)
@@ -1236,9 +1262,11 @@
self.journal_writer.metadata_fetcher_add(fetchers)
for fetcher in fetchers:
self._cql_runner.metadata_fetcher_add(
- fetcher.name,
- fetcher.version,
- json.dumps(map_optional(dict, fetcher.metadata)),
+ MetadataFetcherRow(
+ name=fetcher.name,
+ version=fetcher.version,
+ metadata=json.dumps(map_optional(dict, fetcher.metadata)),
+ )
)
def metadata_fetcher_get(
@@ -1258,9 +1286,11 @@
self.journal_writer.metadata_authority_add(authorities)
for authority in authorities:
self._cql_runner.metadata_authority_add(
- authority.url,
- authority.type.value,
- json.dumps(map_optional(dict, authority.metadata)),
+ MetadataAuthorityRow(
+ url=authority.url,
+ type=authority.type.value,
+ metadata=json.dumps(map_optional(dict, authority.metadata)),
+ )
)
def metadata_authority_get(
diff --git a/swh/storage/tests/test_cassandra.py b/swh/storage/tests/test_cassandra.py
--- a/swh/storage/tests/test_cassandra.py
+++ b/swh/storage/tests/test_cassandra.py
@@ -3,22 +3,22 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
-import attr
+import datetime
import os
import signal
import socket
import subprocess
import time
-
-from collections import namedtuple
from typing import Dict
+import attr
import pytest
from swh.core.api.classes import stream_results
from swh.storage import get_storage
from swh.storage.cassandra import create_keyspace
from swh.storage.cassandra.schema import TABLES, HASH_ALGORITHMS
+from swh.storage.cassandra.model import ContentRow
from swh.storage.utils import now
from swh.storage.tests.test_storage import TestStorage as _TestStorage
@@ -209,12 +209,17 @@
)
# For all tokens, always return cont
- Row = namedtuple("Row", HASH_ALGORITHMS)
-
def mock_cgft(token):
nonlocal called
called += 1
- return [Row(**{algo: getattr(cont, algo) for algo in HASH_ALGORITHMS})]
+ return [
+ ContentRow(
+ length=10,
+ ctime=datetime.datetime.now(),
+ status="present",
+ **{algo: getattr(cont, algo) for algo in HASH_ALGORITHMS},
+ )
+ ]
mocker.patch.object(
swh_storage._cql_runner, "content_get_from_token", mock_cgft
@@ -253,13 +258,12 @@
# For all tokens, always return cont and cont2
cols = list(set(cont.to_dict()) - {"data"})
- Row = namedtuple("Row", cols)
def mock_cgft(token):
nonlocal called
called += 1
return [
- Row(**{col: getattr(cont, col) for col in cols})
+ ContentRow(**{col: getattr(cont, col) for col in cols},)
for cont in [cont, cont2]
]
@@ -299,13 +303,12 @@
# For all tokens, always return cont and cont2
cols = list(set(cont.to_dict()) - {"data"})
- Row = namedtuple("Row", cols)
def mock_cgft(token):
nonlocal called
called += 1
return [
- Row(**{col: getattr(cont, col) for col in cols})
+ ContentRow(**{col: getattr(cont, col) for col in cols})
for cont in [cont, cont2]
]
@@ -342,16 +345,15 @@
rows[tok] = row_d
# For all tokens, always return cont
- keys = set(["tok"] + list(content.to_dict().keys())).difference(set(["data"]))
- Row = namedtuple("Row", keys)
def mock_content_get_token_range(range_start, range_end, limit):
nonlocal called
called += 1
for tok in list(rows.keys()) * 3: # yield multiple times the same tok
- row_d = rows[tok]
- yield Row(**row_d)
+ row_d = dict(rows[tok].items())
+ row_d.pop("tok")
+ yield (tok, ContentRow(**row_d))
mocker.patch.object(
swh_storage._cql_runner,
File Metadata
Details
Attached
Mime Type
text/plain
Expires
Thu, Jan 30, 3:51 PM (1 h, 34 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3224701
Attached To
D3756: cassandra.cql: Use static dataclasses instead of generating namedtuples on the fly.
Event Timeline
Log In to Comment