diff --git a/requirements.txt b/requirements.txt --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ cassandra-driver >= 3.19.0, != 3.21.0 deprecated typing-extensions +mypy_extensions diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -31,6 +31,7 @@ wait_random_exponential, retry_if_exception_type, ) +from mypy_extensions import NamedArg from swh.model.model import ( Content, @@ -103,10 +104,12 @@ session.execute(query) -T = TypeVar("T") +TRet = TypeVar("TRet") -def _prepared_statement(query: str) -> Callable[[Callable[..., T]], Callable[..., T]]: +def _prepared_statement( + query: str, +) -> Callable[[Callable[..., TRet]], Callable[..., TRet]]: """Returns a decorator usable on methods of CqlRunner, to inject them with a 'statement' argument, that is a prepared statement corresponding to the query. @@ -116,7 +119,7 @@ def decorator(f): @functools.wraps(f) - def newf(self, *args, **kwargs) -> T: + def newf(self, *args, **kwargs) -> TRet: if f.__name__ not in self._prepared_statements: statement: PreparedStatement = self._session.prepare(query) self._prepared_statements[f.__name__] = statement @@ -129,7 +132,16 @@ return decorator -def _prepared_insert_statement(table_name: str, columns: List[str]): +TArg = TypeVar("TArg") +TSelf = TypeVar("TSelf") + + +def _prepared_insert_statement( + table_name: str, columns: List[str] +) -> Callable[ + [Callable[[TSelf, TArg, NamedArg(Any, "statement")], TRet]], # noqa + Callable[[TSelf, TArg], TRet], +]: """Shorthand for using `_prepared_statement` for `INSERT INTO` statements.""" return _prepared_statement( @@ -138,7 +150,12 @@ ) -def _prepared_exists_statement(table_name: str): +def _prepared_exists_statement( + table_name: str, +) -> Callable[ + [Callable[[TSelf, TArg, NamedArg(Any, "statement")], TRet]], # noqa + Callable[[TSelf, TArg], TRet], +]: """Shorthand for using `_prepared_statement` for queries that only check which ids in a list exist in the table.""" return _prepared_statement(f"SELECT id FROM {table_name} WHERE id IN ?") diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -9,7 +9,7 @@ import json import random import re -from typing import Any, Dict, List, Iterable, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Iterable, Optional, Set, Tuple, Union import attr @@ -74,9 +74,9 @@ class CassandraStorage: def __init__(self, hosts, keyspace, objstorage, port=9042, journal_writer=None): - self._cql_runner = CqlRunner(hosts, keyspace, port) - self.journal_writer = JournalWriter(journal_writer) - self.objstorage = ObjStorage(objstorage) + self._cql_runner: CqlRunner = CqlRunner(hosts, keyspace, port) + self.journal_writer: JournalWriter = JournalWriter(journal_writer) + self.objstorage: ObjStorage = ObjStorage(objstorage) def check_config(self, *, check_write: bool) -> bool: self._cql_runner.check_read() @@ -92,6 +92,7 @@ found_tokens = self._cql_runner.content_get_tokens_from_single_hash(algo, hash_) for token in found_tokens: + assert isinstance(token, int), found_tokens # Query the main table ('content'). res = self._cql_runner.content_get_from_token(token) @@ -294,7 +295,9 @@ ) def content_get_random(self) -> Sha1Git: - return self._cql_runner.content_get_random().sha1_git + content = self._cql_runner.content_get_random() + assert content, "Could not find any content" + return content.sha1_git def _skipped_content_add(self, contents: List[SkippedContent]) -> Dict: # Filter-out content already in the database. @@ -441,7 +444,9 @@ yield from self._directory_ls(directory, recursive) def directory_get_random(self) -> Sha1Git: - return self._cql_runner.directory_get_random().id + directory = self._cql_runner.directory_get_random() + assert directory, "Could not find any directory" + return directory.id def revision_add(self, revisions: List[Revision]) -> Dict: # Filter-out revisions already in the database @@ -553,7 +558,9 @@ yield from self._get_parent_revs(revisions, seen, limit, True) def revision_get_random(self) -> Sha1Git: - return self._cql_runner.revision_get_random().id + revision = self._cql_runner.revision_get_random() + assert revision, "Could not find any revision" + return revision.id def release_add(self, releases: List[Release]) -> Dict: to_add = [] @@ -587,7 +594,9 @@ yield rels.get(rel_id) def release_get_random(self) -> Sha1Git: - return self._cql_runner.release_get_random().id + release = self._cql_runner.release_get_random() + assert release, "Could not find any release" + return release.id def snapshot_add(self, snapshots: List[Snapshot]) -> Dict: missing = self._cql_runner.snapshot_missing([snp.id for snp in snapshots]) @@ -714,7 +723,9 @@ ) def snapshot_get_random(self) -> Sha1Git: - return self._cql_runner.snapshot_get_random().id + snapshot = self._cql_runner.snapshot_get_random() + assert snapshot, "Could not find any snapshot" + return snapshot.id def object_find_by_sha1_git(self, ids: List[Sha1Git]) -> Dict[Sha1Git, List[Dict]]: results: Dict[Sha1Git, List[Dict]] = {id_: [] for id_ in ids} @@ -722,7 +733,7 @@ # Mind the order, revision is the most likely one for a given ID, # so we check revisions first. - queries = [ + queries: List[Tuple[str, Callable[[List[Sha1Git]], List[Sha1Git]]]] = [ ("revision", self._cql_runner.revision_missing), ("release", self._cql_runner.release_missing), ("content", self._cql_runner.content_missing_by_sha1_git), @@ -730,7 +741,7 @@ ] for (object_type, query_fn) in queries: - found_ids = missing_ids - set(query_fn(missing_ids)) + found_ids = missing_ids - set(query_fn(list(missing_ids))) for sha1_git in found_ids: results[sha1_git].append( {"sha1_git": sha1_git, "type": object_type,} @@ -910,6 +921,7 @@ """Retrieve the latest visit status information for the origin visit object. """ + assert visit.visit row = self._cql_runner.origin_visit_status_get_latest(visit.origin, visit.visit) assert row is not None visit_status = converters.row_to_visit_status(row) @@ -936,7 +948,7 @@ raise StorageArgumentException("page_token must be a string.") next_page_token = None - visit_from = page_token and int(page_token) + visit_from = None if page_token is None else int(page_token) visits: List[OriginVisit] = [] extra_limit = limit + 1