diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -6,6 +6,7 @@ import base64 import datetime import itertools +import operator import random import re from typing import ( @@ -293,17 +294,24 @@ assert len(contents) <= limit return PagedResult(results=contents, next_page_token=next_page_token) - def content_get(self, contents: List[Sha1]) -> List[Optional[Content]]: - contents_by_sha1: Dict[Sha1, Optional[Content]] = {} - for sha1 in contents: - # Get all (sha1, sha1_git, sha256, blake2s256) whose sha1 - # matches the argument, from the index table ('content_by_sha1') - for row in self._content_get_from_hash("sha1", sha1): + def content_get( + self, contents: List[bytes], algo: str = "sha1" + ) -> List[Optional[Content]]: + if algo not in ("sha1", "sha1_git"): + raise StorageArgumentException( + "Invalid algo: {algo} (expected only 'sha1' or 'sha1_git')" + ) + key = operator.attrgetter(algo) + contents_by_hash: Dict[Sha1, Optional[Content]] = {} + for hash_ in contents: + # Get all (sha1, sha1_git, sha256, blake2s256) whose sha1/sha1_git + # matches the argument, from the index table ('content_by_*') + for row in self._content_get_from_hash(algo, hash_): row_d = row.to_dict() row_d.pop("ctime") content = Content(**row_d) - contents_by_sha1[content.sha1] = content - return [contents_by_sha1.get(sha1) for sha1 in contents] + contents_by_hash[key(content)] = content + return [contents_by_hash.get(hash_) for hash_ in contents] def content_find(self, content: Dict[str, Any]) -> List[Content]: # Find an algorithm that is common to all the requested contents. diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -203,11 +203,14 @@ ... @remote_api_endpoint("content/metadata") - def content_get(self, contents: List[Sha1]) -> List[Optional[Content]]: + def content_get( + self, contents: List[bytes], algo: str = "sha1" + ) -> List[Optional[Content]]: """Retrieve content metadata in bulk Args: content: List of content identifiers + algo: "sha1" or "sha1_git", the index to lookup in to find the given hashes Returns: List of contents model objects when they exist, None otherwise. diff --git a/swh/storage/postgresql/db.py b/swh/storage/postgresql/db.py --- a/swh/storage/postgresql/db.py +++ b/swh/storage/postgresql/db.py @@ -125,6 +125,18 @@ ((sha1,) for sha1 in sha1s), ) + def content_get_metadata_from_sha1_gits(self, sha1_gits, cur=None): + cur = self._cursor(cur) + yield from execute_values_generator( + cur, + """ + select %s from (values %%s) as t (sha1_git) + inner join content using (sha1_git) + """ + % ", ".join(self.content_get_metadata_keys), + ((sha1_git,) for sha1_git in sha1_gits), + ) + def content_get_range(self, start, end, limit=None, cur=None): """Retrieve contents within range [start, end]. diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -9,6 +9,7 @@ from contextlib import contextmanager import datetime import itertools +import operator from typing import Any, Counter, Dict, Iterable, List, Optional, Sequence, Tuple import attr @@ -319,15 +320,26 @@ @timed @db_transaction(statement_timeout=500) def content_get( - self, contents: List[Sha1], db=None, cur=None + self, contents: List[bytes], algo: str = "sha1", db=None, cur=None ) -> List[Optional[Content]]: - contents_by_sha1: Dict[Sha1, Optional[Content]] = {} - for row in db.content_get_metadata_from_sha1s(contents, cur): + contents_by_hash: Dict[bytes, Optional[Content]] = {} + if algo == "sha1": + rows = db.content_get_metadata_from_sha1s(contents, cur) + elif algo == "sha1_git": + rows = db.content_get_metadata_from_sha1_gits(contents, cur) + else: + raise StorageArgumentException( + "Invalid algo: {algo} (expected only 'sha1' or 'sha1_git')" + ) + + key = operator.attrgetter(algo) + + for row in rows: row_d = dict(zip(db.content_get_metadata_keys, row)) content = Content(**row_d) - contents_by_sha1[content.sha1] = content + contents_by_hash[key(content)] = content - return [contents_by_sha1.get(sha1) for sha1 in contents] + return [contents_by_hash.get(sha1) for sha1 in contents] @timed @db_transaction_generator() diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -665,6 +665,22 @@ ] assert actual_contents == expected_contents + def test_content_get__sha1_git(self, swh_storage, sample_data): + cont1, cont2 = sample_data.contents[:2] + + swh_storage.content_add([cont1, cont2]) + + actual_contents = swh_storage.content_get( + [cont1.sha1_git, cont2.sha1_git], algo="sha1_git" + ) + + # we only retrieve the metadata so no data nor ctime within + expected_contents = [attr.evolve(c, data=None) for c in [cont1, cont2]] + + assert actual_contents == expected_contents + for content in actual_contents: + assert content.ctime is None + def test_content_get_random(self, swh_storage, sample_data): cont, cont2, cont3 = sample_data.contents[:3] swh_storage.content_add([cont, cont2, cont3])