diff --git a/swh/storage/buffer.py b/swh/storage/buffer.py --- a/swh/storage/buffer.py +++ b/swh/storage/buffer.py @@ -4,20 +4,23 @@ # See top-level LICENSE file for more information from functools import partial -from typing import Collection, Dict, Iterable, List, Mapping, Optional, Tuple +from typing import Collection, Dict, Iterable, Mapping, Tuple + +from typing_extensions import Literal from swh.core.utils import grouper from swh.model.model import BaseModel, Content, SkippedContent from swh.storage import get_storage from swh.storage.interface import StorageInterface -OBJECT_TYPES: List[str] = [ +LObjectType = Literal["content", "skipped_content", "directory", "revision", "release"] +OBJECT_TYPES: Tuple[LObjectType, ...] = ( "content", "skipped_content", "directory", "revision", "release", -] +) DEFAULT_BUFFER_THRESHOLDS: Dict[str, int] = { "content": 10000, @@ -63,7 +66,7 @@ if min_batch_size is not DEFAULT_BUFFER_THRESHOLDS: self._buffer_thresholds = {**DEFAULT_BUFFER_THRESHOLDS, **min_batch_size} - self._objects: Dict[str, Dict[Tuple[str, ...], BaseModel]] = { + self._objects: Dict[LObjectType, Dict[Tuple[str, ...], BaseModel]] = { k: {} for k in OBJECT_TYPES } self._contents_size: int = 0 @@ -106,7 +109,11 @@ ) def object_add( - self, objects: Collection[BaseModel], *, object_type: str, keys: Iterable[str], + self, + objects: Collection[BaseModel], + *, + object_type: LObjectType, + keys: Iterable[str], ) -> Dict[str, int]: """Push objects to write to the storage in the buffer. Flushes the buffer to the storage if the threshold is hit. @@ -121,11 +128,10 @@ return {} - def flush(self, object_types: Optional[List[str]] = None) -> Dict[str, int]: + def flush( + self, object_types: Collection[LObjectType] = OBJECT_TYPES + ) -> Dict[str, int]: summary: Dict[str, int] = self.storage.flush(object_types) - if object_types is None: - object_types = OBJECT_TYPES - for object_type in object_types: buffer_ = self._objects[object_type] batches = grouper(buffer_.values(), n=self._buffer_thresholds[object_type]) @@ -138,7 +144,9 @@ return summary - def clear_buffers(self, object_types: Optional[List[str]] = None) -> None: + def clear_buffers( + self, object_types: Collection[LObjectType] = OBJECT_TYPES + ) -> None: """Clear objects from current buffer. WARNING: @@ -148,13 +156,10 @@ you want to continue your processing. """ - if object_types is None: - object_types = OBJECT_TYPES - for object_type in object_types: buffer_ = self._objects[object_type] buffer_.clear() if object_type == "content": self._contents_size = 0 - return self.storage.clear_buffers(object_types) + self.storage.clear_buffers(object_types) diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -9,7 +9,18 @@ import json import random import re -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import ( + Any, + Callable, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) import attr @@ -1303,11 +1314,11 @@ else: return None - def clear_buffers(self, object_types: Optional[List[str]] = None) -> None: + def clear_buffers(self, object_types: Collection[str]) -> None: """Do nothing """ return None - def flush(self, object_types: Optional[List[str]] = None) -> Dict: + def flush(self, object_types: Collection[str]) -> Dict[str, int]: return {} diff --git a/swh/storage/filter.py b/swh/storage/filter.py --- a/swh/storage/filter.py +++ b/swh/storage/filter.py @@ -63,7 +63,7 @@ """Return only the content keys missing from swh Args: - content_hashes: List of sha256 to check for existence in swh + content_hashes: list of sha256 to check for existence in swh storage """ @@ -79,7 +79,7 @@ """Return only the content keys missing from swh Args: - content_hashes: List of sha1_git to check for existence in swh + content_hashes: list of sha1_git to check for existence in swh storage """ @@ -97,7 +97,7 @@ Args: object_type: object type to use {revision, directory} - ids: List of object_type ids + ids: list of object_type ids Returns: Missing ids from the storage for object_type @@ -114,3 +114,9 @@ fn = fn_by_object_type[object_type] return set(fn(missing_ids)) + + def clear_buffers(self, *args) -> None: + self.storage.clear_buffers(*args) + + def flush(self, *args) -> Dict[str, int]: + return self.storage.flush(*args) diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -5,7 +5,17 @@ import datetime from enum import Enum -from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Collection, + Dict, + Iterable, + List, + Optional, + Tuple, + TypeVar, + Union, +) from typing_extensions import Protocol, TypedDict, runtime_checkable @@ -1185,7 +1195,7 @@ ... @remote_api_endpoint("clear/buffer") - def clear_buffers(self, object_types: Optional[List[str]] = None) -> None: + def clear_buffers(self, object_types: Collection[str]) -> None: """For backend storages (pg, storage, in-memory), this is a noop operation. For proxy storages (especially filter, buffer), this is an operation which cleans internal state. @@ -1193,7 +1203,7 @@ """ @remote_api_endpoint("flush") - def flush(self, object_types: Optional[List[str]] = None) -> Dict: + def flush(self, object_types: Collection[str]) -> Dict[str, int]: """For backend storages (pg, storage, in-memory), this is expected to be a noop operation. For proxy storages (especially buffer), this is expected to trigger actual writes to the backend. diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -9,7 +9,17 @@ from contextlib import contextmanager import datetime import itertools -from typing import Any, Counter, Dict, Iterable, List, Optional, Tuple, Union +from typing import ( + Any, + Collection, + Counter, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) import attr import psycopg2 @@ -1396,13 +1406,13 @@ return None return MetadataAuthority.from_dict(dict(zip(db.metadata_authority_cols, row))) - def clear_buffers(self, object_types: Optional[List[str]] = None) -> None: + def clear_buffers(self, object_types: Collection[str]) -> None: """Do nothing """ return None - def flush(self, object_types: Optional[List[str]] = None) -> Dict: + def flush(self, object_types: Collection[str]) -> Dict[str, int]: return {} def _get_authority_id(self, authority: MetadataAuthority, db, cur): diff --git a/swh/storage/retry.py b/swh/storage/retry.py --- a/swh/storage/retry.py +++ b/swh/storage/retry.py @@ -5,7 +5,7 @@ import logging import traceback -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -130,11 +130,11 @@ def snapshot_add(self, snapshots: List[Snapshot]) -> Dict: return self.storage.snapshot_add(snapshots) - def clear_buffers(self, object_types: Optional[List[str]] = None) -> None: - return self.storage.clear_buffers(object_types) + def clear_buffers(self, *args) -> None: + return self.storage.clear_buffers(*args) - def flush(self, object_types: Optional[List[str]] = None) -> Dict: + def flush(self, *args) -> Dict[str, int]: """Specific case for buffer proxy storage failing to flush data """ - return self.storage.flush(object_types) + return self.storage.flush(*args) diff --git a/swh/storage/tests/test_postgresql.py b/swh/storage/tests/test_postgresql.py --- a/swh/storage/tests/test_postgresql.py +++ b/swh/storage/tests/test_postgresql.py @@ -248,13 +248,13 @@ """Calling clear buffers on real storage does nothing """ - assert swh_storage.clear_buffers() is None + assert swh_storage.clear_buffers([]) is None def test_flush(self, swh_storage): """Calling clear buffers on real storage does nothing """ - assert swh_storage.flush() == {} + assert swh_storage.flush([]) == {} def test_dbversion(self, swh_storage): with swh_storage.db() as db: diff --git a/swh/storage/validate.py b/swh/storage/validate.py --- a/swh/storage/validate.py +++ b/swh/storage/validate.py @@ -69,3 +69,9 @@ def snapshot_add(self, snapshots: List[Snapshot]) -> Dict: self._check_hashes(snapshots) return self.storage.snapshot_add(snapshots) + + def clear_buffers(self, *args) -> None: + self.storage.clear_buffers(*args) + + def flush(self, *args) -> Dict[str, int]: + return self.storage.flush(*args)