diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -23,6 +23,7 @@ Type, TypeVar, Union, + cast, ) from cassandra import ConsistencyLevel, CoordinationFailure @@ -48,7 +49,7 @@ TimestampWithTimezone, ) from swh.model.swhids import CoreSWHID -from swh.storage.interface import ListOrder +from swh.storage.interface import ListOrder, TotalHashDict from ..utils import remove_keys from .common import TOKEN_BEGIN, TOKEN_END, hash_url @@ -384,11 +385,12 @@ ContentRow, f"WHERE {' AND '.join(map('%s = ?'.__mod__, HASH_ALGORITHMS))}" ) def content_get_from_pk( - self, content_hashes: Dict[str, bytes], *, statement + self, content_hashes: TotalHashDict, *, statement ) -> Optional[ContentRow]: rows = list( self._execute_with_retries( - statement, [content_hashes[algo] for algo in HASH_ALGORITHMS] + statement, + [cast(dict, content_hashes)[algo] for algo in HASH_ALGORITHMS], ) ) assert len(rows) <= 1 @@ -398,8 +400,8 @@ return None def content_missing_from_all_hashes( - self, contents_hashes: List[Dict[str, bytes]] - ) -> Iterator[Dict[str, bytes]]: + self, contents_hashes: List[TotalHashDict] + ) -> Iterator[TotalHashDict]: for group in grouper(contents_hashes, PARTITION_KEY_RESTRICTION_MAX_SIZE): group = list(group) diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -22,6 +22,7 @@ Set, Tuple, Union, + cast, ) import attr @@ -53,11 +54,13 @@ from swh.model.swhids import ObjectType as SwhidObjectType from swh.storage.interface import ( VISIT_STATUSES, + HashDict, ListOrder, OriginVisitWithStatuses, PagedResult, PartialBranches, Sha1, + TotalHashDict, ) from swh.storage.objstorage import ObjStorage from swh.storage.utils import map_optional, now @@ -336,10 +339,10 @@ contents_by_hash[key(content)] = content return [contents_by_hash.get(hash_) for hash_ in contents] - def content_find(self, content: Dict[str, Any]) -> List[Content]: + def content_find(self, content: HashDict) -> List[Content]: return self._content_find_many([content]) - def _content_find_many(self, contents: List[Dict[str, Any]]) -> List[Content]: + def _content_find_many(self, contents: List[HashDict]) -> List[Content]: # Find an algorithm that is common to all the requested contents. # It will be used to do an initial filtering efficiently. # TODO: prioritize sha256, we can do more efficient lookups from this hash. @@ -355,14 +358,16 @@ results = [] rows = self._content_get_from_hashes( - common_algo, [content[common_algo] for content in contents] + common_algo, + [content[common_algo] for content in cast(List[dict], contents)], ) for row in rows: # Re-check all the hashes, in case of collisions (either of the # hash of the partition key, or the hashes in it) for content in contents: for algo in HASH_ALGORITHMS: - if content.get(algo) and getattr(row, algo) != content[algo]: + hash_ = content.get(algo) + if hash_ and getattr(row, algo) != hash_: # This hash didn't match; discard the row. break else: @@ -377,15 +382,15 @@ return results def content_missing( - self, contents: List[Dict[str, Any]], key_hash: str = "sha1" + self, contents: List[HashDict], key_hash: str = "sha1" ) -> Iterable[bytes]: if key_hash not in DEFAULT_ALGORITHMS: raise StorageArgumentException( "key_hash should be one of {','.join(DEFAULT_ALGORITHMS)}" ) - contents_with_all_hashes = [] - contents_with_missing_hashes = [] + contents_with_all_hashes: List[TotalHashDict] = [] + contents_with_missing_hashes: List[HashDict] = [] for content in contents: if DEFAULT_ALGORITHMS <= set(content): contents_with_all_hashes.append(content) @@ -396,7 +401,7 @@ for content in self._cql_runner.content_missing_from_all_hashes( contents_with_all_hashes ): - yield content[key_hash] + yield content[key_hash] # type: ignore if contents_with_missing_hashes: # For these, we need the expensive index lookups + main table. @@ -426,7 +431,7 @@ # missing_content. (its length is at most 1, unless there is a # collision) found_contents_with_same_hash = found_contents_by_hash[algo][ - missing_content[algo] + missing_content[algo] # type: ignore ] # Check if there is a found_content that matches all hashes in the @@ -444,7 +449,7 @@ break else: # Not found - yield missing_content[key_hash] + yield missing_content[key_hash] # type: ignore def content_missing_per_sha1(self, contents: List[bytes]) -> Iterable[bytes]: return self.content_missing([{"sha1": c} for c in contents]) diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -55,6 +55,17 @@ the snapshot has less than the request number of branches.""" +class HashDict(TypedDict, total=False): + sha1: bytes + sha1_git: bytes + sha256: bytes + blake2s256: bytes + + +class TotalHashDict(HashDict, total=True): + pass + + @attr.s class OriginVisitWithStatuses: visit = attr.ib(type=OriginVisit) @@ -228,7 +239,7 @@ @remote_api_endpoint("content/missing") def content_missing( - self, contents: List[Dict[str, Any]], key_hash: str = "sha1" + self, contents: List[HashDict], key_hash: str = "sha1" ) -> Iterable[bytes]: """List content missing from storage @@ -280,7 +291,7 @@ ... @remote_api_endpoint("content/present") - def content_find(self, content: Dict[str, Any]) -> List[Content]: + def content_find(self, content: HashDict) -> List[Content]: """Find a content hash in db. Args: diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -48,6 +48,7 @@ from swh.storage.exc import HashCollision, StorageArgumentException, StorageDBError from swh.storage.interface import ( VISIT_STATUSES, + HashDict, ListOrder, OriginVisitWithStatuses, PagedResult, @@ -402,7 +403,7 @@ @db_transaction_generator() def content_missing( self, - contents: List[Dict[str, Any]], + contents: List[HashDict], key_hash: str = "sha1", *, db: Db, @@ -434,9 +435,7 @@ yield obj[0] @db_transaction() - def content_find( - self, content: Dict[str, Any], *, db: Db, cur=None - ) -> List[Content]: + def content_find(self, content: HashDict, *, db: Db, cur=None) -> List[Content]: if not set(content).intersection(DEFAULT_ALGORITHMS): raise StorageArgumentException( "content keys must contain at least one " diff --git a/swh/storage/proxies/filter.py b/swh/storage/proxies/filter.py --- a/swh/storage/proxies/filter.py +++ b/swh/storage/proxies/filter.py @@ -4,7 +4,7 @@ # See top-level LICENSE file for more information -from typing import Dict, Iterable, List, Set +from typing import Dict, Iterable, List, Set, cast from swh.model.model import ( Content, @@ -15,7 +15,7 @@ SkippedContent, ) from swh.storage import get_storage -from swh.storage.interface import StorageInterface +from swh.storage.interface import HashDict, StorageInterface class FilteringProxyStorage: @@ -108,7 +108,7 @@ """ missing_contents = [] for content in contents: - missing_contents.append(content.hashes()) + missing_contents.append(cast(HashDict, content.hashes())) return set( self.storage.content_missing( diff --git a/swh/storage/tests/test_cassandra_migration.py b/swh/storage/tests/test_cassandra_migration.py --- a/swh/storage/tests/test_cassandra_migration.py +++ b/swh/storage/tests/test_cassandra_migration.py @@ -123,7 +123,7 @@ attr.evolve(StorageData.content, data=None) ] with pytest.raises(StorageArgumentException): - swh_storage.content_find({"byte_xor": content_xor_hash}) + swh_storage.content_find({"byte_xor": content_xor_hash}) # type: ignore # Then update the running code: new_hash_algos = HASH_ALGORITHMS + ["byte_xor"] @@ -158,12 +158,12 @@ ] # The new algo does not work, we did not backfill it yet: - assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [] + assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [] # type: ignore # A normal storage would not overwrite, because the object already exists, # as it is not aware it is missing a field: swh_storage.content_add([new_content, new_content2]) - assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [] + assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [] # type: ignore # Backfill (in production this would be done with a replayer reading from # the journal): @@ -173,7 +173,7 @@ overwriting_swh_storage.content_add([new_content, new_content2]) # Now, the object can be found: - assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [ + assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [ # type: ignore attr.evolve(new_content, data=None) ] @@ -270,7 +270,7 @@ attr.evolve(StorageData.content, data=None) ] with pytest.raises(StorageArgumentException): - swh_storage.content_find({"byte_xor": content_xor_hash}) + swh_storage.content_find({"byte_xor": content_xor_hash}) # type: ignore # Then update the running code: new_hash_algos = HASH_ALGORITHMS + ["byte_xor"] @@ -319,7 +319,7 @@ swh_storage._set_cql_runner() # Now, the object can be found with the new hash: - assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [ + assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [ # type: ignore attr.evolve(new_content, data=None) ] @@ -327,7 +327,7 @@ swh_storage._cql_runner._session.execute("DROP TABLE content") # Object is still available, because we don't use it anymore - assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [ + assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [ # type: ignore attr.evolve(new_content, data=None) ]