diff --git a/swh/objstorage/api/client.py b/swh/objstorage/api/client.py --- a/swh/objstorage/api/client.py +++ b/swh/objstorage/api/client.py @@ -3,12 +3,14 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from typing import Iterator, Optional + from swh.core.api import RPCClient from swh.core.utils import iter_chunks from swh.model import hashutil +from swh.objstorage.constants import DEFAULT_LIMIT, ID_DIGEST_LENGTH from swh.objstorage.exc import Error, ObjNotFoundError, ObjStorageAPIError -from swh.objstorage.interface import ObjStorageInterface -from swh.objstorage.objstorage import DEFAULT_LIMIT, ID_DIGEST_LENGTH +from swh.objstorage.interface import ObjId, ObjStorageInterface class RemoteObjStorage(RPCClient): @@ -28,13 +30,17 @@ reraise_exceptions = [ObjNotFoundError, Error] backend_class = ObjStorageInterface - def restore(self, content, obj_id): + def restore(self: ObjStorageInterface, content: bytes, obj_id: ObjId): return self.add(content, obj_id, check_presence=False) def __iter__(self): yield from self.list_content() - def list_content(self, last_obj_id=None, limit=DEFAULT_LIMIT): + def list_content( + self, + last_obj_id: Optional[ObjId] = None, + limit: int = DEFAULT_LIMIT, + ) -> Iterator[ObjId]: params = {"limit": limit} if last_obj_id: params["last_obj_id"] = hashutil.hash_to_hex(last_obj_id) diff --git a/swh/objstorage/backends/azure.py b/swh/objstorage/backends/azure.py --- a/swh/objstorage/backends/azure.py +++ b/swh/objstorage/backends/azure.py @@ -8,7 +8,7 @@ import datetime from itertools import product import string -from typing import Dict, Optional, Union +from typing import Dict, Iterator, List, Optional, Union import warnings from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError @@ -21,6 +21,7 @@ from swh.model import hashutil from swh.objstorage.exc import Error, ObjNotFoundError +from swh.objstorage.interface import ObjId from swh.objstorage.objstorage import ( ObjStorage, compressors, @@ -205,7 +206,7 @@ """ return sum(1 for i in self) - def add(self, content, obj_id, check_presence=True): + def add(self, content: bytes, obj_id: ObjId, check_presence: bool = True) -> ObjId: """Add an obj in storage if it's not there already.""" if check_presence and obj_id in self: return obj_id @@ -229,14 +230,14 @@ return obj_id - def restore(self, content, obj_id): + def restore(self, content: bytes, obj_id: ObjId): """Restore a content.""" if obj_id in self: self.delete(obj_id) return self.add(content, obj_id, check_presence=False) - def get(self, obj_id): + def get(self, obj_id: ObjId) -> bytes: """retrieve blob's content if found.""" return call_async(self._get_async, obj_id) @@ -286,18 +287,18 @@ ] ) - def get_batch(self, obj_ids): + def get_batch(self, obj_ids: List[ObjId]) -> Iterator[Optional[bytes]]: """Retrieve objects' raw content in bulk from storage, concurrently.""" return call_async(self._get_batch_async, obj_ids) - def check(self, obj_id): + def check(self, obj_id: ObjId) -> None: """Check the content integrity.""" obj_content = self.get(obj_id) content_obj_id = compute_hash(obj_content) if content_obj_id != obj_id: raise Error(obj_id) - def delete(self, obj_id): + def delete(self, obj_id: ObjId): """Delete an object.""" super().delete(obj_id) # Check delete permission hex_obj_id = self._internal_id(obj_id) diff --git a/swh/objstorage/backends/generator.py b/swh/objstorage/backends/generator.py --- a/swh/objstorage/backends/generator.py +++ b/swh/objstorage/backends/generator.py @@ -1,7 +1,9 @@ from itertools import count, islice, repeat import logging import random +from typing import Iterator, Optional +from swh.objstorage.interface import ObjId from swh.objstorage.objstorage import DEFAULT_LIMIT, ObjStorage logger = logging.getLogger(__name__) @@ -211,7 +213,11 @@ def delete(self, obj_id, *args, **kwargs): return True - def list_content(self, last_obj_id=None, limit=DEFAULT_LIMIT): + def list_content( + self, + last_obj_id: Optional[ObjId] = None, + limit: int = DEFAULT_LIMIT, + ) -> Iterator[ObjId]: it = iter(self) if last_obj_id: next(it) diff --git a/swh/objstorage/backends/http.py b/swh/objstorage/backends/http.py --- a/swh/objstorage/backends/http.py +++ b/swh/objstorage/backends/http.py @@ -4,12 +4,14 @@ # See top-level LICENSE file for more information import logging +from typing import Iterator, Optional from urllib.parse import urljoin import requests from swh.model import hashutil from swh.objstorage import exc +from swh.objstorage.interface import ObjId from swh.objstorage.objstorage import ( DEFAULT_LIMIT, ObjStorage, @@ -53,19 +55,23 @@ def __len__(self): raise exc.NonIterableObjStorage("__len__") - def add(self, content, obj_id, check_presence=True): + def add(self, content: bytes, obj_id: ObjId, check_presence: bool = True) -> ObjId: raise exc.ReadOnlyObjStorage("add") - def delete(self, obj_id): + def delete(self, obj_id: ObjId): raise exc.ReadOnlyObjStorage("delete") - def restore(self, content, obj_id): + def restore(self, content: bytes, obj_id: ObjId): raise exc.ReadOnlyObjStorage("restore") - def list_content(self, last_obj_id=None, limit=DEFAULT_LIMIT): + def list_content( + self, + last_obj_id: Optional[ObjId] = None, + limit: int = DEFAULT_LIMIT, + ) -> Iterator[ObjId]: raise exc.NonIterableObjStorage("__len__") - def get(self, obj_id): + def get(self, obj_id: ObjId) -> bytes: try: resp = self.session.get(self._path(obj_id)) resp.raise_for_status() @@ -81,7 +87,7 @@ raise exc.Error("Corrupt object %s: trailing data found" % hex_obj_id) return ret - def check(self, obj_id): + def check(self, obj_id: ObjId) -> None: # Check the content integrity obj_content = self.get(obj_id) content_obj_id = compute_hash(obj_content) diff --git a/swh/objstorage/backends/in_memory.py b/swh/objstorage/backends/in_memory.py --- a/swh/objstorage/backends/in_memory.py +++ b/swh/objstorage/backends/in_memory.py @@ -4,6 +4,7 @@ # See top-level LICENSE file for more information from swh.objstorage.exc import Error, ObjNotFoundError +from swh.objstorage.interface import ObjId from swh.objstorage.objstorage import ObjStorage, compute_hash @@ -27,7 +28,7 @@ def __iter__(self): return iter(sorted(self.state)) - def add(self, content, obj_id, check_presence=True): + def add(self, content: bytes, obj_id: ObjId, check_presence: bool = True) -> ObjId: if check_presence and obj_id in self: return obj_id @@ -35,20 +36,19 @@ return obj_id - def get(self, obj_id): + def get(self, obj_id: ObjId) -> bytes: if obj_id not in self: raise ObjNotFoundError(obj_id) return self.state[obj_id] - def check(self, obj_id): + def check(self, obj_id: ObjId) -> None: if obj_id not in self: raise ObjNotFoundError(obj_id) if compute_hash(self.state[obj_id]) != obj_id: - raise Error("Corrupt object %s" % obj_id) - return True + raise Error("Corrupt object %s" % obj_id.hex()) - def delete(self, obj_id): + def delete(self, obj_id: ObjId): super().delete(obj_id) # Check delete permission if obj_id not in self: raise ObjNotFoundError(obj_id) diff --git a/swh/objstorage/backends/libcloud.py b/swh/objstorage/backends/libcloud.py --- a/swh/objstorage/backends/libcloud.py +++ b/swh/objstorage/backends/libcloud.py @@ -15,6 +15,7 @@ from swh.model import hashutil from swh.objstorage.exc import Error, ObjNotFoundError +from swh.objstorage.interface import ObjId from swh.objstorage.objstorage import ( ObjStorage, compressors, @@ -61,7 +62,7 @@ def __init__( self, container_name: str, - compression: Optional[str] = None, + compression: str = "gzip", path_prefix: Optional[str] = None, **kwargs, ): @@ -156,17 +157,17 @@ """ return sum(1 for i in self) - def add(self, content, obj_id, check_presence=True): + def add(self, content: bytes, obj_id: ObjId, check_presence: bool = True) -> ObjId: if check_presence and obj_id in self: return obj_id self._put_object(content, obj_id) return obj_id - def restore(self, content, obj_id): + def restore(self, content: bytes, obj_id: ObjId): return self.add(content, obj_id, check_presence=False) - def get(self, obj_id): + def get(self, obj_id: ObjId) -> bytes: obj = b"".join(self._get_object(obj_id).as_stream()) d = decompressors[self.compression]() ret = d.decompress(obj) @@ -175,7 +176,7 @@ raise Error("Corrupt object %s: trailing data found" % hex_obj_id) return ret - def check(self, obj_id): + def check(self, obj_id: ObjId) -> None: # Check that the file exists, as _get_object raises ObjNotFoundError self._get_object(obj_id) # Check the content integrity @@ -184,7 +185,7 @@ if content_obj_id != obj_id: raise Error(obj_id) - def delete(self, obj_id): + def delete(self, obj_id: ObjId): super().delete(obj_id) # Check delete permission obj = self._get_object(obj_id) return self.driver.delete_object(obj) diff --git a/swh/objstorage/backends/pathslicing.py b/swh/objstorage/backends/pathslicing.py --- a/swh/objstorage/backends/pathslicing.py +++ b/swh/objstorage/backends/pathslicing.py @@ -1,26 +1,20 @@ -# Copyright (C) 2015-2019 The Software Heritage developers +# Copyright (C) 2015-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from collections.abc import Iterator from contextlib import contextmanager from itertools import islice import os import random import tempfile -from typing import List +from typing import Iterable, Iterator, List, Optional from swh.model import hashutil +from swh.objstorage.constants import DEFAULT_LIMIT, ID_HASH_ALGO, ID_HEXDIGEST_LENGTH from swh.objstorage.exc import Error, ObjNotFoundError -from swh.objstorage.objstorage import ( - DEFAULT_LIMIT, - ID_HASH_ALGO, - ID_HEXDIGEST_LENGTH, - ObjStorage, - compressors, - decompressors, -) +from swh.objstorage.interface import ObjId +from swh.objstorage.objstorage import ObjStorage, compressors, decompressors BUFSIZ = 1048576 @@ -188,11 +182,11 @@ return True - def __contains__(self, obj_id): + def __contains__(self, obj_id: ObjId) -> bool: hex_obj_id = hashutil.hash_to_hex(obj_id) return os.path.isfile(self.slicer.get_path(hex_obj_id)) - def __iter__(self): + def __iter__(self) -> Iterator[bytes]: """Iterate over the object identifiers currently available in the storage. @@ -217,7 +211,7 @@ return obj_iterator() - def __len__(self): + def __len__(self) -> int: """Compute the number of objects available in the storage. Warning: this currently uses `__iter__`, its warning about bad @@ -228,23 +222,25 @@ """ return sum(1 for i in self) - def add(self, content, obj_id, check_presence=True): + def add( + self, + content: bytes, + obj_id: ObjId, + check_presence: bool = True, + ) -> ObjId: if check_presence and obj_id in self: # If the object is already present, return immediately. return obj_id hex_obj_id = hashutil.hash_to_hex(obj_id) - if not isinstance(content, Iterator): - content = [content] compressor = compressors[self.compression]() with self._write_obj_file(hex_obj_id) as f: - for chunk in content: - f.write(compressor.compress(chunk)) + f.write(compressor.compress(content)) f.write(compressor.flush()) return obj_id - def get(self, obj_id): + def get(self, obj_id: ObjId) -> bytes: if obj_id not in self: raise ObjNotFoundError(obj_id) @@ -260,7 +256,7 @@ return out - def check(self, obj_id): + def check(self, obj_id: ObjId) -> None: try: data = self.get(obj_id) except OSError: @@ -282,7 +278,7 @@ % (hashutil.hash_to_hex(obj_id), hashutil.hash_to_hex(actual_obj_id)) ) - def delete(self, obj_id): + def delete(self, obj_id: ObjId): super().delete(obj_id) # Check delete permission if obj_id not in self: raise ObjNotFoundError(obj_id) @@ -296,7 +292,7 @@ # Management methods - def get_random(self, batch_size): + def get_random(self, batch_size: int) -> Iterable[ObjId]: def get_random_content(self, batch_size): """Get a batch of content inside a single directory. @@ -334,7 +330,9 @@ yield lambda c: f.write(compressor.compress(c)) f.write(compressor.flush()) - def list_content(self, last_obj_id=None, limit=DEFAULT_LIMIT): + def list_content( + self, last_obj_id: Optional[ObjId] = None, limit: int = DEFAULT_LIMIT + ) -> Iterator[ObjId]: if last_obj_id: it = self.iter_from(last_obj_id) else: diff --git a/swh/objstorage/backends/seaweedfs/objstorage.py b/swh/objstorage/backends/seaweedfs/objstorage.py --- a/swh/objstorage/backends/seaweedfs/objstorage.py +++ b/swh/objstorage/backends/seaweedfs/objstorage.py @@ -7,9 +7,11 @@ from itertools import islice import logging import os +from typing import Iterator, Optional from swh.model import hashutil from swh.objstorage.exc import Error, ObjNotFoundError +from swh.objstorage.interface import ObjId from swh.objstorage.objstorage import ( DEFAULT_LIMIT, ObjStorage, @@ -72,27 +74,26 @@ """ return sum(1 for i in self) - def add(self, content, obj_id, check_presence=True): + def add(self, content: bytes, obj_id: ObjId, check_presence: bool = True) -> ObjId: if check_presence and obj_id in self: return obj_id def compressor(data): comp = compressors[self.compression]() - for chunk in data: - yield comp.compress(chunk) + yield comp.compress(data) yield comp.flush() - if isinstance(content, bytes): - content = [content] + assert isinstance( + content, bytes + ), "list of content chunks is not supported anymore" - # XXX should handle streaming correctly... self.wf.put(io.BytesIO(b"".join(compressor(content))), self._path(obj_id)) return obj_id - def restore(self, content, obj_id): + def restore(self, content: bytes, obj_id: ObjId): return self.add(content, obj_id, check_presence=False) - def get(self, obj_id): + def get(self, obj_id: ObjId) -> bytes: try: obj = self.wf.get(self._path(obj_id)) except Exception: @@ -105,21 +106,25 @@ raise Error("Corrupt object %s: trailing data found" % hex_obj_id) return ret - def check(self, obj_id): + def check(self, obj_id: ObjId) -> None: # Check the content integrity obj_content = self.get(obj_id) content_obj_id = compute_hash(obj_content) if content_obj_id != obj_id: raise Error(obj_id) - def delete(self, obj_id): + def delete(self, obj_id: ObjId): super().delete(obj_id) # Check delete permission if obj_id not in self: raise ObjNotFoundError(obj_id) self.wf.delete(self._path(obj_id)) return True - def list_content(self, last_obj_id=None, limit=DEFAULT_LIMIT): + def list_content( + self, + last_obj_id: Optional[ObjId] = None, + limit: int = DEFAULT_LIMIT, + ) -> Iterator[ObjId]: if last_obj_id: objid = hashutil.hash_to_hex(last_obj_id) lastfilename = objid diff --git a/swh/objstorage/backends/winery/objstorage.py b/swh/objstorage/backends/winery/objstorage.py --- a/swh/objstorage/backends/winery/objstorage.py +++ b/swh/objstorage/backends/winery/objstorage.py @@ -7,6 +7,7 @@ from multiprocessing import Process from swh.objstorage import exc +from swh.objstorage.interface import ObjId from swh.objstorage.objstorage import ObjStorage from .roshard import ROShard @@ -28,7 +29,7 @@ def uninit(self): self.winery.uninit() - def get(self, obj_id): + def get(self, obj_id: ObjId) -> bytes: return self.winery.get(obj_id) def check_config(self, *, check_write): @@ -37,13 +38,13 @@ def __contains__(self, obj_id): return obj_id in self.winery - def add(self, content, obj_id, check_presence=True): + def add(self, content: bytes, obj_id: ObjId, check_presence: bool = True) -> ObjId: return self.winery.add(content, obj_id, check_presence) - def check(self, obj_id): + def check(self, obj_id: ObjId) -> None: return self.winery.check(obj_id) - def delete(self, obj_id): + def delete(self, obj_id: ObjId): raise PermissionError("Delete is not allowed.") @@ -74,7 +75,7 @@ self.shards[name] = shard return self.shards[name] - def get(self, obj_id): + def get(self, obj_id: ObjId) -> bytes: shard_info = self.base.get(obj_id) if shard_info is None: raise exc.ObjNotFoundError(obj_id) @@ -140,7 +141,7 @@ self.shard.uninit() super().uninit() - def add(self, content, obj_id, check_presence=True): + def add(self, content: bytes, obj_id: ObjId, check_presence: bool = True) -> ObjId: if check_presence and obj_id in self: return obj_id @@ -157,7 +158,7 @@ return obj_id - def check(self, obj_id): + def check(self, obj_id: ObjId) -> None: # load all shards packing == True and not locked (i.e. packer # was interrupted for whatever reason) run pack for each of them pass diff --git a/swh/objstorage/constants.py b/swh/objstorage/constants.py new file mode 100644 --- /dev/null +++ b/swh/objstorage/constants.py @@ -0,0 +1,17 @@ +# Copyright (C) 2015-2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from typing_extensions import Literal + +ID_HASH_ALGO: Literal["sha1"] = "sha1" + +ID_HEXDIGEST_LENGTH = 40 +"""Size in bytes of the hash hexadecimal representation.""" + +ID_DIGEST_LENGTH = 20 +"""Size in bytes of the hash""" + +DEFAULT_LIMIT = 10000 +"""Default number of results of ``list_content``.""" diff --git a/swh/objstorage/factory.py b/swh/objstorage/factory.py --- a/swh/objstorage/factory.py +++ b/swh/objstorage/factory.py @@ -15,7 +15,7 @@ from swh.objstorage.backends.seaweedfs import SeaweedFilerObjStorage from swh.objstorage.multiplexer import MultiplexerObjStorage, StripingObjStorage from swh.objstorage.multiplexer.filter import add_filters -from swh.objstorage.objstorage import ID_HEXDIGEST_LENGTH, ObjStorage # noqa +from swh.objstorage.objstorage import ObjStorage __all__ = ["get_objstorage", "ObjStorage"] diff --git a/swh/objstorage/interface.py b/swh/objstorage/interface.py --- a/swh/objstorage/interface.py +++ b/swh/objstorage/interface.py @@ -3,12 +3,15 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Dict +from typing import Dict, Iterable, Iterator, List, Optional from typing_extensions import Protocol, runtime_checkable from swh.core.api import remote_api_endpoint -from swh.objstorage.objstorage import DEFAULT_LIMIT +from swh.objstorage.constants import DEFAULT_LIMIT + +ObjId = bytes +"""Type of object ids, which should be a sha1 hash.""" @runtime_checkable @@ -48,11 +51,11 @@ ... @remote_api_endpoint("content/contains") - def __contains__(self, obj_id): + def __contains__(self, obj_id: ObjId) -> bool: """Indicate if the given object is present in the storage. Args: - obj_id (bytes): object identifier. + obj_id: object identifier. Returns: True if and only if the object is present in the current object @@ -62,12 +65,12 @@ ... @remote_api_endpoint("content/add") - def add(self, content, obj_id, check_presence=True): + def add(self, content: bytes, obj_id: ObjId, check_presence: bool = True) -> ObjId: """Add a new object to the object storage. Args: - content (bytes): object's raw content to add in storage. - obj_id (bytes): checksum of [bytes] using [ID_HASH_ALGO] + content: object's raw content to add in storage. + obj_id: checksum of [bytes] using [ID_HASH_ALGO] algorithm. It is trusted to match the bytes. check_presence (bool): indicate if the presence of the content should be verified before adding the file. @@ -92,7 +95,7 @@ """ ... - def restore(self, content, obj_id): + def restore(self, content: bytes, obj_id: ObjId): """Restore a content that have been corrupted. This function is identical to add but does not check if @@ -101,21 +104,17 @@ suitable for most cases. Args: - content (bytes): object's raw content to add in storage - obj_id (bytes): checksum of `bytes` as computed by - ID_HASH_ALGO. When given, obj_id will be trusted to - match bytes. If missing, obj_id will be computed on - the fly. - + content: object's raw content to add in storage + obj_id: dict of hashes of the content (or only the sha1, for legacy clients) """ ... @remote_api_endpoint("content/get") - def get(self, obj_id): + def get(self, obj_id: ObjId) -> bytes: """Retrieve the content of a given object. Args: - obj_id (bytes): object id. + obj_id: object id. Returns: the content of the requested object as bytes. @@ -127,7 +126,7 @@ ... @remote_api_endpoint("content/get/batch") - def get_batch(self, obj_ids): + def get_batch(self, obj_ids: List[ObjId]) -> Iterator[Optional[bytes]]: """Retrieve objects' raw content in bulk from storage. Note: This function does have a default implementation in @@ -138,7 +137,7 @@ can be overridden to perform a more efficient operation. Args: - obj_ids ([bytes]: list of object ids. + obj_ids: list of object ids. Returns: list of resulting contents, or None if the content could @@ -149,14 +148,14 @@ ... @remote_api_endpoint("content/check") - def check(self, obj_id): + def check(self, obj_id: ObjId) -> None: """Perform an integrity check for a given object. Verify that the file object is in place and that the content matches the object id. Args: - obj_id (bytes): object identifier. + obj_id: object identifier. Raises: ObjNotFoundError: if the requested object is missing. @@ -166,11 +165,11 @@ ... @remote_api_endpoint("content/delete") - def delete(self, obj_id): + def delete(self, obj_id: ObjId): """Delete an object. Args: - obj_id (bytes): object identifier. + obj_id: object identifier. Raises: ObjNotFoundError: if the requested object is missing. @@ -181,34 +180,35 @@ # Management methods @remote_api_endpoint("content/get/random") - def get_random(self, batch_size): + def get_random(self, batch_size: int) -> Iterable[ObjId]: """Get random ids of existing contents. This method is used in order to get random ids to perform content integrity verifications on random contents. Args: - batch_size (int): Number of ids that will be given + batch_size: Number of ids that will be given Yields: - An iterable of ids (bytes) of contents that are in the - current object storage. + ids of contents that are in the current object storage. """ ... - def __iter__(self): + def __iter__(self) -> Iterator[ObjId]: ... - def list_content(self, last_obj_id=None, limit=DEFAULT_LIMIT): + def list_content( + self, last_obj_id: Optional[ObjId] = None, limit: int = DEFAULT_LIMIT + ) -> Iterator[ObjId]: """Generates known object ids. Args: - last_obj_id (bytes): object id from which to iterate from + last_obj_id: object id from which to iterate from (excluded). limit (int): max number of object ids to generate. Generates: - obj_id (bytes): object ids. + obj_id: object ids. """ ... diff --git a/swh/objstorage/multiplexer/multiplexer_objstorage.py b/swh/objstorage/multiplexer/multiplexer_objstorage.py --- a/swh/objstorage/multiplexer/multiplexer_objstorage.py +++ b/swh/objstorage/multiplexer/multiplexer_objstorage.py @@ -6,9 +6,10 @@ import queue import random import threading -from typing import Dict +from typing import Dict, Iterable from swh.objstorage.exc import ObjNotFoundError +from swh.objstorage.interface import ObjId from swh.objstorage.objstorage import ObjStorage @@ -222,7 +223,7 @@ return obj_iterator() - def add(self, content, obj_id, check_presence=True): + def add(self, content: bytes, obj_id: ObjId, check_presence: bool = True) -> ObjId: """Add a new object to the object storage. If the adding step works in all the storages that accept this content, @@ -255,6 +256,8 @@ continue return result + assert False, "No backend objstorage configured" + def add_batch(self, contents, check_presence=True) -> Dict: """Add a batch of new objects to the object storage.""" write_threads = list(self.get_write_threads()) @@ -275,7 +278,7 @@ "object:add:bytes": summed["object:add:bytes"] // len(results), } - def restore(self, content, obj_id): + def restore(self, content: bytes, obj_id: ObjId): return self.wrap_call( self.get_write_threads(obj_id), "restore", @@ -283,7 +286,7 @@ obj_id=obj_id, ).pop() - def get(self, obj_id): + def get(self, obj_id: ObjId) -> bytes: for storage in self.get_read_threads(obj_id): try: return storage.get(obj_id) @@ -292,7 +295,7 @@ # If no storage contains this content, raise the error raise ObjNotFoundError(obj_id) - def check(self, obj_id): + def check(self, obj_id: ObjId) -> None: nb_present = 0 for storage in self.get_read_threads(obj_id): try: @@ -308,11 +311,11 @@ if nb_present == 0: raise ObjNotFoundError(obj_id) - def delete(self, obj_id): + def delete(self, obj_id: ObjId): super().delete(obj_id) # Check delete permission return all(self.wrap_call(self.get_write_threads(obj_id), "delete", obj_id)) - def get_random(self, batch_size): + def get_random(self, batch_size: int) -> Iterable[ObjId]: storages_set = [storage for storage in self.storages if len(storage) > 0] if len(storages_set) <= 0: return [] diff --git a/swh/objstorage/objstorage.py b/swh/objstorage/objstorage.py --- a/swh/objstorage/objstorage.py +++ b/swh/objstorage/objstorage.py @@ -7,23 +7,14 @@ import bz2 from itertools import dropwhile, islice import lzma -from typing import Dict +from typing import Callable, Dict, Iterable, Iterator, List, Optional import zlib from swh.model import hashutil +from .constants import DEFAULT_LIMIT, ID_HASH_ALGO from .exc import ObjNotFoundError - -ID_HASH_ALGO = "sha1" - -ID_HEXDIGEST_LENGTH = 40 -"""Size in bytes of the hash hexadecimal representation.""" - -ID_DIGEST_LENGTH = 20 -"""Size in bytes of the hash""" - -DEFAULT_LIMIT = 10000 -"""Default number of results of ``list_content``.""" +from .interface import ObjId, ObjStorageInterface def compute_hash(content, algo=ID_HASH_ALGO): @@ -56,28 +47,43 @@ class NullDecompressor: - def decompress(self, data): + def decompress(self, data: bytes) -> bytes: return data @property - def unused_data(self): + def unused_data(self) -> bytes: return b"" -decompressors = { - "bz2": bz2.BZ2Decompressor, - "lzma": lzma.LZMADecompressor, - "gzip": lambda: zlib.decompressobj(wbits=31), - "zlib": zlib.decompressobj, - "none": NullDecompressor, +class _CompressorProtocol: + def compress(self, data: bytes) -> bytes: + ... + + def flush(self) -> bytes: + ... + + +class _DecompressorProtocol: + def decompress(self, data: bytes) -> bytes: + ... + + unused_data: bytes + + +decompressors: Dict[str, Callable[[], _DecompressorProtocol]] = { + "bz2": bz2.BZ2Decompressor, # type: ignore + "lzma": lzma.LZMADecompressor, # type: ignore + "gzip": lambda: zlib.decompressobj(wbits=31), # type: ignore + "zlib": zlib.decompressobj, # type: ignore + "none": NullDecompressor, # type: ignore } -compressors = { - "bz2": bz2.BZ2Compressor, - "lzma": lzma.LZMACompressor, - "gzip": lambda: zlib.compressobj(wbits=31), - "zlib": zlib.compressobj, - "none": NullCompressor, +compressors: Dict[str, Callable[[], _CompressorProtocol]] = { + "bz2": bz2.BZ2Compressor, # type: ignore + "lzma": lzma.LZMACompressor, # type: ignore + "gzip": lambda: zlib.compressobj(wbits=31), # type: ignore + "zlib": zlib.compressobj, # type: ignore + "none": NullCompressor, # type: ignore } @@ -87,19 +93,7 @@ # it becomes needed self.allow_delete = allow_delete - @abc.abstractmethod - def check_config(self, *, check_write): - pass - - @abc.abstractmethod - def __contains__(self, obj_id): - pass - - @abc.abstractmethod - def add(self, content, obj_id, check_presence=True): - pass - - def add_batch(self, contents, check_presence=True) -> Dict: + def add_batch(self: ObjStorageInterface, contents, check_presence=True) -> Dict: summary = {"object:add": 0, "object:add:bytes": 0} for obj_id, content in contents.items(): if check_presence and obj_id in self: @@ -109,15 +103,13 @@ summary["object:add:bytes"] += len(content) return summary - def restore(self, content, obj_id): + def restore(self: ObjStorageInterface, content: bytes, obj_id: ObjId): # check_presence to false will erase the potential previous content. return self.add(content, obj_id, check_presence=False) - @abc.abstractmethod - def get(self, obj_id): - pass - - def get_batch(self, obj_ids): + def get_batch( + self: ObjStorageInterface, obj_ids: List[ObjId] + ) -> Iterator[Optional[bytes]]: for obj_id in obj_ids: try: yield self.get(obj_id) @@ -125,23 +117,19 @@ yield None @abc.abstractmethod - def check(self, obj_id): - pass - - @abc.abstractmethod - def delete(self, obj_id): + def delete(self, obj_id: ObjId): if not self.allow_delete: raise PermissionError("Delete is not allowed.") - # Management methods - - def get_random(self, batch_size): + def get_random(self, batch_size: int) -> Iterable[ObjId]: pass - # Streaming methods - - def list_content(self, last_obj_id=None, limit=DEFAULT_LIMIT): + def list_content( + self: ObjStorageInterface, + last_obj_id: Optional[ObjId] = None, + limit: int = DEFAULT_LIMIT, + ) -> Iterator[ObjId]: it = iter(self) - if last_obj_id: - it = dropwhile(lambda x: x <= last_obj_id, it) + if last_obj_id is not None: + it = dropwhile(last_obj_id.__ge__, it) return islice(it, limit) diff --git a/swh/objstorage/tests/test_objstorage_pathslicing.py b/swh/objstorage/tests/test_objstorage_pathslicing.py --- a/swh/objstorage/tests/test_objstorage_pathslicing.py +++ b/swh/objstorage/tests/test_objstorage_pathslicing.py @@ -10,8 +10,8 @@ from swh.model import hashutil from swh.objstorage import exc +from swh.objstorage.constants import ID_DIGEST_LENGTH from swh.objstorage.factory import get_objstorage -from swh.objstorage.objstorage import ID_DIGEST_LENGTH from .objstorage_testing import ObjStorageTestFixture