diff --git a/swh/objstorage/api/client.py b/swh/objstorage/api/client.py --- a/swh/objstorage/api/client.py +++ b/swh/objstorage/api/client.py @@ -32,7 +32,7 @@ reraise_exceptions = [ObjNotFoundError, Error] backend_class = ObjStorageInterface - def restore(self, content, obj_id=None, *args, **kwargs): + def restore(self, content, obj_id=None): return self.add(content, obj_id, check_presence=False) def add_stream(self, content_iter, obj_id, check_presence=True): diff --git a/swh/objstorage/backends/in_memory.py b/swh/objstorage/backends/in_memory.py --- a/swh/objstorage/backends/in_memory.py +++ b/swh/objstorage/backends/in_memory.py @@ -24,13 +24,13 @@ def check_config(self, *, check_write): return True - def __contains__(self, obj_id, *args, **kwargs): + def __contains__(self, obj_id): return obj_id in self.state def __iter__(self): return iter(sorted(self.state)) - def add(self, content, obj_id=None, check_presence=True, *args, **kwargs): + def add(self, content, obj_id=None, check_presence=True): if obj_id is None: obj_id = compute_hash(content) @@ -41,20 +41,20 @@ return obj_id - def get(self, obj_id, *args, **kwargs): + def get(self, obj_id): if obj_id not in self: raise ObjNotFoundError(obj_id) return self.state[obj_id] - def check(self, obj_id, *args, **kwargs): + def check(self, obj_id): if obj_id not in self: raise ObjNotFoundError(obj_id) if compute_hash(self.state[obj_id]) != obj_id: raise Error("Corrupt object %s" % obj_id) return True - def delete(self, obj_id, *args, **kwargs): + def delete(self, obj_id): super().delete(obj_id) # Check delete permission if obj_id not in self: raise ObjNotFoundError(obj_id) diff --git a/swh/objstorage/interface.py b/swh/objstorage/interface.py --- a/swh/objstorage/interface.py +++ b/swh/objstorage/interface.py @@ -3,11 +3,14 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from typing import Dict, Protocol, runtime_checkable + from swh.core.api import remote_api_endpoint from swh.objstorage.objstorage import DEFAULT_CHUNK_SIZE, DEFAULT_LIMIT -class ObjStorageInterface: +@runtime_checkable +class ObjStorageInterface(Protocol): """ High-level API to manipulate the Software Heritage object storage. Conceptually, the object storage offers the following methods: @@ -80,7 +83,7 @@ ... @remote_api_endpoint("content/add/batch") - def add_batch(self, contents, check_presence=True): + def add_batch(self, contents, check_presence=True) -> Dict: """Add a batch of new objects to the object storage. Args: diff --git a/swh/objstorage/multiplexer/multiplexer_objstorage.py b/swh/objstorage/multiplexer/multiplexer_objstorage.py --- a/swh/objstorage/multiplexer/multiplexer_objstorage.py +++ b/swh/objstorage/multiplexer/multiplexer_objstorage.py @@ -6,6 +6,7 @@ import queue import random import threading +from typing import Dict from swh.objstorage.exc import ObjNotFoundError from swh.objstorage.objstorage import ObjStorage @@ -251,7 +252,7 @@ continue return result - def add_batch(self, contents, check_presence=True): + def add_batch(self, contents, check_presence=True) -> Dict: """Add a batch of new objects to the object storage. """ diff --git a/swh/objstorage/multiplexer/striping_objstorage.py b/swh/objstorage/multiplexer/striping_objstorage.py --- a/swh/objstorage/multiplexer/striping_objstorage.py +++ b/swh/objstorage/multiplexer/striping_objstorage.py @@ -5,6 +5,7 @@ from collections import defaultdict import queue +from typing import Dict from swh.objstorage.multiplexer.multiplexer_objstorage import ( MultiplexerObjStorage, @@ -49,16 +50,16 @@ for i in range(self.num_storages): yield self.storage_threads[(idx + i) % self.num_storages] - def add_batch(self, contents, check_presence=True): + def add_batch(self, contents, check_presence=True) -> Dict: """Add a batch of new objects to the object storage. """ - content_by_storage_index = defaultdict(dict) + content_by_storage_index: Dict[bytes, Dict] = defaultdict(dict) for obj_id, content in contents.items(): storage_index = self.get_storage_index(obj_id) content_by_storage_index[storage_index][obj_id] = content - mailbox = queue.Queue() + mailbox: queue.Queue[Dict] = queue.Queue() for storage_index, contents in content_by_storage_index.items(): self.storage_threads[storage_index].queue_command( "add_batch", contents, check_presence=check_presence, mailbox=mailbox, diff --git a/swh/objstorage/objstorage.py b/swh/objstorage/objstorage.py --- a/swh/objstorage/objstorage.py +++ b/swh/objstorage/objstorage.py @@ -83,11 +83,11 @@ pass @abc.abstractmethod - def __contains__(self, obj_id, *args, **kwargs): + def __contains__(self, obj_id): pass @abc.abstractmethod - def add(self, content, obj_id=None, check_presence=True, *args, **kwargs): + def add(self, content, obj_id=None, check_presence=True): pass def add_batch(self, contents, check_presence=True) -> Dict: @@ -100,15 +100,15 @@ summary["object:add:bytes"] += len(content) return summary - def restore(self, content, obj_id=None, *args, **kwargs): + def restore(self, content, obj_id=None): # check_presence to false will erase the potential previous content. return self.add(content, obj_id, check_presence=False) @abc.abstractmethod - def get(self, obj_id, *args, **kwargs): + def get(self, obj_id): pass - def get_batch(self, obj_ids, *args, **kwargs): + def get_batch(self, obj_ids): for obj_id in obj_ids: try: yield self.get(obj_id) @@ -116,17 +116,17 @@ yield None @abc.abstractmethod - def check(self, obj_id, *args, **kwargs): + def check(self, obj_id): pass @abc.abstractmethod - def delete(self, obj_id, *args, **kwargs): + def delete(self, obj_id): if not self.allow_delete: raise PermissionError("Delete is not allowed.") # Management methods - def get_random(self, batch_size, *args, **kwargs): + def get_random(self, batch_size): pass # Streaming methods diff --git a/swh/objstorage/tests/objstorage_testing.py b/swh/objstorage/tests/objstorage_testing.py --- a/swh/objstorage/tests/objstorage_testing.py +++ b/swh/objstorage/tests/objstorage_testing.py @@ -4,13 +4,45 @@ # See top-level LICENSE file for more information from collections.abc import Iterator +import inspect import time from swh.objstorage import exc +from swh.objstorage.interface import ObjStorageInterface from swh.objstorage.objstorage import compute_hash class ObjStorageTestFixture: + def test_types(self): + """Checks all methods of ObjStorageInterface are implemented by this + backend, and that they have the same signature.""" + # Create an instance of the protocol (which cannot be instantiated + # directly, so this creates a subclass, then instantiates it) + interface = type("_", (ObjStorageInterface,), {})() + + assert "get_batch" in dir(interface) + + missing_methods = [] + + for meth_name in dir(interface): + if meth_name.startswith("_"): + continue + interface_meth = getattr(interface, meth_name) + concrete_meth = getattr(self.storage, meth_name) + + expected_signature = inspect.signature(interface_meth) + actual_signature = inspect.signature(concrete_meth) + + assert expected_signature == actual_signature, meth_name + + assert missing_methods == [] + + # If all the assertions above succeed, then this one should too. + # But there's no harm in double-checking. + # And we could replace the assertions above by this one, but unlike + # the assertions above, it doesn't explain what is missing. + assert isinstance(self.storage, ObjStorageInterface) + def hash_content(self, content): obj_id = compute_hash(content) return content, obj_id