diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -75,13 +75,17 @@ def content_get_metadata_from_sha1s(self, sha1s, cur=None): cur = self._cursor(cur) - yield from execute_values_generator( + for data in execute_values_generator( cur, """ select t.sha1, %s from (values %%s) as t (sha1) left join content using (sha1) """ % ', '.join(self.content_get_metadata_keys[1:]), ((sha1,) for sha1 in sha1s), - ) + ): + if set(data[1:]) == {None}: # inexisting content + yield data[0], None + else: + yield data[0], data def content_get_range(self, start, end, limit=None, cur=None): """Retrieve contents within range [start, end]. diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2019 The Software Heritage developers +# Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -14,7 +14,7 @@ from collections import defaultdict from datetime import timedelta -from typing import Any, Dict, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional import attr @@ -347,18 +347,21 @@ result2['next_page_token'] = hash_to_hex(result['next']) return result2 - def content_get_metadata(self, content): + def content_get_metadata( + self, contents: List[Dict]) -> Dict: """Retrieve content metadata in bulk Args: content: iterable of content identifiers (sha1) Returns: - an iterable with content metadata corresponding to the given ids + a dict with keys the content's sha1 and the associated value + either the existing content's metadata or None if the content does + not exist. + """ - # FIXME: the return value should be a mapping from search key to found - # content*s* - for sha1 in content: + result: Dict = defaultdict(list) + for sha1 in contents: if sha1 in self._content_indexes['sha1']: objs = self._content_indexes['sha1'][sha1] # FIXME: rather than selecting one of the objects with that @@ -367,17 +370,10 @@ key = random.sample(objs, 1)[0] d = self._contents[key].to_dict() del d['ctime'] - yield d + result[sha1].append(d) else: - # FIXME: should really be None - yield { - 'sha1': sha1, - 'sha1_git': None, - 'sha256': None, - 'blake2s256': None, - 'length': None, - 'status': None, - } + result[sha1] = [] + return result def content_find(self, content): if not set(content).intersection(DEFAULT_ALGORITHMS): diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2019 The Software Heritage developers +# Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -11,7 +11,7 @@ from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager -from typing import Any, Dict, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional import dateutil.parser import psycopg2 @@ -537,18 +537,29 @@ @remote_api_endpoint('content/metadata') @timed - @db_transaction_generator(statement_timeout=500) - def content_get_metadata(self, content, db=None, cur=None): + @db_transaction(statement_timeout=500) + def content_get_metadata( + self, contents: List[Dict], + db=None, cur=None) -> Dict[bytes, List[Dict]]: """Retrieve content metadata in bulk Args: content: iterable of content identifiers (sha1) Returns: - an iterable with content metadata corresponding to the given ids + a dict with keys the content's sha1 and the associated value + either the existing content's metadata or None if the content does + not exist. + """ - for metadata in db.content_get_metadata_from_sha1s(content, cur): - yield dict(zip(db.content_get_metadata_keys, metadata)) + result: Dict[bytes, List[Dict]] = defaultdict(list) + for sha1, row in db.content_get_metadata_from_sha1s(contents, cur): + if row: + content_meta = dict(zip(db.content_get_metadata_keys, row)) + result[sha1].append(content_meta) + else: + result[sha1] = [] + return result @remote_api_endpoint('content/missing') @timed diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -236,8 +236,9 @@ expected_cont = cont.copy() del expected_cont['ctime'] - assert list(swh_storage.content_get_metadata([cont['sha1']])) == \ - [expected_cont] + assert swh_storage.content_get_metadata([cont['sha1']]) == { + cont['sha1']: [expected_cont] + } assert list(swh_storage.journal_writer.objects) == [('content', cont)] @@ -446,14 +447,16 @@ swh_storage.content_add([cont1, cont2]) - actual_md = list(swh_storage.content_get_metadata( - [cont1['sha1'], cont2['sha1']])) + actual_md = swh_storage.content_get_metadata( + [cont1['sha1'], cont2['sha1']]) # we only retrieve the metadata cont1.pop('data') cont2.pop('data') - assert actual_md in ([cont1, cont2], [cont2, cont1]) + assert actual_md[cont1['sha1']] == [cont1] + assert actual_md[cont2['sha1']] == [cont2] + assert len(actual_md.keys()) == 2 def test_content_get_metadata_missing_sha1(self, swh_storage): cont1 = data.cont @@ -462,15 +465,10 @@ swh_storage.content_add([cont1, cont2]) - gen = swh_storage.content_get_metadata([missing_cont['sha1']]) - - # All the metadata keys are None - missing_cont.pop('data') - for key in missing_cont: - if key != 'sha1': - missing_cont[key] = None + actual_contents = swh_storage.content_get_metadata( + [missing_cont['sha1']]) - assert list(gen) == [missing_cont] + assert actual_contents == {missing_cont['sha1']: []} def test_content_get_random(self, swh_storage): swh_storage.content_add([data.cont, data.cont2, data.cont3]) @@ -3163,12 +3161,17 @@ get_sha1s = [c['sha1'] for c in expected_contents] # retrieve contents - actual_contents = list(swh_storage.content_get_metadata(get_sha1s)) + meta_contents = swh_storage.content_get_metadata(get_sha1s) - assert len(actual_contents) == len(get_sha1s) + assert len(list(meta_contents)) == len(get_sha1s) + + actual_contents = [] + for contents in meta_contents.values(): + actual_contents.extend(contents) keys_to_check = {'length', 'status', 'sha1', 'sha1_git', 'sha256', 'blake2s256'} + assert_contents_ok(expected_contents, actual_contents, keys_to_check=keys_to_check)