diff --git a/swh/objstorage/api/client.py b/swh/objstorage/api/client.py --- a/swh/objstorage/api/client.py +++ b/swh/objstorage/api/client.py @@ -3,15 +3,15 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Iterator, Optional +from typing import Any, Dict, Iterator, Optional import msgpack from swh.core.api import RPCClient -from swh.model import hashutil from swh.objstorage.constants import DEFAULT_LIMIT from swh.objstorage.exc import Error, ObjNotFoundError, ObjStorageAPIError from swh.objstorage.interface import CompositeObjId, ObjId, ObjStorageInterface +from swh.objstorage.objstorage import objid_to_default_hex class RemoteObjStorage(RPCClient): @@ -42,9 +42,9 @@ last_obj_id: Optional[ObjId] = None, limit: int = DEFAULT_LIMIT, ) -> Iterator[CompositeObjId]: - params = {"limit": limit} + params: Dict[str, Any] = {"limit": limit} if last_obj_id: - params["last_obj_id"] = hashutil.hash_to_hex(last_obj_id) + params["last_obj_id"] = objid_to_default_hex(last_obj_id) response = self.raw_verb( "get", "content", @@ -52,4 +52,4 @@ params=params, stream=True, ) - yield from msgpack.Unpacker(response.raw, raw=True) + yield from msgpack.Unpacker(response.raw, raw=False) diff --git a/swh/objstorage/backends/azure.py b/swh/objstorage/backends/azure.py --- a/swh/objstorage/backends/azure.py +++ b/swh/objstorage/backends/azure.py @@ -18,6 +18,7 @@ generate_container_sas, ) from azure.storage.blob.aio import ContainerClient as AsyncContainerClient +from typing_extensions import Literal from swh.model import hashutil from swh.objstorage.exc import Error, ObjNotFoundError @@ -97,6 +98,8 @@ ``api_secret_key`` and ``container_name`` arguments are deprecated. """ + PRIMARY_HASH: Literal["sha1"] = "sha1" + def __init__( self, container_url: Optional[str] = None, @@ -195,7 +198,7 @@ """Iterate over the objects present in the storage.""" for client in self.get_all_container_clients(): for obj in client.list_blobs(): - yield hashutil.hash_to_bytes(obj.name) + yield {self.PRIMARY_HASH: hashutil.hash_to_bytes(obj.name)} def __len__(self): """Compute the number of objects in the current object storage. diff --git a/swh/objstorage/backends/http.py b/swh/objstorage/backends/http.py --- a/swh/objstorage/backends/http.py +++ b/swh/objstorage/backends/http.py @@ -17,6 +17,7 @@ ObjStorage, compute_hash, decompressors, + objid_to_default_hex, ) LOGGER = logging.getLogger(__name__) @@ -83,7 +84,7 @@ d = decompressors[self.compression]() ret = d.decompress(ret) if d.unused_data: - hex_obj_id = hashutil.hash_to_hex(obj_id) + hex_obj_id = objid_to_default_hex(obj_id) raise exc.Error("Corrupt object %s: trailing data found" % hex_obj_id) return ret diff --git a/swh/objstorage/backends/in_memory.py b/swh/objstorage/backends/in_memory.py --- a/swh/objstorage/backends/in_memory.py +++ b/swh/objstorage/backends/in_memory.py @@ -5,6 +5,8 @@ from typing import Iterator +from typing_extensions import Literal + from swh.objstorage.exc import Error, ObjNotFoundError from swh.objstorage.interface import CompositeObjId, ObjId from swh.objstorage.objstorage import ObjStorage, compute_hash, objid_to_default_hex @@ -17,6 +19,8 @@ """ + PRIMARY_HASH: Literal["sha1"] = "sha1" + def __init__(self, **args): super().__init__() self.state = {} @@ -28,7 +32,8 @@ return obj_id in self.state def __iter__(self) -> Iterator[CompositeObjId]: - return iter(sorted(self.state)) + for id_ in sorted(self.state): + yield {self.PRIMARY_HASH: id_} def add(self, content: bytes, obj_id: ObjId, check_presence: bool = True) -> None: if check_presence and obj_id in self: diff --git a/swh/objstorage/backends/libcloud.py b/swh/objstorage/backends/libcloud.py --- a/swh/objstorage/backends/libcloud.py +++ b/swh/objstorage/backends/libcloud.py @@ -1,4 +1,4 @@ -# Copyright (C) 2016-2017 The Software Heritage developers +# Copyright (C) 2016-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -11,6 +11,7 @@ from libcloud.storage import providers import libcloud.storage.drivers.s3 from libcloud.storage.types import ObjectDoesNotExistError, Provider +from typing_extensions import Literal from swh.model import hashutil from swh.objstorage.exc import Error, ObjNotFoundError @@ -20,6 +21,7 @@ compressors, compute_hash, decompressors, + objid_to_default_hex, ) @@ -58,6 +60,8 @@ kwargs: extra arguments are passed through to the LibCloud driver """ + PRIMARY_HASH: Literal["sha1"] = "sha1" + def __init__( self, container_name: str, @@ -142,7 +146,7 @@ if self.path_prefix: name = name[len(self.path_prefix) :] - yield hashutil.hash_to_bytes(name) + yield {self.PRIMARY_HASH: hashutil.hash_to_bytes(name)} def __len__(self): """Compute the number of objects in the current object storage. @@ -170,7 +174,7 @@ d = decompressors[self.compression]() ret = d.decompress(obj) if d.unused_data: - hex_obj_id = hashutil.hash_to_hex(obj_id) + hex_obj_id = objid_to_default_hex(obj_id) raise Error("Corrupt object %s: trailing data found" % hex_obj_id) return ret diff --git a/swh/objstorage/backends/pathslicing.py b/swh/objstorage/backends/pathslicing.py --- a/swh/objstorage/backends/pathslicing.py +++ b/swh/objstorage/backends/pathslicing.py @@ -9,6 +9,8 @@ import tempfile from typing import Iterator, List, Optional +from typing_extensions import Literal + from swh.model import hashutil from swh.objstorage.constants import DEFAULT_LIMIT, ID_HASH_ALGO, ID_HEXDIGEST_LENGTH from swh.objstorage.exc import Error, ObjNotFoundError @@ -152,6 +154,8 @@ """ + PRIMARY_HASH: Literal["sha1"] = "sha1" + def __init__(self, root, slicing, compression="gzip", **kwargs): super().__init__(**kwargs) self.root = root @@ -205,15 +209,12 @@ """ - def obj_iterator(): - # XXX hackish: it does not verify that the depth of found files - # matches the slicing depth of the storage - for root, _dirs, files in os.walk(self.root): - _dirs.sort() - for f in sorted(files): - yield bytes.fromhex(f) - - return obj_iterator() + # XXX hackish: it does not verify that the depth of found files + # matches the slicing depth of the storage + for root, _dirs, files in os.walk(self.root): + _dirs.sort() + for f in sorted(files): + yield {self.PRIMARY_HASH: bytes.fromhex(f)} def __len__(self) -> int: """Compute the number of objects available in the storage. @@ -329,7 +330,7 @@ dirs.remove(d) for f in sorted(files): if f > hex_obj_id: - yield bytes.fromhex(f) + yield {self.PRIMARY_HASH: bytes.fromhex(f)} if n_leaf: yield i diff --git a/swh/objstorage/backends/seaweedfs/objstorage.py b/swh/objstorage/backends/seaweedfs/objstorage.py --- a/swh/objstorage/backends/seaweedfs/objstorage.py +++ b/swh/objstorage/backends/seaweedfs/objstorage.py @@ -9,6 +9,8 @@ import os from typing import Iterator, Optional +from typing_extensions import Literal + from swh.model import hashutil from swh.objstorage.exc import Error, ObjNotFoundError from swh.objstorage.interface import CompositeObjId, ObjId @@ -32,6 +34,8 @@ https://github.com/chrislusf/seaweedfs/wiki/Filer-Server-API """ + PRIMARY_HASH: Literal["sha1"] = "sha1" + def __init__(self, url, compression=None, **kwargs): super().__init__(**kwargs) self.wf = HttpFiler(url) @@ -102,7 +106,7 @@ d = decompressors[self.compression]() ret = d.decompress(obj) if d.unused_data: - hex_obj_id = hashutil.hash_to_hex(obj_id) + hex_obj_id = objid_to_default_hex(obj_id) raise Error("Corrupt object %s: trailing data found" % hex_obj_id) return ret @@ -132,7 +136,7 @@ lastfilename = None for fname in islice(self.wf.iterfiles(last_file_name=lastfilename), limit): bytehex = fname.rsplit("/", 1)[-1] - yield hashutil.bytehex_to_hash(bytehex.encode()) + yield {self.PRIMARY_HASH: hashutil.bytehex_to_hash(bytehex.encode())} # internal methods def _put_object(self, content, obj_id): @@ -153,5 +157,5 @@ content = [content] self.wf.put(io.BytesIO(b"".join(compressor(content))), self._path(obj_id)) - def _path(self, obj_id): - return os.path.join(self.wf.basepath, hashutil.hash_to_hex(obj_id)) + def _path(self, obj_id: ObjId): + return os.path.join(self.wf.basepath, objid_to_default_hex(obj_id)) diff --git a/swh/objstorage/tests/objstorage_testing.py b/swh/objstorage/tests/objstorage_testing.py --- a/swh/objstorage/tests/objstorage_testing.py +++ b/swh/objstorage/tests/objstorage_testing.py @@ -181,14 +181,14 @@ sto_obj_ids = list(sto_obj_ids) self.assertFalse(sto_obj_ids) - obj_ids = set() + obj_ids = [] for i in range(100): content, obj_id = self.hash_content(b"content %d" % i) self.storage.add(content, obj_id=obj_id) - obj_ids.add(obj_id) + obj_ids.append({"sha1": obj_id}) - sto_obj_ids = set(self.storage) - self.assertEqual(sto_obj_ids, obj_ids) + sto_obj_ids = list(self.storage) + self.assertCountEqual(sto_obj_ids, obj_ids) def test_list_content(self): all_ids = [] @@ -196,8 +196,8 @@ content = b"example %d" % i obj_id = compute_hash(content) self.storage.add(content, obj_id) - all_ids.append(obj_id) - all_ids.sort() + all_ids.append({"sha1": obj_id}) + all_ids.sort(key=lambda d: d["sha1"]) ids = list(self.storage.list_content()) self.assertEqual(len(ids), 1200) diff --git a/swh/objstorage/tests/test_objstorage_pathslicing.py b/swh/objstorage/tests/test_objstorage_pathslicing.py --- a/swh/objstorage/tests/test_objstorage_pathslicing.py +++ b/swh/objstorage/tests/test_objstorage_pathslicing.py @@ -42,7 +42,9 @@ content, obj_id = self.hash_content(b"iter") self.assertEqual(list(iter(self.storage)), []) self.storage.add(content, obj_id=obj_id) - self.assertEqual(list(iter(self.storage)), [obj_id]) + self.assertEqual( + list(iter(self.storage)), [{self.storage.PRIMARY_HASH: obj_id}] + ) def test_len(self): content, obj_id = self.hash_content(b"len") @@ -74,8 +76,8 @@ for i in range(100): content, obj_id = self.hash_content(b"content %d" % i) self.storage.add(content, obj_id=obj_id) - all_ids.append(obj_id) - all_ids.sort() + all_ids.append({self.storage.PRIMARY_HASH: obj_id}) + all_ids.sort(key=lambda d: d[self.storage.PRIMARY_HASH]) ids = list(self.storage.iter_from(b"\x00" * ID_DIGEST_LENGTH)) self.assertEqual(len(ids), len(all_ids)) diff --git a/swh/objstorage/tests/test_readonly_filter.py b/swh/objstorage/tests/test_readonly_filter.py --- a/swh/objstorage/tests/test_readonly_filter.py +++ b/swh/objstorage/tests/test_readonly_filter.py @@ -57,8 +57,8 @@ self.assertFalse(self.absent_id in self.storage) def test_can_iter(self): - self.assertIn(self.valid_id, iter(self.storage)) - self.assertIn(self.invalid_id, iter(self.storage)) + self.assertIn({"sha1": self.valid_id}, iter(self.storage)) + self.assertIn({"sha1": self.invalid_id}, iter(self.storage)) def test_can_len(self): self.assertEqual(2, len(self.storage))