diff --git a/swh/storage/__init__.py b/swh/storage/__init__.py --- a/swh/storage/__init__.py +++ b/swh/storage/__init__.py @@ -22,6 +22,7 @@ "buffer": ".proxies.buffer.BufferingProxyStorage", "counter": ".proxies.counter.CountingProxyStorage", "filter": ".proxies.filter.FilteringProxyStorage", + "overlay": ".proxies.overlay.OverlayProxyStorage", "retry": ".proxies.retry.RetryingProxyStorage", "tenacious": ".proxies.tenacious.TenaciousProxyStorage", "validate": ".proxies.validate.ValidatingProxyStorage", diff --git a/swh/storage/proxies/overlay.py b/swh/storage/proxies/overlay.py new file mode 100644 --- /dev/null +++ b/swh/storage/proxies/overlay.py @@ -0,0 +1,335 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import datetime +import functools +import random +from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, TypeVar +import warnings + +from swh.core.api.classes import PagedResult +from swh.model.model import OriginVisit, Sha1Git +from swh.storage import get_storage +from swh.storage.exc import StorageArgumentException +from swh.storage.interface import StorageInterface + +OBJECT_TYPES = [ + "content", + "directory", + "snapshot", + "origin_visit_status", + "origin_visit", + "origin", +] + + +TKey = TypeVar("TKey", bound=Hashable) +TValue = TypeVar("TValue") + + +class UntestedCodeWarning(UserWarning): + pass + + +class OverlayProxyStorage: + """Overlay storage proxy + + This storage proxy is in front of several backends (or other proxies). + + It always writes to the first backend. + When reading, it returns aggregated results for all backends (or from the + first backend to have a result, for endpoints which provide a single result). + + Sample configuration use case for filtering storage: + + .. code-block: yaml + + storage: + cls: counter + storages: + - cls: remote + url: http://storage-rw.internal.staging.swh.network:5002/ + - cls: remote + url: http://storage-ro2.internal.staging.swh.network:5002/ + - cls: remote + url: http://storage-ro1.internal.staging.swh.network:5002/ + + """ + + def __init__(self, storages): + warnings.warn( + "OverlayProxyStorage is not well-tested and should not be used " + "in production.", + UntestedCodeWarning, + ) + self.storages: List[StorageInterface] = [ + get_storage(**storage) for storage in storages + ] + + def __getattr__(self, key): + if key == "storage": + raise AttributeError(key) + elif key == "journal_writer": + # Useful for tests + return self.storages[0].journal_writer + elif key.endswith("_add") or key in ("content_update", "content_add_metadata"): + return getattr(self.storages[0], key) + elif key in ( + "content_get_data", + "directory_get_entries", + "directory_entry_get_by_path", + "snapshot_get", + "snapshot_get_branches", + "snapshot_count_branches", + "origin_visit_get_by", + "origin_visit_status_get", + "origin_visit_status_get_latest", + "metadata_authority_get", + "metadata_fetcher_get", + ): + return self._getter_optional(key) + elif key in ( + "origin_list", + "origin_visit_get", + "origin_search", + "raw_extrinsic_metadata_get", + "origin_visit_get_with_statuses", + ): + return self._getter_pagedresult(key) + elif key.endswith("_get") or key in ("origin_get_by_sha1",): + return self._getter_list_optional(key) + elif key in ( + "content_missing", # TODO: could be optimized + "content_missing_per_sha1_git", # TODO: could be optimized + ): + return self._getter_intersection(key) + elif key.endswith("_missing") or key in ("content_missing_per_sha1",): + return self._missing(key) + elif key in ("refresh_stat_counters", "stat_counters"): + return getattr(self.storages[0], key) + elif key.endswith("_get_random"): + return self._getter_random(key) + elif key in ( + "content_find", + "origin_snapshot_get_all", + "extid_get_from_extid", + "extid_get_from_target", + "raw_extrinsic_metadata_get_by_ids", # TODO: could be optimized + "raw_extrinsic_metadata_get_authorities", + ): + return self._getter_union(key) + else: + raise NotImplementedError(key) + + def _getter_optional(self, method_name: str) -> Callable[[TKey], Optional[TValue]]: + """Generates a function which take an id and return, queries underlying + storages in order until one returns a non-None value""" + + @functools.wraps(getattr(self.storages[0], method_name)) + def newf(id_: TKey, *args, **kwargs) -> Optional[TValue]: + method: Callable[[TKey], Optional[TValue]] + + for storage in self.storages: + method = getattr(storage, method_name) + result = method(id_, *args, **kwargs) + if result is not None: + return result + + return None + + return newf + + def _getter_list_optional( + self, + method_name: str, + ) -> Callable[[List[TKey]], List[Optional[TValue]]]: + """Generates a function which take a list of ids and return a list of optional + objects in the same order, implemented by querying all underlying storages.""" + + @functools.wraps(getattr(self.storages[0], method_name)) + def newf(ids: List[TKey], *args, **kwargs) -> List[Optional[TValue]]: + method: Callable[[List[TKey]], List[Optional[TValue]]] + + missing_ids = list(ids) + results = {} + for storage in self.storages: + method = getattr(storage, method_name) + new_results = dict( + zip(missing_ids, method(missing_ids, *args, **kwargs)) + ) + results.update(new_results) + missing_ids = [id_ for id_ in missing_ids if new_results[id_] is None] + + return [results[id_] for id_ in ids] + + return newf + + def _missing(self, method_name: str) -> Callable[[List[TKey]], Iterable[TKey]]: + @functools.wraps(getattr(self.storages[0], method_name)) + def newf(ids: List[TKey]) -> List[TKey]: + method: Callable[[List[TKey]], Iterable[TKey]] + + missing_ids = list(ids) + for storage in self.storages: + method = getattr(storage, method_name) + missing_ids = list(method(missing_ids)) + return missing_ids + + return newf + + def _getter_random(self, method_name: str) -> Callable[[], Optional[TValue]]: + @functools.wraps(getattr(self.storages[0], method_name)) + def newf(*args, **kwargs) -> Optional[TValue]: + method: Callable[[], Optional[TValue]] + + # Not uniform sampling, but we don't care. + storages = list(self.storages) + random.shuffle(storages) + + for storage in storages: + method = getattr(storage, method_name) + try: + result = method(*args, **kwargs) + except IndexError: + # in-memory storage when empty + result = None + if result is not None: + return result + + return None + + return newf + + def _getter_intersection(self, method_name) -> Callable[..., List[TKey]]: + @functools.wraps(getattr(self.storages[0], method_name)) + def newf(*args, **kwargs) -> List[TKey]: + (head, *tail) = self.storages + results = set(getattr(head, method_name)(*args, **kwargs)) + for storage in tail: + method = getattr(storage, method_name) + results.intersection_update(method(*args, **kwargs)) + return list(results) + + return newf + + def _getter_union(self, method_name) -> Callable[..., List[TKey]]: + @functools.wraps(getattr(self.storages[0], method_name)) + def newf(*args, **kwargs) -> List[TKey]: + results = set() + for storage in self.storages: + method = getattr(storage, method_name) + results.update(method(*args, **kwargs)) + return list(results) + + return newf + + def _getter_pagedresult(self, method_name: str) -> Callable[..., PagedResult]: + @functools.wraps(getattr(self.storages[0], method_name)) + def newf(*args, page_token: Optional[bytes] = None, **kwargs) -> PagedResult: + if page_token is None: + storage_id = 0 + else: + if isinstance(page_token, str): + (storage_id_str, page_token) = page_token.split(" ", 1) + elif isinstance(page_token, bytes): + (storage_id_bytes, page_token) = page_token.split(b" ", 1) + storage_id_str = storage_id_bytes.decode() + else: + raise StorageArgumentException( + "page_token must be a string or bytes" + ) + storage_id = int(storage_id_str) + page_token = page_token or None + + prepend_results = [] + + for storage in self.storages[storage_id:]: + method = getattr(storage, method_name) + results = method(*args, page_token=page_token, **kwargs) + if results.results: + if results.next_page_token is None: + prepend_results = results.results + continue + elif isinstance(results.next_page_token, str): + next_page_token = f"{storage_id} {results.next_page_token}" + else: + next_page_token = f"{storage_id} ".encode() + ( + results.next_page_token + ) + return PagedResult( + next_page_token=next_page_token, + results=prepend_results + results.results, + ) + else: + storage_id += 1 + page_token = None + + return PagedResult( + next_page_token=None, + results=prepend_results, + ) + + return newf + + def check_config(self, *, check_write: bool) -> bool: + (rw_storage, *ro_storages) = self.storages + return rw_storage.check_config(check_write=check_write) and all( + storage.check_config(check_write=False) for storage in ro_storages + ) + + def directory_ls( + self, directory: Sha1Git, recursive: bool = False + ) -> Iterable[Dict[str, Any]]: + for storage in self.storages: + it = iter(storage.directory_ls(directory, recursive=recursive)) + try: + yield next(it) + except StopIteration: + # Note: this is slightly wasteful for the empty directory + continue + else: + yield from it + return + + def directory_get_raw_manifest( + self, directory_ids: List[Sha1Git] + ) -> Dict[Sha1Git, Optional[bytes]]: + results = {} + missing_ids = set(directory_ids) + for storage in self.storages: + new_results = storage.directory_get_raw_manifest(list(missing_ids)) + missing_ids.difference_update(set(new_results)) + results.update(new_results) + return results + + def object_find_by_sha1_git(self, ids: List[Sha1Git]) -> Dict[Sha1Git, List[Dict]]: + results: Dict[Sha1Git, List[Dict]] = {} + for storage in self.storages: + for (id_, objects) in storage.object_find_by_sha1_git(ids).items(): + results.setdefault(id_, []).extend(objects) + + return results + + def origin_visit_get_latest(self, *args, **kwargs) -> Optional[OriginVisit]: + return max( + ( + storage.origin_visit_get_latest(*args, **kwargs) + for storage in self.storages + ), + key=lambda ov: (-1000, None) if ov is None else (ov.visit, ov.date), + ) + + def origin_visit_find_by_date( + self, origin: str, visit_date: datetime.datetime + ) -> Optional[OriginVisit]: + return min( + ( + storage.origin_visit_find_by_date(origin, visit_date) + for storage in self.storages + ), + key=lambda ov: (datetime.timedelta.max, None) + if ov is None + else (abs(visit_date - ov.date), -(ov.visit or 0)), + ) diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -560,7 +560,7 @@ ] missing_contents = swh_storage.content_missing_per_sha1_git(contents) - assert list(missing_contents) == [missing_cont.sha1_git, missing_cont2.sha1_git] + assert set(missing_contents) == {missing_cont.sha1_git, missing_cont2.sha1_git} missing_contents = swh_storage.content_missing_per_sha1_git([]) assert list(missing_contents) == [] @@ -1382,8 +1382,10 @@ summary = swh_storage.extid_add(extids) assert summary == {"extid:add": len(gitids)} - assert swh_storage.extid_get_from_extid("git", gitids) == extids - assert swh_storage.extid_get_from_target(ObjectType.REVISION, gitids) == extids + assert set(swh_storage.extid_get_from_extid("git", gitids)) == set(extids) + assert set( + swh_storage.extid_get_from_target(ObjectType.REVISION, gitids) + ) == set(extids) assert swh_storage.extid_get_from_extid("hg", gitids) == [] assert swh_storage.extid_get_from_target(ObjectType.RELEASE, gitids) == [] @@ -1434,10 +1436,10 @@ summary = swh_storage.extid_add(extid_objs) assert summary == {"extid:add": len(swhids)} - assert swh_storage.extid_get_from_extid("hg", extids) == extid_objs - assert ( - swh_storage.extid_get_from_target(ObjectType.REVISION, swhids) == extid_objs - ) + assert set(swh_storage.extid_get_from_extid("hg", extids)) == set(extid_objs) + assert set( + swh_storage.extid_get_from_target(ObjectType.REVISION, swhids) + ) == set(extid_objs) assert swh_storage.extid_get_from_extid("git", extids) == [] assert swh_storage.extid_get_from_target(ObjectType.RELEASE, swhids) == [] @@ -1475,8 +1477,10 @@ # add them again, should be noop summary = swh_storage.extid_add(extids) # assert summary == {"extid:add": 0} - assert swh_storage.extid_get_from_extid("git", gitids) == extids - assert swh_storage.extid_get_from_target(ObjectType.REVISION, gitids) == extids + assert set(swh_storage.extid_get_from_extid("git", gitids)) == set(extids) + assert set( + swh_storage.extid_get_from_target(ObjectType.REVISION, gitids) + ) == set(extids) def test_extid_add_extid_multicity(self, swh_storage, sample_data): @@ -1515,8 +1519,8 @@ ] swh_storage.extid_add(extids2) - assert swh_storage.extid_get_from_extid("git", ids) == extids - assert swh_storage.extid_get_from_extid("hg", ids) == extids2 + assert set(swh_storage.extid_get_from_extid("git", ids)) == set(extids) + assert set(swh_storage.extid_get_from_extid("hg", ids)) == set(extids2) assert set(swh_storage.extid_get_from_target(ObjectType.REVISION, ids)) == { *extids, *extids2, @@ -1558,8 +1562,12 @@ swh_storage.extid_add(extids2) assert set(swh_storage.extid_get_from_extid("git", ids)) == {*extids, *extids2} - assert swh_storage.extid_get_from_target(ObjectType.REVISION, ids) == extids - assert swh_storage.extid_get_from_target(ObjectType.RELEASE, ids) == extids2 + assert set(swh_storage.extid_get_from_target(ObjectType.REVISION, ids)) == set( + extids + ) + assert set(swh_storage.extid_get_from_target(ObjectType.RELEASE, ids)) == set( + extids2 + ) def test_extid_version_behavior(self, swh_storage, sample_data): ids = [ diff --git a/swh/storage/tests/test_overlay.py b/swh/storage/tests/test_overlay.py new file mode 100644 --- /dev/null +++ b/swh/storage/tests/test_overlay.py @@ -0,0 +1,85 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import pytest + +from swh.storage.tests.test_in_memory import ( + TestInMemoryStorageGeneratedData as _TestInMemoryStorageGeneratedData, +) +from swh.storage.tests.test_in_memory import TestInMemoryStorage as _TestInMemoryStorage + + +@pytest.fixture(params=[1, 2]) +def swh_storage_backend_config(request): + yield { + "cls": "overlay", + "storages": [ + { + "cls": "memory", + "journal_writer": { + "cls": "memory", + }, + } + for _ in range(request.param) + ], + } + + +class TestOverlayProxy(_TestInMemoryStorage): + @pytest.mark.skip("Not supported by the overlay proxy") + def test_types(self): + pass + + @pytest.mark.skip("Not supported by the overlay proxy") + def test_content_get_partition(self): + pass + + @pytest.mark.skip("Not supported by the overlay proxy") + def test_content_get_partition_full(self): + pass + + @pytest.mark.skip("Not supported by the overlay proxy") + def test_content_get_partition_empty(self): + pass + + @pytest.mark.skip("Not supported by the overlay proxy") + def test_content_get_partition_limit_none(self): + pass + + @pytest.mark.skip("Not supported by the overlay proxy") + def test_content_get_partition_pagination_generate(self): + pass + + @pytest.mark.skip("Not supported by the overlay proxy") + def test_revision_log(self): + pass + + @pytest.mark.skip("Not supported by the overlay proxy") + def test_revision_log_with_limit(self): + pass + + @pytest.mark.skip("Not supported by the overlay proxy") + def test_revision_log_unknown_revision(self): + pass + + @pytest.mark.skip("Not supported by the overlay proxy") + def test_revision_shortlog(self): + pass + + @pytest.mark.skip("Not supported by the overlay proxy") + def test_revision_shortlog_with_limit(self): + pass + + @pytest.mark.skip("TODO: rewrite this test without hardcoded page_token") + def test_origin_visit_get_with_statuses(self): + pass + + @pytest.mark.skip("Not supported by the overlay proxy") + def test_content_add_objstorage_first(self): + pass + + +class TestOverlayProxyGeneratedData(_TestInMemoryStorageGeneratedData): + pass