diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -9,7 +9,7 @@ import json import random import re -from typing import Any, Dict, List, Iterable, Optional, Tuple, Union +from typing import Any, Dict, List, Iterable, Optional, Sequence, Tuple, Union import attr @@ -215,14 +215,14 @@ result[content_metadata["sha1"]].append(content_metadata) return result - def content_find(self, content): + def content_find(self, content: Dict[str, Any]) -> Sequence[Content]: # Find an algorithm that is common to all the requested contents. # It will be used to do an initial filtering efficiently. filter_algos = list(set(content).intersection(HASH_ALGORITHMS)) if not filter_algos: raise StorageArgumentException( - "content keys must contain at least one of: " - "%s" % ", ".join(sorted(HASH_ALGORITHMS)) + "content keys must contain at least one " + f"of: {', '.join(sorted(HASH_ALGORITHMS))}" ) common_algo = filter_algos[0] @@ -237,12 +237,9 @@ break else: # All hashes match, keep this row. - results.append( - { - **row._asdict(), - "ctime": row.ctime.replace(tzinfo=datetime.timezone.utc), - } - ) + row_d = row._asdict() + row_d["ctime"] = row.ctime.replace(tzinfo=datetime.timezone.utc) + results.append(Content(**row_d)) return results def content_missing(self, content, key_hash="sha1"): @@ -250,8 +247,6 @@ res = self.content_find(cont) if not res: yield cont[key_hash] - if any(c["status"] == "missing" for c in res): - yield cont[key_hash] def content_missing_per_sha1(self, contents): return self.content_missing([{"sha1": c for c in contents}]) @@ -341,7 +336,7 @@ def directory_missing(self, directories): return self._cql_runner.directory_missing(directories) - def _join_dentry_to_content(self, dentry): + def _join_dentry_to_content(self, dentry: DirectoryEntry) -> Dict[str, Any]: keys = ( "status", "sha1", @@ -352,11 +347,11 @@ ret = dict.fromkeys(keys) ret.update(dentry.to_dict()) if ret["type"] == "file": - content = self.content_find({"sha1_git": ret["target"]}) - if content: - content = content[0] + contents = self.content_find({"sha1_git": ret["target"]}) + if contents: + content = contents[0] for key in keys: - ret[key] = content[key] + ret[key] = getattr(content, key) return ret def _directory_ls(self, directory_id, recursive, prefix=b""): diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -293,7 +293,12 @@ ] def content_find( - self, sha1=None, sha1_git=None, sha256=None, blake2s256=None, cur=None + self, + sha1: Optional[bytes] = None, + sha1_git: Optional[bytes] = None, + sha256: Optional[bytes] = None, + blake2s256: Optional[bytes] = None, + cur=None, ): """Find the content optionally on a combination of the following checksums sha1, sha1_git, sha256 or blake2s256. @@ -316,21 +321,21 @@ "sha256": sha256, "blake2s256": blake2s256, } + + query_parts = [ + f"SELECT {','.join(self.content_find_cols)} " "FROM content " "WHERE " + ] + query_params = [] where_parts = [] - args = [] - # Adds only those keys which have value other than None + # Adds only those keys which have values exist for algorithm in checksum_dict: if checksum_dict[algorithm] is not None: - args.append(checksum_dict[algorithm]) - where_parts.append(algorithm + "= %s") - query = " AND ".join(where_parts) - cur.execute( - """SELECT %s - FROM content WHERE %s - """ - % (",".join(self.content_find_cols), query), - args, - ) + where_parts.append(f"{algorithm} = %s") + query_params.append(checksum_dict[algorithm]) + + query_parts.append(" AND ".join(where_parts)) + query = "\n".join(query_parts) + cur.execute(query, query_params) content = cur.fetchall() return content diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -24,6 +24,7 @@ Iterator, List, Optional, + Sequence, Tuple, TypeVar, Union, @@ -321,11 +322,11 @@ result[sha1].append(d) return result - def content_find(self, content): + def content_find(self, content: Dict[str, Any]) -> Sequence[Content]: if not set(content).intersection(DEFAULT_ALGORITHMS): raise StorageArgumentException( - "content keys must contain at least one of: %s" - % ", ".join(sorted(DEFAULT_ALGORITHMS)) + "content keys must contain at least one " + f"of: {', '.join(sorted(DEFAULT_ALGORITHMS))}" ) found = [] for algo in DEFAULT_ALGORITHMS: @@ -337,7 +338,7 @@ return [] keys = list(set.intersection(*found)) - return [self._contents[key].to_dict() for key in keys] + return [self._contents[key] for key in keys] def content_missing(self, content, key_hash="sha1"): for cont in content: @@ -347,10 +348,6 @@ if hash_ not in self._content_indexes.get(algo, []): yield cont[key_hash] break - else: - for result in self.content_find(cont): - if result["status"] == "missing": - yield cont[key_hash] def content_missing_per_sha1(self, contents): for content in contents: @@ -418,7 +415,7 @@ if id not in self._directories: yield id - def _join_dentry_to_content(self, dentry): + def _join_dentry_to_content(self, dentry: Dict[str, Any]) -> Dict[str, Any]: keys = ( "status", "sha1", @@ -430,11 +427,11 @@ ret.update(dentry) if ret["type"] == "file": # TODO: Make it able to handle more than one content - content = self.content_find({"sha1_git": ret["target"]}) - if content: - content = content[0] + contents = self.content_find({"sha1_git": ret["target"]}) + if contents: + content = contents[0] for key in keys: - ret[key] = content[key] + ret[key] = getattr(content, key) return ret def _directory_ls(self, directory_id, recursive, prefix=b""): diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -6,7 +6,7 @@ import datetime from enum import Enum -from typing import Dict, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union from swh.core.api import remote_api_endpoint @@ -291,7 +291,7 @@ ... @remote_api_endpoint("content/present") - def content_find(self, content): + def content_find(self, content: Dict[str, Any]) -> Sequence[Content]: """Find a content hash in db. Args: @@ -299,14 +299,14 @@ checksum algorithm names (see swh.model.hashutil.ALGORITHMS) to checksum values - Returns: - a triplet (sha1, sha1_git, sha256) if the content exist - or None otherwise. - Raises: ValueError: in case the key of the dictionary is not sha1, sha1_git nor sha256. + Returns: + an iterable of Content objects matching the search criteria if the + content exist. Empty iterable otherwise. + """ ... diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -11,11 +11,13 @@ from collections import defaultdict from contextlib import contextmanager from typing import ( + Any, Counter, Dict, Iterable, List, Optional, + Sequence, Tuple, Union, ) @@ -359,21 +361,27 @@ @timed @db_transaction() - def content_find(self, content, db=None, cur=None): + def content_find( + self, content: Dict[str, Any], db=None, cur=None + ) -> Sequence[Content]: if not set(content).intersection(DEFAULT_ALGORITHMS): raise StorageArgumentException( - "content keys must contain at least one of: " - "sha1, sha1_git, sha256, blake2s256" + "content keys must contain at least one " + f"of: {', '.join(sorted(DEFAULT_ALGORITHMS))}" ) - contents = db.content_find( + rows = db.content_find( sha1=content.get("sha1"), sha1_git=content.get("sha1_git"), sha256=content.get("sha256"), blake2s256=content.get("blake2s256"), cur=cur, ) - return [dict(zip(db.content_find_cols, content)) for content in contents] + contents = [] + for row in rows: + row_d = dict(zip(db.content_find_cols, row)) + contents.append(Content(**row_d)) + return contents @timed @db_transaction() diff --git a/swh/storage/tests/test_cassandra.py b/swh/storage/tests/test_cassandra.py --- a/swh/storage/tests/test_cassandra.py +++ b/swh/storage/tests/test_cassandra.py @@ -313,14 +313,14 @@ swh_storage._cql_runner, "content_get_from_token", mock_cgft ) - expected_cont = attr.evolve(cont, data=None).to_dict() + expected_content = attr.evolve(cont, data=None) actual_result = swh_storage.content_find({"sha1": cont.sha1}) assert called == 2 # but cont2 should be filtered out - assert actual_result == [expected_cont] + assert actual_result == [expected_content] @pytest.mark.skip("content_update is not yet implemented for Cassandra") def test_content_update(self): diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -3049,45 +3049,38 @@ swh_storage.content_add_metadata([content]) actually_present = swh_storage.content_find({"sha1": content.sha1}) - assert actually_present[0] == content.to_dict() + assert actually_present[0] == content def test_content_find_with_present_content(self, swh_storage, sample_data): content = sample_data.content - expected_content = content.to_dict() - del expected_content["data"] - del expected_content["ctime"] + expected_content = attr.evolve(content, data=None) # 1. with something to find swh_storage.content_add([content]) actually_present = swh_storage.content_find({"sha1": content.sha1}) assert 1 == len(actually_present) - actually_present[0].pop("ctime") assert actually_present[0] == expected_content # 2. with something to find actually_present = swh_storage.content_find({"sha1_git": content.sha1_git}) assert 1 == len(actually_present) - actually_present[0].pop("ctime") assert actually_present[0] == expected_content # 3. with something to find actually_present = swh_storage.content_find({"sha256": content.sha256}) assert 1 == len(actually_present) - actually_present[0].pop("ctime") assert actually_present[0] == expected_content # 4. with something to find actually_present = swh_storage.content_find(content.hashes()) assert 1 == len(actually_present) - actually_present[0].pop("ctime") assert actually_present[0] == expected_content def test_content_find_with_non_present_content(self, swh_storage, sample_data): missing_content = sample_data.skipped_content # 1. with something that does not exist actually_present = swh_storage.content_find({"sha1": missing_content.sha1}) - assert actually_present == [] # 2. with something that does not exist @@ -3115,30 +3108,18 @@ # Inject the data swh_storage.content_add([content, duplicated_content]) - actual_result = list( - swh_storage.content_find( - { - "blake2s256": duplicated_content.blake2s256, - "sha256": duplicated_content.sha256, - } - ) + actual_result = swh_storage.content_find( + { + "blake2s256": duplicated_content.blake2s256, + "sha256": duplicated_content.sha256, + } ) - expected_content = content.to_dict() - expected_duplicated_content = duplicated_content.to_dict() + expected_content = attr.evolve(content, data=None) + expected_duplicated_content = attr.evolve(duplicated_content, data=None) - for key in ["data", "ctime"]: # so we can compare - for dict_ in [ - expected_content, - expected_duplicated_content, - actual_result[0], - actual_result[1], - ]: - dict_.pop(key, None) - - expected_result = [expected_content, expected_duplicated_content] - for result in expected_result: - assert result in actual_result + for result in actual_result: + assert result in [expected_content, expected_duplicated_content] def test_content_find_with_duplicate_sha256(self, swh_storage, sample_data): content = sample_data.content @@ -3158,42 +3139,24 @@ ) swh_storage.content_add([content, duplicated_content]) - actual_result = list( - swh_storage.content_find({"sha256": duplicated_content.sha256}) - ) - + actual_result = swh_storage.content_find({"sha256": duplicated_content.sha256}) assert len(actual_result) == 2 - expected_content = content.to_dict() - expected_duplicated_content = duplicated_content.to_dict() - - for key in ["data", "ctime"]: # so we can compare - for dict_ in [ - expected_content, - expected_duplicated_content, - actual_result[0], - actual_result[1], - ]: - dict_.pop(key, None) - - assert sorted(actual_result, key=lambda x: x["sha1"]) == [ - expected_content, - expected_duplicated_content, - ] + expected_content = attr.evolve(content, data=None) + expected_duplicated_content = attr.evolve(duplicated_content, data=None) + + for result in actual_result: + assert result in [expected_content, expected_duplicated_content] # Find with both sha256 and blake2s256 - actual_result = list( - swh_storage.content_find( - { - "sha256": duplicated_content.sha256, - "blake2s256": duplicated_content.blake2s256, - } - ) + actual_result = swh_storage.content_find( + { + "sha256": duplicated_content.sha256, + "blake2s256": duplicated_content.blake2s256, + } ) assert len(actual_result) == 1 - actual_result[0].pop("ctime") - assert actual_result == [expected_duplicated_content] def test_content_find_with_duplicate_blake2s256(self, swh_storage, sample_data): @@ -3216,45 +3179,32 @@ swh_storage.content_add([content, duplicated_content]) - actual_result = list( - swh_storage.content_find({"blake2s256": duplicated_content.blake2s256}) + actual_result = swh_storage.content_find( + {"blake2s256": duplicated_content.blake2s256} ) - expected_content = content.to_dict() - expected_duplicated_content = duplicated_content.to_dict() + expected_content = attr.evolve(content, data=None) + expected_duplicated_content = attr.evolve(duplicated_content, data=None) - for key in ["data", "ctime"]: # so we can compare - for dict_ in [ - expected_content, - expected_duplicated_content, - actual_result[0], - actual_result[1], - ]: - dict_.pop(key, None) - - expected_result = [expected_content, expected_duplicated_content] - for result in expected_result: - assert result in actual_result + for result in actual_result: + assert result in [expected_content, expected_duplicated_content] # Find with both sha256 and blake2s256 - actual_result = list( - swh_storage.content_find( - { - "sha256": duplicated_content.sha256, - "blake2s256": duplicated_content.blake2s256, - } - ) + actual_result = swh_storage.content_find( + { + "sha256": duplicated_content.sha256, + "blake2s256": duplicated_content.blake2s256, + } ) - actual_result[0].pop("ctime") assert actual_result == [expected_duplicated_content] def test_content_find_bad_input(self, swh_storage): - # 1. with bad input + # 1. with no hash to lookup with pytest.raises(StorageArgumentException): - swh_storage.content_find({}) # empty is bad + swh_storage.content_find({}) # need at least one hash - # 2. with bad input + # 2. with bad hash with pytest.raises(StorageArgumentException): swh_storage.content_find({"unknown-sha1": "something"}) # not the right key