diff --git a/swh/storage/buffer.py b/swh/storage/buffer.py --- a/swh/storage/buffer.py +++ b/swh/storage/buffer.py @@ -6,12 +6,15 @@ from functools import partial from typing import Dict, Iterable, Mapping, Sequence, Tuple +from typing_extensions import Literal + from swh.core.utils import grouper from swh.model.model import BaseModel, Content, SkippedContent from swh.storage import get_storage from swh.storage.interface import StorageInterface -OBJECT_TYPES: Tuple[str, ...] = ( +LObjectType = Literal["content", "skipped_content", "directory", "revision", "release"] +OBJECT_TYPES: Tuple[LObjectType, ...] = ( "content", "skipped_content", "directory", @@ -60,7 +63,7 @@ self._buffer_thresholds = {**DEFAULT_BUFFER_THRESHOLDS, **min_batch_size} - self._objects: Dict[str, Dict[Tuple[str, ...], BaseModel]] = { + self._objects: Dict[LObjectType, Dict[Tuple[str, ...], BaseModel]] = { k: {} for k in OBJECT_TYPES } self._contents_size: int = 0 @@ -103,7 +106,11 @@ ) def object_add( - self, objects: Sequence[BaseModel], *, object_type: str, keys: Iterable[str], + self, + objects: Sequence[BaseModel], + *, + object_type: LObjectType, + keys: Iterable[str], ) -> Dict[str, int]: """Push objects to write to the storage in the buffer. Flushes the buffer to the storage if the threshold is hit. @@ -118,7 +125,9 @@ return {} - def flush(self, object_types: Sequence[str] = OBJECT_TYPES) -> Dict[str, int]: + def flush( + self, object_types: Sequence[LObjectType] = OBJECT_TYPES + ) -> Dict[str, int]: summary: Dict[str, int] = self.storage.flush(object_types) for object_type in object_types: buffer_ = self._objects[object_type] @@ -132,7 +141,7 @@ return summary - def clear_buffers(self, object_types: Sequence[str] = OBJECT_TYPES) -> None: + def clear_buffers(self, object_types: Sequence[LObjectType] = OBJECT_TYPES) -> None: """Clear objects from current buffer. WARNING: