diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -137,16 +137,17 @@ def _prepared_insert_statement( - table_name: str, columns: List[str] + row_class: Type[BaseRow], ) -> Callable[ [Callable[[TSelf, TArg, NamedArg(Any, "statement")], TRet]], # noqa Callable[[TSelf, TArg], TRet], ]: """Shorthand for using `_prepared_statement` for `INSERT INTO` statements.""" + columns = row_class.cols() return _prepared_statement( "INSERT INTO %s (%s) VALUES (%s)" - % (table_name, ", ".join(columns), ", ".join("?" for _ in columns),) + % (row_class.TABLE, ", ".join(columns), ", ".join("?" for _ in columns),) ) @@ -201,9 +202,8 @@ ) -> None: self._execute_with_retries(statement, [nb, object_type]) - def _add_one(self, statement, object_type: Optional[str], obj: BaseRow) -> None: - if object_type: - self._increment_counter(object_type, 1) + def _add_one(self, statement, obj: BaseRow) -> None: + self._increment_counter(obj.TABLE, 1) self._execute_with_retries(statement, dataclasses.astuple(obj)) _T = TypeVar("_T", bound=BaseRow) @@ -240,7 +240,7 @@ self._execute_with_retries(statement, None) self._increment_counter("content", 1) - @_prepared_insert_statement("content", ContentRow.cols()) + @_prepared_insert_statement(ContentRow) def content_add_prepare( self, content: ContentRow, *, statement ) -> Tuple[int, Callable[[], None]]: @@ -357,7 +357,7 @@ self._execute_with_retries(statement, None) self._increment_counter("skipped_content", 1) - @_prepared_insert_statement("skipped_content", SkippedContentRow.cols()) + @_prepared_insert_statement(SkippedContentRow) def skipped_content_add_prepare( self, content, *, statement ) -> Tuple[int, Callable[[], None]]: @@ -439,9 +439,9 @@ def revision_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) - @_prepared_insert_statement("revision", RevisionRow.cols()) + @_prepared_insert_statement(RevisionRow) def revision_add_one(self, revision: RevisionRow, *, statement) -> None: - self._add_one(statement, "revision", revision) + self._add_one(statement, revision) @_prepared_statement("SELECT id FROM revision WHERE id IN ?") def revision_get_ids(self, revision_ids, *, statement) -> Iterable[int]: @@ -463,11 +463,11 @@ # 'revision_parent' table ########################## - @_prepared_insert_statement("revision_parent", RevisionParentRow.cols()) + @_prepared_insert_statement(RevisionParentRow) def revision_parent_add_one( self, revision_parent: RevisionParentRow, *, statement ) -> None: - self._add_one(statement, None, revision_parent) + self._add_one(statement, revision_parent) @_prepared_statement("SELECT parent_id FROM revision_parent WHERE id = ?") def revision_parent_get( @@ -486,9 +486,9 @@ def release_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) - @_prepared_insert_statement("release", ReleaseRow.cols()) + @_prepared_insert_statement(ReleaseRow) def release_add_one(self, release: ReleaseRow, *, statement) -> None: - self._add_one(statement, "release", release) + self._add_one(statement, release) @_prepared_statement("SELECT * FROM release WHERE id in ?") def release_get(self, release_ids: List[str], *, statement) -> Iterable[ReleaseRow]: @@ -508,11 +508,11 @@ def directory_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) - @_prepared_insert_statement("directory", DirectoryRow.cols()) + @_prepared_insert_statement(DirectoryRow) def directory_add_one(self, directory: DirectoryRow, *, statement) -> None: """Called after all calls to directory_entry_add_one, to commit/finalize the directory.""" - self._add_one(statement, "directory", directory) + self._add_one(statement, directory) @_prepared_statement("SELECT * FROM directory WHERE token(id) > ? LIMIT 1") def directory_get_random(self, *, statement) -> Optional[DirectoryRow]: @@ -522,9 +522,9 @@ # 'directory_entry' table ########################## - @_prepared_insert_statement("directory_entry", DirectoryEntryRow.cols()) + @_prepared_insert_statement(DirectoryEntryRow) def directory_entry_add_one(self, entry: DirectoryEntryRow, *, statement) -> None: - self._add_one(statement, None, entry) + self._add_one(statement, entry) @_prepared_statement("SELECT * FROM directory_entry WHERE directory_id IN ?") def directory_entry_get( @@ -543,9 +543,9 @@ def snapshot_missing(self, ids: List[bytes], *, statement) -> List[bytes]: return self._missing(statement, ids) - @_prepared_insert_statement("snapshot", SnapshotRow.cols()) + @_prepared_insert_statement(SnapshotRow) def snapshot_add_one(self, snapshot: SnapshotRow, *, statement) -> None: - self._add_one(statement, "snapshot", snapshot) + self._add_one(statement, snapshot) @_prepared_statement("SELECT * FROM snapshot WHERE id = ?") def snapshot_get(self, snapshot_id: Sha1Git, *, statement) -> ResultSet: @@ -561,9 +561,9 @@ # 'snapshot_branch' table ########################## - @_prepared_insert_statement("snapshot_branch", SnapshotBranchRow.cols()) + @_prepared_insert_statement(SnapshotBranchRow) def snapshot_branch_add_one(self, branch: SnapshotBranchRow, *, statement) -> None: - self._add_one(statement, None, branch) + self._add_one(statement, branch) @_prepared_statement( "SELECT ascii_bins_count(target_type) AS counts " @@ -594,9 +594,9 @@ # 'origin' table ########################## - @_prepared_insert_statement("origin", OriginRow.cols()) + @_prepared_insert_statement(OriginRow) def origin_add_one(self, origin: OriginRow, *, statement) -> None: - self._add_one(statement, "origin", origin) + self._add_one(statement, origin) @_prepared_statement("SELECT * FROM origin WHERE sha1 = ?") def origin_get_by_sha1(self, sha1: bytes, *, statement) -> Iterable[OriginRow]: @@ -828,15 +828,15 @@ OriginVisitStatusRow.from_dict, origin_visit_status_get_method(*args) ) - @_prepared_insert_statement("origin_visit", OriginVisitRow.cols()) + @_prepared_insert_statement(OriginVisitRow) def origin_visit_add_one(self, visit: OriginVisitRow, *, statement) -> None: - self._add_one(statement, "origin_visit", visit) + self._add_one(statement, visit) - @_prepared_insert_statement("origin_visit_status", OriginVisitStatusRow.cols()) + @_prepared_insert_statement(OriginVisitStatusRow) def origin_visit_status_add_one( self, visit_update: OriginVisitStatusRow, *, statement ) -> None: - self._add_one(statement, None, visit_update) + self._add_one(statement, visit_update) def origin_visit_status_get_latest( self, origin: str, visit: int, @@ -914,9 +914,9 @@ # 'metadata_authority' table ########################## - @_prepared_insert_statement("metadata_authority", MetadataAuthorityRow.cols()) + @_prepared_insert_statement(MetadataAuthorityRow) def metadata_authority_add(self, authority: MetadataAuthorityRow, *, statement): - self._add_one(statement, None, authority) + self._add_one(statement, authority) @_prepared_statement("SELECT * from metadata_authority WHERE type = ? AND url = ?") def metadata_authority_get( @@ -932,9 +932,9 @@ # 'metadata_fetcher' table ########################## - @_prepared_insert_statement("metadata_fetcher", MetadataFetcherRow.cols()) + @_prepared_insert_statement(MetadataFetcherRow) def metadata_fetcher_add(self, fetcher, *, statement): - self._add_one(statement, None, fetcher) + self._add_one(statement, fetcher) @_prepared_statement( "SELECT * from metadata_fetcher WHERE name = ? AND version = ?" @@ -952,11 +952,9 @@ # 'raw_extrinsic_metadata' table ######################### - @_prepared_insert_statement( - "raw_extrinsic_metadata", RawExtrinsicMetadataRow.cols() - ) + @_prepared_insert_statement(RawExtrinsicMetadataRow) def raw_extrinsic_metadata_add(self, raw_extrinsic_metadata, *, statement): - self._add_one(statement, None, raw_extrinsic_metadata) + self._add_one(statement, raw_extrinsic_metadata) @_prepared_statement( "SELECT * from raw_extrinsic_metadata " diff --git a/swh/storage/cassandra/model.py b/swh/storage/cassandra/model.py --- a/swh/storage/cassandra/model.py +++ b/swh/storage/cassandra/model.py @@ -21,7 +21,7 @@ import dataclasses import datetime -from typing import Any, Dict, List, Optional, Type, TypeVar +from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar from swh.model.model import Person, TimestampWithTimezone @@ -30,6 +30,8 @@ class BaseRow: + TABLE: ClassVar[str] + @classmethod def from_dict(cls: Type[T], d: Dict[str, Any]) -> T: return cls(**d) # type: ignore @@ -44,6 +46,8 @@ @dataclasses.dataclass class ContentRow(BaseRow): + TABLE = "content" + sha1: bytes sha1_git: bytes sha256: bytes @@ -55,6 +59,8 @@ @dataclasses.dataclass class SkippedContentRow(BaseRow): + TABLE = "skipped_content" + sha1: Optional[bytes] sha1_git: Optional[bytes] sha256: Optional[bytes] @@ -68,11 +74,15 @@ @dataclasses.dataclass class DirectoryRow(BaseRow): + TABLE = "directory" + id: bytes @dataclasses.dataclass class DirectoryEntryRow(BaseRow): + TABLE = "directory_entry" + directory_id: bytes name: bytes target: bytes @@ -82,6 +92,8 @@ @dataclasses.dataclass class RevisionRow(BaseRow): + TABLE = "revision" + id: bytes date: Optional[TimestampWithTimezone] committer_date: Optional[TimestampWithTimezone] @@ -97,6 +109,8 @@ @dataclasses.dataclass class RevisionParentRow(BaseRow): + TABLE = "revision_parent" + id: bytes parent_rank: int parent_id: bytes @@ -104,6 +118,8 @@ @dataclasses.dataclass class ReleaseRow(BaseRow): + TABLE = "release" + id: bytes target_type: str target: bytes @@ -116,11 +132,15 @@ @dataclasses.dataclass class SnapshotRow(BaseRow): + TABLE = "snapshot" + id: bytes @dataclasses.dataclass class SnapshotBranchRow(BaseRow): + TABLE = "snapshot_branch" + snapshot_id: bytes name: bytes target_type: Optional[str] @@ -129,6 +149,8 @@ @dataclasses.dataclass class OriginVisitRow(BaseRow): + TABLE = "origin_visit" + origin: str visit: int date: datetime.datetime @@ -137,6 +159,8 @@ @dataclasses.dataclass class OriginVisitStatusRow(BaseRow): + TABLE = "origin_visit_status" + origin: str visit: int date: datetime.datetime @@ -147,6 +171,8 @@ @dataclasses.dataclass class OriginRow(BaseRow): + TABLE = "origin" + sha1: bytes url: str next_visit_id: int @@ -154,6 +180,8 @@ @dataclasses.dataclass class MetadataAuthorityRow(BaseRow): + TABLE = "metadata_authority" + url: str type: str metadata: str @@ -161,6 +189,8 @@ @dataclasses.dataclass class MetadataFetcherRow(BaseRow): + TABLE = "metadata_fetcher" + name: str version: str metadata: str @@ -168,6 +198,8 @@ @dataclasses.dataclass class RawExtrinsicMetadataRow(BaseRow): + TABLE = "raw_extrinsic_metadata" + type: str id: str