diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -6,6 +6,7 @@ import base64 import datetime import itertools +import operator import random import re from typing import ( @@ -293,17 +294,25 @@ assert len(contents) <= limit return PagedResult(results=contents, next_page_token=next_page_token) - def content_get(self, contents: List[Sha1]) -> List[Optional[Content]]: - contents_by_sha1: Dict[Sha1, Optional[Content]] = {} - for sha1 in contents: - # Get all (sha1, sha1_git, sha256, blake2s256) whose sha1 - # matches the argument, from the index table ('content_by_sha1') - for row in self._content_get_from_hash("sha1", sha1): + def content_get( + self, contents: List[bytes], algo: str = "sha1" + ) -> List[Optional[Content]]: + if algo not in DEFAULT_ALGORITHMS: + raise StorageArgumentException( + "algo should be one of {','.join(DEFAULT_ALGORITHMS)}" + ) + + key = operator.attrgetter(algo) + contents_by_hash: Dict[Sha1, Optional[Content]] = {} + for hash_ in contents: + # Get all (sha1, sha1_git, sha256, blake2s256) whose sha1/sha1_git + # matches the argument, from the index table ('content_by_*') + for row in self._content_get_from_hash(algo, hash_): row_d = row.to_dict() row_d.pop("ctime") content = Content(**row_d) - contents_by_sha1[content.sha1] = content - return [contents_by_sha1.get(sha1) for sha1 in contents] + contents_by_hash[key(content)] = content + return [contents_by_hash.get(hash_) for hash_ in contents] def content_find(self, content: Dict[str, Any]) -> List[Content]: # Find an algorithm that is common to all the requested contents. diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -203,11 +203,15 @@ ... @remote_api_endpoint("content/metadata") - def content_get(self, contents: List[Sha1]) -> List[Optional[Content]]: + def content_get( + self, contents: List[bytes], algo: str = "sha1" + ) -> List[Optional[Content]]: """Retrieve content metadata in bulk Args: content: List of content identifiers + algo: one of the checksum algorithm in + :data:`swh.model.hashutil.DEFAULT_ALGORITHMS` Returns: List of contents model objects when they exist, None otherwise. diff --git a/swh/storage/postgresql/db.py b/swh/storage/postgresql/db.py --- a/swh/storage/postgresql/db.py +++ b/swh/storage/postgresql/db.py @@ -13,6 +13,7 @@ from swh.core.db.db_utils import execute_values_generator from swh.core.db.db_utils import jsonize as _jsonize from swh.core.db.db_utils import stored_procedure +from swh.model.hashutil import DEFAULT_ALGORITHMS from swh.model.identifiers import ObjectType from swh.model.model import SHA1_SIZE, OriginVisit, OriginVisitStatus, Sha1Git from swh.storage.interface import ListOrder @@ -113,16 +114,18 @@ "origin", ] - def content_get_metadata_from_sha1s(self, sha1s, cur=None): + def content_get_metadata_from_hashes( + self, hashes: List[bytes], algo: str, cur=None + ): cur = self._cursor(cur) + assert algo in DEFAULT_ALGORITHMS + query = f""" + select {", ".join(self.content_get_metadata_keys)} + from (values %s) as t (hash) + inner join content on (content.{algo}=hash) + """ yield from execute_values_generator( - cur, - """ - select t.sha1, %s from (values %%s) as t (sha1) - inner join content using (sha1) - """ - % ", ".join(self.content_get_metadata_keys[1:]), - ((sha1,) for sha1 in sha1s), + cur, query, ((hash_,) for hash_ in hashes), ) def content_get_range(self, start, end, limit=None, cur=None): diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -9,6 +9,7 @@ from contextlib import contextmanager import datetime import itertools +import operator from typing import Any, Counter, Dict, Iterable, List, Optional, Sequence, Tuple import attr @@ -319,15 +320,23 @@ @timed @db_transaction(statement_timeout=500) def content_get( - self, contents: List[Sha1], db=None, cur=None + self, contents: List[bytes], algo: str = "sha1", db=None, cur=None ) -> List[Optional[Content]]: - contents_by_sha1: Dict[Sha1, Optional[Content]] = {} - for row in db.content_get_metadata_from_sha1s(contents, cur): + contents_by_hash: Dict[bytes, Optional[Content]] = {} + if algo not in DEFAULT_ALGORITHMS: + raise StorageArgumentException( + "algo should be one of {','.join(DEFAULT_ALGORITHMS)}" + ) + + rows = db.content_get_metadata_from_hashes(contents, algo, cur) + key = operator.attrgetter(algo) + + for row in rows: row_d = dict(zip(db.content_get_metadata_keys, row)) content = Content(**row_d) - contents_by_sha1[content.sha1] = content + contents_by_hash[key(content)] = content - return [contents_by_sha1.get(sha1) for sha1 in contents] + return [contents_by_hash.get(sha1) for sha1 in contents] @timed @db_transaction_generator() diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -19,7 +19,7 @@ from swh.core.api.classes import stream_results from swh.model import from_disk -from swh.model.hashutil import hash_to_bytes +from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes from swh.model.hypothesis_strategies import objects from swh.model.identifiers import CoreSWHID, ObjectType from swh.model.model import ( @@ -635,12 +635,15 @@ for content in actual_contents: assert content in expected_contents - def test_content_get(self, swh_storage, sample_data): + @pytest.mark.parametrize("algo", DEFAULT_ALGORITHMS) + def test_content_get(self, swh_storage, sample_data, algo): cont1, cont2 = sample_data.contents[:2] swh_storage.content_add([cont1, cont2]) - actual_contents = swh_storage.content_get([cont1.sha1, cont2.sha1]) + actual_contents = swh_storage.content_get( + [getattr(cont1, algo), getattr(cont2, algo)], algo + ) # we only retrieve the metadata so no data nor ctime within expected_contents = [attr.evolve(c, data=None) for c in [cont1, cont2]] @@ -649,7 +652,8 @@ for content in actual_contents: assert content.ctime is None - def test_content_get_missing_sha1(self, swh_storage, sample_data): + @pytest.mark.parametrize("algo", DEFAULT_ALGORITHMS) + def test_content_get_missing(self, swh_storage, sample_data, algo): cont1, cont2 = sample_data.contents[:2] assert cont1.sha1 != cont2.sha1 missing_cont = sample_data.skipped_content @@ -657,7 +661,8 @@ swh_storage.content_add([cont1, cont2]) actual_contents = swh_storage.content_get( - [cont1.sha1, cont2.sha1, missing_cont.sha1] + [getattr(cont1, algo), getattr(cont2, algo), getattr(missing_cont, algo)], + algo, ) expected_contents = [