diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -139,7 +139,7 @@ self._consistency_level = consistency_level self._set_cql_runner() self.journal_writer: JournalWriter = JournalWriter(journal_writer) - self.objstorage: ObjStorage = ObjStorage(objstorage) + self.objstorage: ObjStorage = ObjStorage(self, objstorage) self._allow_overwrite = allow_overwrite if directory_entries_insert_algo not in DIRECTORY_ENTRIES_INSERT_ALGOS: @@ -275,7 +275,7 @@ def content_add_metadata(self, content: List[Content]) -> Dict[str, int]: return self._content_add(content, with_data=False) - def content_get_data(self, content: Sha1) -> Optional[bytes]: + def content_get_data(self, content: Union[Sha1, HashDict]) -> Optional[bytes]: # FIXME: Make this method support slicing the `data` return self.objstorage.content_get(content) diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -836,7 +836,7 @@ def reset(self): self._cql_runner = InMemoryCqlRunner() - self.objstorage = ObjStorage({"cls": "memory"}) + self.objstorage = ObjStorage(self, {"cls": "memory"}) def check_config(self, *, check_write: bool) -> bool: return True diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -5,7 +5,7 @@ import datetime from enum import Enum -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union import attr from typing_extensions import Protocol, TypedDict, runtime_checkable @@ -181,11 +181,11 @@ ... @remote_api_endpoint("content/data") - def content_get_data(self, content: Sha1) -> Optional[bytes]: + def content_get_data(self, content: Union[HashDict, Sha1]) -> Optional[bytes]: """Given a content identifier, returns its associated data if any. Args: - content: sha1 identifier + content: dict of hashes (or just sha1 identifier) Returns: raw content data (bytes) diff --git a/swh/storage/objstorage.py b/swh/storage/objstorage.py --- a/swh/storage/objstorage.py +++ b/swh/storage/objstorage.py @@ -1,14 +1,15 @@ -# Copyright (C) 2020 The Software Heritage developers +# Copyright (C) 2020-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Dict, Iterable, Optional +from typing import Dict, Iterable, Optional, Union, cast +from swh.model.hashutil import DEFAULT_ALGORITHMS from swh.model.model import Content, MissingData from swh.objstorage.exc import ObjNotFoundError from swh.objstorage.factory import get_objstorage -from swh.storage.interface import Sha1 +from swh.storage.interface import HashDict, Sha1 from .exc import StorageArgumentException @@ -19,7 +20,8 @@ """ - def __init__(self, objstorage_config: Dict): + def __init__(self, storage, objstorage_config: Dict): + self.storage = storage self.objstorage = get_objstorage(**objstorage_config) def __getattr__(self, key): @@ -27,7 +29,7 @@ raise AttributeError(key) return getattr(self.objstorage, key) - def content_get(self, obj_id: Sha1) -> Optional[bytes]: + def content_get(self, obj_id: Union[Sha1, HashDict]) -> Optional[bytes]: """Retrieve data associated to the content from the objstorage Args: @@ -37,8 +39,26 @@ associated content's data if any, None otherwise. """ + hashes: HashDict + if isinstance(obj_id, bytes): + hashes = {"sha1": obj_id} + else: + hashes = obj_id + if set(hashes) < DEFAULT_ALGORITHMS: + # If some hashes are missing, query the database to fill blanks + candidates = self.storage.content_find(hashes) + if candidates: + # There may be more than one in case of collision; but we cannot + # do anything about it here + hashes = cast(HashDict, candidates[0].hashes()) + else: + # we will pass the partial hash dict to the objstorage, which + # will do the best it can with it. Usually, this will return None, + # as objects missing from the storage DB are unlikely to be present in the + # objstorage + pass try: - data = self.objstorage.get(obj_id) + data = self.objstorage.get(hashes) except ObjNotFoundError: data = None diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -11,7 +11,7 @@ import itertools import logging import operator -from typing import Any, Counter, Dict, Iterable, List, Optional, Sequence, Tuple +from typing import Any, Counter, Dict, Iterable, List, Optional, Sequence, Tuple, Union import attr import psycopg2 @@ -163,7 +163,7 @@ except psycopg2.OperationalError as e: raise StorageDBError(e) self.journal_writer = JournalWriter(journal_writer) - self.objstorage = ObjStorage(objstorage) + self.objstorage = ObjStorage(self, objstorage) self.query_options = query_options self._flavor: Optional[str] = None @@ -341,7 +341,7 @@ "content:add": len(contents), } - def content_get_data(self, content: Sha1) -> Optional[bytes]: + def content_get_data(self, content: Union[HashDict, Sha1]) -> Optional[bytes]: # FIXME: Make this method support slicing the `data` return self.objstorage.content_get(content) diff --git a/swh/storage/pytest_plugin.py b/swh/storage/pytest_plugin.py --- a/swh/storage/pytest_plugin.py +++ b/swh/storage/pytest_plugin.py @@ -48,10 +48,21 @@ @pytest.fixture -def swh_storage(swh_storage_backend_config): +def swh_storage_backend(swh_storage_backend_config): + """ + By default, this fixture aliases ``swh_storage``. However, when ``swh_storage`` + is overridden to be a proxy storage, this fixture returns the storage instance + behind all proxies. + + This is useful to introspect the state of backends from proxy tests""" return get_storage(**swh_storage_backend_config) +@pytest.fixture +def swh_storage(swh_storage_backend): + return swh_storage_backend + + @pytest.fixture def sample_data() -> StorageData: """Pre-defined sample storage object data to manipulate diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -238,6 +238,58 @@ swh_storage.refresh_stat_counters() assert swh_storage.stat_counters()["content"] == 1 + @pytest.mark.parametrize("algo", sorted(DEFAULT_ALGORITHMS)) + def test_content_get_data_single_hash_dict( + self, swh_storage, swh_storage_backend, sample_data, mocker, algo + ): + cont = sample_data.content + + swh_storage.content_add([cont]) + + content_find = mocker.patch.object( + swh_storage_backend, "content_find", wraps=swh_storage_backend.content_find + ) + assert swh_storage.content_get_data({algo: cont.get_hash(algo)}) == cont.data + + assert len(content_find.mock_calls) == 1 + + def test_content_get_data_two_hash_dict( + self, swh_storage, swh_storage_backend, sample_data, mocker + ): + cont = sample_data.content + + swh_storage.content_add([cont]) + + content_find = mocker.patch.object( + swh_storage_backend, "content_find", wraps=swh_storage_backend.content_find + ) + + combinations = list(itertools.combinations(sorted(DEFAULT_ALGORITHMS), 2)) + for (algo1, algo2) in combinations: + assert ( + swh_storage.content_get_data( + {algo1: cont.get_hash(algo1), algo2: cont.get_hash(algo2)} + ) + == cont.data + ) + assert len(content_find.mock_calls) == len(combinations) + + def test_content_get_data_full_dict( + self, swh_storage, swh_storage_backend, sample_data, mocker + ): + cont = sample_data.content + + swh_storage.content_add([cont]) + + content_find = mocker.patch.object( + swh_storage_backend, "content_find", wraps=swh_storage_backend.content_find + ) + assert swh_storage.content_get_data(cont.hashes()) == cont.data + assert len(content_find.mock_calls) == 0, ( + "content_get_data() needlessly called content_find(), " + "as all hashes were provided as argument" + ) + def test_content_get_data_missing(self, swh_storage, sample_data): cont, cont2 = sample_data.contents[:2] diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py --- a/swh/storage/tests/test_api_client.py +++ b/swh/storage/tests/test_api_client.py @@ -63,6 +63,11 @@ storage.journal_writer = journal_writer +@pytest.fixture +def swh_storage_backend(app_server, swh_storage): + return app_server.storage + + class TestStorageApi(_TestStorage): @pytest.mark.skip( 'The "person" table of the pgsql is a legacy thing, and not ' diff --git a/swh/storage/tests/test_tenacious.py b/swh/storage/tests/test_tenacious.py --- a/swh/storage/tests/test_tenacious.py +++ b/swh/storage/tests/test_tenacious.py @@ -36,7 +36,7 @@ @pytest.fixture -def swh_storage_backend_config2(): +def swh_storage_backend_config(): yield { "cls": "memory", "journal_writer": { @@ -46,21 +46,17 @@ @pytest.fixture -def swh_storage(): +def swh_storage(swh_storage_backend, swh_storage_backend_config): storage_config = { "cls": "pipeline", "steps": [ {"cls": "tenacious"}, - { - "cls": "memory", - "journal_writer": { - "cls": "memory", - }, - }, + swh_storage_backend_config, ], } storage = get_storage(**storage_config) + storage.storage = swh_storage_backend # use the same instance of the in-mem backend storage.journal_writer = storage.storage.journal_writer return storage