diff --git a/swh/objstorage/__init__.py b/swh/objstorage/__init__.py index 498e16e..912edc9 100644 --- a/swh/objstorage/__init__.py +++ b/swh/objstorage/__init__.py @@ -1,107 +1,12 @@ -# Copyright (C) 2016 The Software Heritage developers +# Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from swh.objstorage.objstorage import ObjStorage, ID_HASH_LENGTH # noqa -from swh.objstorage.backends.pathslicing import PathSlicingObjStorage -from swh.objstorage.backends.in_memory import InMemoryObjStorage -from swh.objstorage.api.client import RemoteObjStorage -from swh.objstorage.multiplexer import MultiplexerObjStorage, StripingObjStorage -from swh.objstorage.multiplexer.filter import add_filters -from swh.objstorage.backends.seaweed import WeedObjStorage -from swh.objstorage.backends.generator import RandomGeneratorObjStorage +from typing import Iterable +from pkgutil import extend_path -from typing import Callable, Dict, Union +__path__: Iterable[str] = extend_path(__path__, __name__) -__all__ = ["get_objstorage", "ObjStorage"] - - -_STORAGE_CLASSES: Dict[str, Union[type, Callable[..., type]]] = { - "pathslicing": PathSlicingObjStorage, - "remote": RemoteObjStorage, - "memory": InMemoryObjStorage, - "weed": WeedObjStorage, - "random": RandomGeneratorObjStorage, -} - -_STORAGE_CLASSES_MISSING = {} - -try: - from swh.objstorage.backends.azure import ( - AzureCloudObjStorage, - PrefixedAzureCloudObjStorage, - ) - - _STORAGE_CLASSES["azure"] = AzureCloudObjStorage - _STORAGE_CLASSES["azure-prefixed"] = PrefixedAzureCloudObjStorage -except ImportError as e: - _STORAGE_CLASSES_MISSING["azure"] = e.args[0] - _STORAGE_CLASSES_MISSING["azure-prefixed"] = e.args[0] - -try: - from swh.objstorage.backends.rados import RADOSObjStorage - - _STORAGE_CLASSES["rados"] = RADOSObjStorage -except ImportError as e: - _STORAGE_CLASSES_MISSING["rados"] = e.args[0] - -try: - from swh.objstorage.backends.libcloud import ( - AwsCloudObjStorage, - OpenStackCloudObjStorage, - ) - - _STORAGE_CLASSES["s3"] = AwsCloudObjStorage - _STORAGE_CLASSES["swift"] = OpenStackCloudObjStorage -except ImportError as e: - _STORAGE_CLASSES_MISSING["s3"] = e.args[0] - _STORAGE_CLASSES_MISSING["swift"] = e.args[0] - - -def get_objstorage(cls, args): - """ Create an ObjStorage using the given implementation class. - - Args: - cls (str): objstorage class unique key contained in the - _STORAGE_CLASSES dict. - args (dict): arguments for the required class of objstorage - that must match exactly the one in the `__init__` method of the - class. - Returns: - subclass of ObjStorage that match the given `storage_class` argument. - Raises: - ValueError: if the given storage class is not a valid objstorage - key. - """ - if cls in _STORAGE_CLASSES: - return _STORAGE_CLASSES[cls](**args) - else: - raise ValueError( - "Storage class {} is not available: {}".format( - cls, _STORAGE_CLASSES_MISSING.get(cls, "unknown name") - ) - ) - - -def _construct_filtered_objstorage(storage_conf, filters_conf): - return add_filters(get_objstorage(**storage_conf), filters_conf) - - -_STORAGE_CLASSES["filtered"] = _construct_filtered_objstorage - - -def _construct_multiplexer_objstorage(objstorages): - storages = [get_objstorage(**conf) for conf in objstorages] - return MultiplexerObjStorage(storages) - - -_STORAGE_CLASSES["multiplexer"] = _construct_multiplexer_objstorage - - -def _construct_striping_objstorage(objstorages): - storages = [get_objstorage(**conf) for conf in objstorages] - return StripingObjStorage(storages) - - -_STORAGE_CLASSES["striping"] = _construct_striping_objstorage +# for BW compat +from swh.objstorage.factory import * # noqa diff --git a/swh/objstorage/api/server.py b/swh/objstorage/api/server.py index 517b004..30cffe6 100644 --- a/swh/objstorage/api/server.py +++ b/swh/objstorage/api/server.py @@ -1,270 +1,270 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import aiohttp.web import json from swh.core.config import read as config_read from swh.core.api.asynchronous import ( RPCServerApp, decode_request, encode_data_server as encode_data, ) from swh.core.api.serializers import msgpack_loads, SWHJSONDecoder from swh.model import hashutil -from swh.objstorage import get_objstorage +from swh.objstorage.factory import get_objstorage from swh.objstorage.objstorage import DEFAULT_LIMIT from swh.objstorage.exc import Error, ObjNotFoundError from swh.core.statsd import statsd def timed(f): async def w(*a, **kw): with statsd.timed( "swh_objstorage_request_duration_seconds", tags={"endpoint": f.__name__} ): return await f(*a, **kw) return w @timed async def index(request): return aiohttp.web.Response(body="SWH Objstorage API server") @timed async def check_config(request): req = await decode_request(request) return encode_data(request.app["objstorage"].check_config(**req)) @timed async def contains(request): req = await decode_request(request) return encode_data(request.app["objstorage"].__contains__(**req)) @timed async def add_bytes(request): req = await decode_request(request) statsd.increment( "swh_objstorage_in_bytes_total", len(req["content"]), tags={"endpoint": "add_bytes"}, ) return encode_data(request.app["objstorage"].add(**req)) @timed async def add_batch(request): req = await decode_request(request) return encode_data(request.app["objstorage"].add_batch(**req)) @timed async def get_bytes(request): req = await decode_request(request) ret = request.app["objstorage"].get(**req) statsd.increment( "swh_objstorage_out_bytes_total", len(ret), tags={"endpoint": "get_bytes"} ) return encode_data(ret) @timed async def get_batch(request): req = await decode_request(request) return encode_data(request.app["objstorage"].get_batch(**req)) @timed async def check(request): req = await decode_request(request) return encode_data(request.app["objstorage"].check(**req)) @timed async def delete(request): req = await decode_request(request) return encode_data(request.app["objstorage"].delete(**req)) # Management methods @timed async def get_random_contents(request): req = await decode_request(request) return encode_data(request.app["objstorage"].get_random(**req)) # Streaming methods @timed async def add_stream(request): hex_id = request.match_info["hex_id"] obj_id = hashutil.hash_to_bytes(hex_id) check_pres = request.query.get("check_presence", "").lower() == "true" objstorage = request.app["objstorage"] if check_pres and obj_id in objstorage: return encode_data(obj_id) # XXX this really should go in a decode_stream_request coroutine in # swh.core, but since py35 does not support async generators, it cannot # easily be made for now content_type = request.headers.get("Content-Type") if content_type == "application/x-msgpack": decode = msgpack_loads elif content_type == "application/json": decode = lambda x: json.loads(x, cls=SWHJSONDecoder) # noqa else: raise ValueError("Wrong content type `%s` for API request" % content_type) buffer = b"" with objstorage.chunk_writer(obj_id) as write: while not request.content.at_eof(): data, eot = await request.content.readchunk() buffer += data if eot: write(decode(buffer)) buffer = b"" return encode_data(obj_id) @timed async def get_stream(request): hex_id = request.match_info["hex_id"] obj_id = hashutil.hash_to_bytes(hex_id) response = aiohttp.web.StreamResponse() await response.prepare(request) for chunk in request.app["objstorage"].get_stream(obj_id, 2 << 20): await response.write(chunk) await response.write_eof() return response @timed async def list_content(request): last_obj_id = request.query.get("last_obj_id") if last_obj_id: last_obj_id = bytes.fromhex(last_obj_id) limit = int(request.query.get("limit", DEFAULT_LIMIT)) response = aiohttp.web.StreamResponse() response.enable_chunked_encoding() await response.prepare(request) for obj_id in request.app["objstorage"].list_content(last_obj_id, limit=limit): await response.write(obj_id) await response.write_eof() return response def make_app(config): """Initialize the remote api application. """ client_max_size = config.get("client_max_size", 1024 * 1024 * 1024) app = RPCServerApp(client_max_size=client_max_size) app.client_exception_classes = (ObjNotFoundError, Error) # retro compatibility configuration settings app["config"] = config _cfg = config["objstorage"] app["objstorage"] = get_objstorage(_cfg["cls"], _cfg["args"]) app.router.add_route("GET", "/", index) app.router.add_route("POST", "/check_config", check_config) app.router.add_route("POST", "/content/contains", contains) app.router.add_route("POST", "/content/add", add_bytes) app.router.add_route("POST", "/content/add/batch", add_batch) app.router.add_route("POST", "/content/get", get_bytes) app.router.add_route("POST", "/content/get/batch", get_batch) app.router.add_route("POST", "/content/get/random", get_random_contents) app.router.add_route("POST", "/content/check", check) app.router.add_route("POST", "/content/delete", delete) app.router.add_route("GET", "/content", list_content) app.router.add_route("POST", "/content/add_stream/{hex_id}", add_stream) app.router.add_route("GET", "/content/get_stream/{hex_id}", get_stream) return app def load_and_check_config(config_file): """Check the minimal configuration is set to run the api or raise an error explanation. Args: config_file (str): Path to the configuration file to load type (str): configuration type. For 'local' type, more checks are done. Raises: Error if the setup is not as expected Returns: configuration as a dict """ if not config_file: raise EnvironmentError("Configuration file must be defined") if not os.path.exists(config_file): raise FileNotFoundError("Configuration file %s does not exist" % (config_file,)) cfg = config_read(config_file) if "objstorage" not in cfg: raise KeyError("Invalid configuration; missing objstorage config entry") missing_keys = [] vcfg = cfg["objstorage"] for key in ("cls", "args"): v = vcfg.get(key) if v is None: missing_keys.append(key) if missing_keys: raise KeyError( "Invalid configuration; missing %s config entry" % (", ".join(missing_keys),) ) cls = vcfg.get("cls") if cls == "pathslicing": args = vcfg["args"] for key in ("root", "slicing"): v = args.get(key) if v is None: missing_keys.append(key) if missing_keys: raise KeyError( "Invalid configuration; missing args.%s config entry" % (", ".join(missing_keys),) ) return cfg def make_app_from_configfile(): """Load configuration and then build application to run """ config_file = os.environ.get("SWH_CONFIG_FILENAME") config = load_and_check_config(config_file) return make_app(config=config) if __name__ == "__main__": print("Deprecated. Use swh-objstorage") diff --git a/swh/objstorage/cli.py b/swh/objstorage/cli.py index 3aa3514..dac97ae 100644 --- a/swh/objstorage/cli.py +++ b/swh/objstorage/cli.py @@ -1,107 +1,107 @@ -# Copyright (C) 2015-2019 The Software Heritage developers +# Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import logging import time import click import aiohttp.web from swh.core.cli import CONTEXT_SETTINGS -from swh.objstorage import get_objstorage from swh.objstorage.api.server import load_and_check_config, make_app +from swh.objstorage.factory import get_objstorage @click.group(name="objstorage", context_settings=CONTEXT_SETTINGS) @click.option( "--config-file", "-C", default=None, type=click.Path(exists=True, dir_okay=False,), help="Configuration file.", ) @click.pass_context def cli(ctx, config_file): """Software Heritage Objstorage tools. """ ctx.ensure_object(dict) cfg = load_and_check_config(config_file) ctx.obj["config"] = cfg @cli.command("rpc-serve") @click.option( "--host", default="0.0.0.0", metavar="IP", show_default=True, help="Host ip address to bind the server on", ) @click.option( "--port", "-p", default=5003, type=click.INT, metavar="PORT", show_default=True, help="Binding port of the server", ) @click.pass_context def serve(ctx, host, port): """Run a standalone objstorage server. This is not meant to be run on production systems. """ app = make_app(ctx.obj["config"]) if ctx.obj["log_level"] == "DEBUG": app.update(debug=True) aiohttp.web.run_app(app, host=host, port=int(port)) @cli.command("import") @click.argument("directory", required=True, nargs=-1) @click.pass_context def import_directories(ctx, directory): """Import a local directory in an existing objstorage. """ objstorage = get_objstorage(**ctx.obj["config"]["objstorage"]) nobj = 0 volume = 0 t0 = time.time() for dirname in directory: for root, _dirs, files in os.walk(dirname): for name in files: path = os.path.join(root, name) with open(path, "rb") as f: objstorage.add(f.read()) volume += os.stat(path).st_size nobj += 1 click.echo( "Imported %d files for a volume of %s bytes in %d seconds" % (nobj, volume, time.time() - t0) ) @cli.command("fsck") @click.pass_context def fsck(ctx): """Check the objstorage is not corrupted. """ objstorage = get_objstorage(**ctx.obj["config"]["objstorage"]) for obj_id in objstorage: try: objstorage.check(obj_id) except objstorage.Error as err: logging.error(err) def main(): return cli(auto_envvar_prefix="SWH_OBJSTORAGE") if __name__ == "__main__": main() diff --git a/swh/objstorage/__init__.py b/swh/objstorage/factory.py similarity index 98% copy from swh/objstorage/__init__.py copy to swh/objstorage/factory.py index 498e16e..c431b0f 100644 --- a/swh/objstorage/__init__.py +++ b/swh/objstorage/factory.py @@ -1,107 +1,108 @@ -# Copyright (C) 2016 The Software Heritage developers +# Copyright (C) 2016-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from typing import Callable, Dict, Union + from swh.objstorage.objstorage import ObjStorage, ID_HASH_LENGTH # noqa from swh.objstorage.backends.pathslicing import PathSlicingObjStorage from swh.objstorage.backends.in_memory import InMemoryObjStorage from swh.objstorage.api.client import RemoteObjStorage from swh.objstorage.multiplexer import MultiplexerObjStorage, StripingObjStorage from swh.objstorage.multiplexer.filter import add_filters from swh.objstorage.backends.seaweed import WeedObjStorage from swh.objstorage.backends.generator import RandomGeneratorObjStorage -from typing import Callable, Dict, Union __all__ = ["get_objstorage", "ObjStorage"] _STORAGE_CLASSES: Dict[str, Union[type, Callable[..., type]]] = { "pathslicing": PathSlicingObjStorage, "remote": RemoteObjStorage, "memory": InMemoryObjStorage, "weed": WeedObjStorage, "random": RandomGeneratorObjStorage, } _STORAGE_CLASSES_MISSING = {} try: from swh.objstorage.backends.azure import ( AzureCloudObjStorage, PrefixedAzureCloudObjStorage, ) _STORAGE_CLASSES["azure"] = AzureCloudObjStorage _STORAGE_CLASSES["azure-prefixed"] = PrefixedAzureCloudObjStorage except ImportError as e: _STORAGE_CLASSES_MISSING["azure"] = e.args[0] _STORAGE_CLASSES_MISSING["azure-prefixed"] = e.args[0] try: from swh.objstorage.backends.rados import RADOSObjStorage _STORAGE_CLASSES["rados"] = RADOSObjStorage except ImportError as e: _STORAGE_CLASSES_MISSING["rados"] = e.args[0] try: from swh.objstorage.backends.libcloud import ( AwsCloudObjStorage, OpenStackCloudObjStorage, ) _STORAGE_CLASSES["s3"] = AwsCloudObjStorage _STORAGE_CLASSES["swift"] = OpenStackCloudObjStorage except ImportError as e: _STORAGE_CLASSES_MISSING["s3"] = e.args[0] _STORAGE_CLASSES_MISSING["swift"] = e.args[0] def get_objstorage(cls, args): """ Create an ObjStorage using the given implementation class. Args: cls (str): objstorage class unique key contained in the _STORAGE_CLASSES dict. args (dict): arguments for the required class of objstorage that must match exactly the one in the `__init__` method of the class. Returns: subclass of ObjStorage that match the given `storage_class` argument. Raises: ValueError: if the given storage class is not a valid objstorage key. """ if cls in _STORAGE_CLASSES: return _STORAGE_CLASSES[cls](**args) else: raise ValueError( "Storage class {} is not available: {}".format( cls, _STORAGE_CLASSES_MISSING.get(cls, "unknown name") ) ) def _construct_filtered_objstorage(storage_conf, filters_conf): return add_filters(get_objstorage(**storage_conf), filters_conf) _STORAGE_CLASSES["filtered"] = _construct_filtered_objstorage def _construct_multiplexer_objstorage(objstorages): storages = [get_objstorage(**conf) for conf in objstorages] return MultiplexerObjStorage(storages) _STORAGE_CLASSES["multiplexer"] = _construct_multiplexer_objstorage def _construct_striping_objstorage(objstorages): storages = [get_objstorage(**conf) for conf in objstorages] return StripingObjStorage(storages) _STORAGE_CLASSES["striping"] = _construct_striping_objstorage diff --git a/swh/objstorage/tests/test_multiplexer_filter.py b/swh/objstorage/tests/test_multiplexer_filter.py index fbeb23a..f66e976 100644 --- a/swh/objstorage/tests/test_multiplexer_filter.py +++ b/swh/objstorage/tests/test_multiplexer_filter.py @@ -1,331 +1,331 @@ -# Copyright (C) 2015-2018 The Software Heritage developers +# Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import random import shutil import tempfile import unittest from string import ascii_lowercase from swh.model import hashutil -from swh.objstorage import get_objstorage +from swh.objstorage.factory import get_objstorage from swh.objstorage.exc import Error, ObjNotFoundError from swh.objstorage.multiplexer.filter import id_prefix, id_regex, read_only from swh.objstorage.objstorage import compute_hash def get_random_content(): return bytes("".join(random.sample(ascii_lowercase, 10)), "utf8") class MixinTestReadFilter(unittest.TestCase): # Read only filter should not allow writing def setUp(self): super().setUp() self.tmpdir = tempfile.mkdtemp() pstorage = { "cls": "pathslicing", "args": {"root": self.tmpdir, "slicing": "0:5"}, } base_storage = get_objstorage(**pstorage) base_storage.id = compute_hash self.storage = get_objstorage( "filtered", {"storage_conf": pstorage, "filters_conf": [read_only()]} ) self.valid_content = b"pre-existing content" self.invalid_content = b"invalid_content" self.true_invalid_content = b"Anything that is not correct" self.absent_content = b"non-existent content" # Create a valid content. self.valid_id = base_storage.add(self.valid_content) # Create an invalid id and add a content with it. self.invalid_id = base_storage.id(self.true_invalid_content) base_storage.add(self.invalid_content, obj_id=self.invalid_id) # Compute an id for a non-existing content. self.absent_id = base_storage.id(self.absent_content) def tearDown(self): super().tearDown() shutil.rmtree(self.tmpdir) def test_can_contains(self): self.assertTrue(self.valid_id in self.storage) self.assertTrue(self.invalid_id in self.storage) self.assertFalse(self.absent_id in self.storage) def test_can_iter(self): self.assertIn(self.valid_id, iter(self.storage)) self.assertIn(self.invalid_id, iter(self.storage)) def test_can_len(self): self.assertEqual(2, len(self.storage)) def test_can_get(self): self.assertEqual(self.valid_content, self.storage.get(self.valid_id)) self.assertEqual(self.invalid_content, self.storage.get(self.invalid_id)) def test_can_check(self): with self.assertRaises(ObjNotFoundError): self.storage.check(self.absent_id) with self.assertRaises(Error): self.storage.check(self.invalid_id) self.storage.check(self.valid_id) def test_can_get_random(self): self.assertEqual(1, len(list(self.storage.get_random(1)))) self.assertEqual( len(list(self.storage)), len(set(self.storage.get_random(1000))) ) def test_cannot_add(self): new_id = self.storage.add(b"New content") result = self.storage.add(self.valid_content, self.valid_id) self.assertIsNone(new_id, self.storage) self.assertIsNone(result) def test_cannot_restore(self): result = self.storage.restore(self.valid_content, self.valid_id) self.assertIsNone(result) class MixinTestIdFilter: """ Mixin class that tests the filters based on filter.IdFilter Methods "make_valid", "make_invalid" and "filter_storage" must be implemented by subclasses. """ def setUp(self): super().setUp() # Use a hack here : as the mock uses the content as id, it is easy to # create contents that are filtered or not. self.prefix = "71" self.tmpdir = tempfile.mkdtemp() # Make the storage filtered self.sconf = { "cls": "pathslicing", "args": {"root": self.tmpdir, "slicing": "0:5"}, } storage = get_objstorage(**self.sconf) self.base_storage = storage self.storage = self.filter_storage(self.sconf) # Set the id calculators storage.id = compute_hash # Present content with valid id self.present_valid_content = self.ensure_valid(b"yroqdtotji") self.present_valid_id = storage.id(self.present_valid_content) # Present content with invalid id self.present_invalid_content = self.ensure_invalid(b"glxddlmmzb") self.present_invalid_id = storage.id(self.present_invalid_content) # Missing content with valid id self.missing_valid_content = self.ensure_valid(b"rmzkdclkez") self.missing_valid_id = storage.id(self.missing_valid_content) # Missing content with invalid id self.missing_invalid_content = self.ensure_invalid(b"hlejfuginh") self.missing_invalid_id = storage.id(self.missing_invalid_content) # Present corrupted content with valid id self.present_corrupted_valid_content = self.ensure_valid(b"cdsjwnpaij") self.true_present_corrupted_valid_content = self.ensure_valid(b"mgsdpawcrr") self.present_corrupted_valid_id = storage.id( self.true_present_corrupted_valid_content ) # Present corrupted content with invalid id self.present_corrupted_invalid_content = self.ensure_invalid(b"pspjljnrco") self.true_present_corrupted_invalid_content = self.ensure_invalid(b"rjocbnnbso") self.present_corrupted_invalid_id = storage.id( self.true_present_corrupted_invalid_content ) # Missing (potentially) corrupted content with valid id self.missing_corrupted_valid_content = self.ensure_valid(b"zxkokfgtou") self.true_missing_corrupted_valid_content = self.ensure_valid(b"royoncooqa") self.missing_corrupted_valid_id = storage.id( self.true_missing_corrupted_valid_content ) # Missing (potentially) corrupted content with invalid id self.missing_corrupted_invalid_content = self.ensure_invalid(b"hxaxnrmnyk") self.true_missing_corrupted_invalid_content = self.ensure_invalid(b"qhbolyuifr") self.missing_corrupted_invalid_id = storage.id( self.true_missing_corrupted_invalid_content ) # Add the content that are supposed to be present self.storage.add(self.present_valid_content) self.storage.add(self.present_invalid_content) self.storage.add( self.present_corrupted_valid_content, obj_id=self.present_corrupted_valid_id ) self.storage.add( self.present_corrupted_invalid_content, obj_id=self.present_corrupted_invalid_id, ) def tearDown(self): super().tearDown() shutil.rmtree(self.tmpdir) def filter_storage(self, sconf): raise NotImplementedError( "Id_filter test class must have a filter_storage method" ) def ensure_valid(self, content=None): if content is None: content = get_random_content() while not self.storage.is_valid(self.base_storage.id(content)): content = get_random_content() return content def ensure_invalid(self, content=None): if content is None: content = get_random_content() while self.storage.is_valid(self.base_storage.id(content)): content = get_random_content() return content def test_contains(self): # Both contents are present, but the invalid one should be ignored. self.assertTrue(self.present_valid_id in self.storage) self.assertFalse(self.present_invalid_id in self.storage) self.assertFalse(self.missing_valid_id in self.storage) self.assertFalse(self.missing_invalid_id in self.storage) self.assertTrue(self.present_corrupted_valid_id in self.storage) self.assertFalse(self.present_corrupted_invalid_id in self.storage) self.assertFalse(self.missing_corrupted_valid_id in self.storage) self.assertFalse(self.missing_corrupted_invalid_id in self.storage) def test_iter(self): self.assertIn(self.present_valid_id, iter(self.storage)) self.assertNotIn(self.present_invalid_id, iter(self.storage)) self.assertNotIn(self.missing_valid_id, iter(self.storage)) self.assertNotIn(self.missing_invalid_id, iter(self.storage)) self.assertIn(self.present_corrupted_valid_id, iter(self.storage)) self.assertNotIn(self.present_corrupted_invalid_id, iter(self.storage)) self.assertNotIn(self.missing_corrupted_valid_id, iter(self.storage)) self.assertNotIn(self.missing_corrupted_invalid_id, iter(self.storage)) def test_len(self): # Four contents are present, but only two should be valid. self.assertEqual(2, len(self.storage)) def test_get(self): self.assertEqual( self.present_valid_content, self.storage.get(self.present_valid_id) ) with self.assertRaises(ObjNotFoundError): self.storage.get(self.present_invalid_id) with self.assertRaises(ObjNotFoundError): self.storage.get(self.missing_valid_id) with self.assertRaises(ObjNotFoundError): self.storage.get(self.missing_invalid_id) self.assertEqual( self.present_corrupted_valid_content, self.storage.get(self.present_corrupted_valid_id), ) with self.assertRaises(ObjNotFoundError): self.storage.get(self.present_corrupted_invalid_id) with self.assertRaises(ObjNotFoundError): self.storage.get(self.missing_corrupted_valid_id) with self.assertRaises(ObjNotFoundError): self.storage.get(self.missing_corrupted_invalid_id) def test_check(self): self.storage.check(self.present_valid_id) with self.assertRaises(ObjNotFoundError): self.storage.check(self.present_invalid_id) with self.assertRaises(ObjNotFoundError): self.storage.check(self.missing_valid_id) with self.assertRaises(ObjNotFoundError): self.storage.check(self.missing_invalid_id) with self.assertRaises(Error): self.storage.check(self.present_corrupted_valid_id) with self.assertRaises(ObjNotFoundError): self.storage.check(self.present_corrupted_invalid_id) with self.assertRaises(ObjNotFoundError): self.storage.check(self.missing_corrupted_valid_id) with self.assertRaises(ObjNotFoundError): self.storage.check(self.missing_corrupted_invalid_id) def test_get_random(self): self.assertEqual(0, len(list(self.storage.get_random(0)))) random_content = list(self.storage.get_random(1000)) self.assertIn(self.present_valid_id, random_content) self.assertNotIn(self.present_invalid_id, random_content) self.assertNotIn(self.missing_valid_id, random_content) self.assertNotIn(self.missing_invalid_id, random_content) self.assertIn(self.present_corrupted_valid_id, random_content) self.assertNotIn(self.present_corrupted_invalid_id, random_content) self.assertNotIn(self.missing_corrupted_valid_id, random_content) self.assertNotIn(self.missing_corrupted_invalid_id, random_content) def test_add(self): # Add valid and invalid contents to the storage and check their # presence with the unfiltered storage. valid_content = self.ensure_valid(b"ulepsrjbgt") valid_id = self.base_storage.id(valid_content) invalid_content = self.ensure_invalid(b"znvghkjked") invalid_id = self.base_storage.id(invalid_content) self.storage.add(valid_content) self.storage.add(invalid_content) self.assertTrue(valid_id in self.base_storage) self.assertFalse(invalid_id in self.base_storage) def test_restore(self): # Add corrupted content to the storage and the try to restore it valid_content = self.ensure_valid(b"ulepsrjbgt") valid_id = self.base_storage.id(valid_content) corrupted_content = self.ensure_valid(b"ltjkjsloyb") corrupted_id = self.base_storage.id(corrupted_content) self.storage.add(corrupted_content, obj_id=valid_id) with self.assertRaises(ObjNotFoundError): self.storage.check(corrupted_id) with self.assertRaises(Error): self.storage.check(valid_id) self.storage.restore(valid_content) self.storage.check(valid_id) class TestPrefixFilter(MixinTestIdFilter, unittest.TestCase): def setUp(self): self.prefix = b"71" super().setUp() def ensure_valid(self, content): obj_id = compute_hash(content) hex_obj_id = hashutil.hash_to_hex(obj_id) self.assertTrue(hex_obj_id.startswith(self.prefix)) return content def ensure_invalid(self, content): obj_id = compute_hash(content) hex_obj_id = hashutil.hash_to_hex(obj_id) self.assertFalse(hex_obj_id.startswith(self.prefix)) return content def filter_storage(self, sconf): return get_objstorage( "filtered", {"storage_conf": sconf, "filters_conf": [id_prefix(self.prefix)]}, ) class TestRegexFilter(MixinTestIdFilter, unittest.TestCase): def setUp(self): self.regex = r"[a-f][0-9].*" super().setUp() def filter_storage(self, sconf): return get_objstorage( "filtered", {"storage_conf": sconf, "filters_conf": [id_regex(self.regex)]} ) diff --git a/swh/objstorage/tests/test_objstorage_api.py b/swh/objstorage/tests/test_objstorage_api.py index c77d2ee..2d2f17d 100644 --- a/swh/objstorage/tests/test_objstorage_api.py +++ b/swh/objstorage/tests/test_objstorage_api.py @@ -1,51 +1,51 @@ -# Copyright (C) 2015-2019 The Software Heritage developers +# Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import shutil import tempfile import unittest import pytest from swh.core.api.tests.server_testing import ServerTestFixtureAsync -from swh.objstorage import get_objstorage +from swh.objstorage.factory import get_objstorage from swh.objstorage.api.server import make_app from swh.objstorage.tests.objstorage_testing import ObjStorageTestFixture class TestRemoteObjStorage( ServerTestFixtureAsync, ObjStorageTestFixture, unittest.TestCase ): """ Test the remote archive API. """ def setUp(self): self.tmpdir = tempfile.mkdtemp() self.config = { "objstorage": { "cls": "pathslicing", "args": { "root": self.tmpdir, "slicing": "0:1/0:5", "allow_delete": True, }, }, "client_max_size": 8 * 1024 * 1024, } self.app = make_app(self.config) super().setUp() self.storage = get_objstorage("remote", {"url": self.url()}) def tearDown(self): super().tearDown() shutil.rmtree(self.tmpdir) @pytest.mark.skip("makes no sense to test this for the remote api") def test_delete_not_allowed(self): pass @pytest.mark.skip("makes no sense to test this for the remote api") def test_delete_not_allowed_by_default(self): pass diff --git a/swh/objstorage/tests/test_objstorage_azure.py b/swh/objstorage/tests/test_objstorage_azure.py index 629dc1a..a42ff30 100644 --- a/swh/objstorage/tests/test_objstorage_azure.py +++ b/swh/objstorage/tests/test_objstorage_azure.py @@ -1,179 +1,179 @@ -# Copyright (C) 2016-2018 The Software Heritage developers +# Copyright (C) 2016-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import unittest from collections import defaultdict from unittest.mock import patch from typing import Any, Dict from azure.common import AzureMissingResourceHttpError from swh.model.hashutil import hash_to_hex -from swh.objstorage import get_objstorage +from swh.objstorage.factory import get_objstorage from swh.objstorage.objstorage import decompressors from swh.objstorage.exc import Error from .objstorage_testing import ObjStorageTestFixture class MockBlob: """ Libcloud object mock that replicates its API """ def __init__(self, name, content): self.name = name self.content = content class MockBlockBlobService: """Mock internal azure library which AzureCloudObjStorage depends upon. """ _data: Dict[str, Any] = {} def __init__(self, account_name, account_key, **kwargs): # do not care for the account_name and the api_secret_key here self._data = defaultdict(dict) def get_container_properties(self, container_name): self._data[container_name] return container_name in self._data def create_blob_from_bytes(self, container_name, blob_name, blob): self._data[container_name][blob_name] = blob def get_blob_to_bytes(self, container_name, blob_name): if blob_name not in self._data[container_name]: raise AzureMissingResourceHttpError("Blob %s not found" % blob_name, 404) return MockBlob(name=blob_name, content=self._data[container_name][blob_name]) def delete_blob(self, container_name, blob_name): try: self._data[container_name].pop(blob_name) except KeyError: raise AzureMissingResourceHttpError("Blob %s not found" % blob_name, 404) return True def exists(self, container_name, blob_name): return blob_name in self._data[container_name] def list_blobs(self, container_name, marker=None, maxresults=None): for blob_name, content in sorted(self._data[container_name].items()): if marker is None or blob_name > marker: yield MockBlob(name=blob_name, content=content) class TestAzureCloudObjStorage(ObjStorageTestFixture, unittest.TestCase): compression = "none" def setUp(self): super().setUp() patcher = patch( "swh.objstorage.backends.azure.BlockBlobService", MockBlockBlobService, ) patcher.start() self.addCleanup(patcher.stop) self.storage = get_objstorage( "azure", { "account_name": "account-name", "api_secret_key": "api-secret-key", "container_name": "container-name", "compression": self.compression, }, ) def test_compression(self): content, obj_id = self.hash_content(b"test content is compressed") self.storage.add(content, obj_id=obj_id) blob_service, container = self.storage.get_blob_service(obj_id) internal_id = self.storage._internal_id(obj_id) raw_blob = blob_service.get_blob_to_bytes(container, internal_id) d = decompressors[self.compression]() assert d.decompress(raw_blob.content) == content assert d.unused_data == b"" def test_trailing_data_on_stored_blob(self): content, obj_id = self.hash_content(b"test content without garbage") self.storage.add(content, obj_id=obj_id) blob_service, container = self.storage.get_blob_service(obj_id) internal_id = self.storage._internal_id(obj_id) blob_service._data[container][internal_id] += b"trailing garbage" if self.compression == "none": with self.assertRaises(Error) as e: self.storage.check(obj_id) else: with self.assertRaises(Error) as e: self.storage.get(obj_id) assert "trailing data" in e.exception.args[0] class TestAzureCloudObjStorageGzip(TestAzureCloudObjStorage): compression = "gzip" class TestAzureCloudObjStorageZlib(TestAzureCloudObjStorage): compression = "zlib" class TestAzureCloudObjStorageLzma(TestAzureCloudObjStorage): compression = "lzma" class TestAzureCloudObjStorageBz2(TestAzureCloudObjStorage): compression = "bz2" class TestPrefixedAzureCloudObjStorage(ObjStorageTestFixture, unittest.TestCase): def setUp(self): super().setUp() patcher = patch( "swh.objstorage.backends.azure.BlockBlobService", MockBlockBlobService, ) patcher.start() self.addCleanup(patcher.stop) self.accounts = {} for prefix in "0123456789abcdef": self.accounts[prefix] = { "account_name": "account_%s" % prefix, "api_secret_key": "secret_key_%s" % prefix, "container_name": "container_%s" % prefix, } self.storage = get_objstorage("azure-prefixed", {"accounts": self.accounts}) def test_prefixedazure_instantiation_missing_prefixes(self): del self.accounts["d"] del self.accounts["e"] with self.assertRaisesRegex(ValueError, "Missing prefixes"): get_objstorage("azure-prefixed", {"accounts": self.accounts}) def test_prefixedazure_instantiation_inconsistent_prefixes(self): self.accounts["00"] = self.accounts["0"] with self.assertRaisesRegex(ValueError, "Inconsistent prefixes"): get_objstorage("azure-prefixed", {"accounts": self.accounts}) def test_prefixedazure_sharding_behavior(self): for i in range(100): content, obj_id = self.hash_content(b"test_content_%02d" % i) self.storage.add(content, obj_id=obj_id) hex_obj_id = hash_to_hex(obj_id) prefix = hex_obj_id[0] self.assertTrue( self.storage.prefixes[prefix][0].exists( self.accounts[prefix]["container_name"], hex_obj_id ) ) diff --git a/swh/objstorage/tests/test_objstorage_in_memory.py b/swh/objstorage/tests/test_objstorage_in_memory.py index 2931b9b..d152cf5 100644 --- a/swh/objstorage/tests/test_objstorage_in_memory.py +++ b/swh/objstorage/tests/test_objstorage_in_memory.py @@ -1,16 +1,16 @@ -# Copyright (C) 2015-2018 The Software Heritage developers +# Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import unittest -from swh.objstorage import get_objstorage +from swh.objstorage.factory import get_objstorage from .objstorage_testing import ObjStorageTestFixture class TestInMemoryObjStorage(ObjStorageTestFixture, unittest.TestCase): def setUp(self): super().setUp() self.storage = get_objstorage(cls="memory", args={}) diff --git a/swh/objstorage/tests/test_objstorage_instantiation.py b/swh/objstorage/tests/test_objstorage_instantiation.py index e15ee4e..c4fe9a1 100644 --- a/swh/objstorage/tests/test_objstorage_instantiation.py +++ b/swh/objstorage/tests/test_objstorage_instantiation.py @@ -1,40 +1,40 @@ -# Copyright (C) 2015-2016 The Software Heritage developers +# Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import shutil import tempfile import unittest -from swh.objstorage import get_objstorage +from swh.objstorage.factory import get_objstorage from swh.objstorage.api.client import RemoteObjStorage from swh.objstorage.backends.pathslicing import PathSlicingObjStorage class TestObjStorageInitialization(unittest.TestCase): """ Test that the methods for ObjStorage initializations with `get_objstorage` works properly. """ def setUp(self): self.path = tempfile.mkdtemp() self.path2 = tempfile.mkdtemp() # Server is launched at self.url() self.config = {"storage_base": self.path2, "storage_slicing": "0:1/0:5"} super().setUp() def tearDown(self): super().tearDown() shutil.rmtree(self.path) shutil.rmtree(self.path2) def test_pathslicing_objstorage(self): conf = {"cls": "pathslicing", "args": {"root": self.path, "slicing": "0:2/0:5"}} st = get_objstorage(**conf) self.assertTrue(isinstance(st, PathSlicingObjStorage)) def test_remote_objstorage(self): conf = {"cls": "remote", "args": {"url": "http://127.0.0.1:4242/"}} st = get_objstorage(**conf) self.assertTrue(isinstance(st, RemoteObjStorage)) diff --git a/swh/objstorage/tests/test_objstorage_multiplexer.py b/swh/objstorage/tests/test_objstorage_multiplexer.py index cec4beb..c3670a4 100644 --- a/swh/objstorage/tests/test_objstorage_multiplexer.py +++ b/swh/objstorage/tests/test_objstorage_multiplexer.py @@ -1,68 +1,68 @@ -# Copyright (C) 2015-2016 The Software Heritage developers +# Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import shutil import tempfile import unittest -from swh.objstorage import PathSlicingObjStorage +from swh.objstorage.backends.pathslicing import PathSlicingObjStorage from swh.objstorage.multiplexer import MultiplexerObjStorage from swh.objstorage.multiplexer.filter import add_filter, read_only from .objstorage_testing import ObjStorageTestFixture class TestMultiplexerObjStorage(ObjStorageTestFixture, unittest.TestCase): def setUp(self): super().setUp() self.tmpdir = tempfile.mkdtemp() os.mkdir(os.path.join(self.tmpdir, "root1")) os.mkdir(os.path.join(self.tmpdir, "root2")) self.storage_v1 = PathSlicingObjStorage( os.path.join(self.tmpdir, "root1"), "0:2/2:4" ) self.storage_v2 = PathSlicingObjStorage( os.path.join(self.tmpdir, "root2"), "0:1/0:5" ) self.r_storage = add_filter(self.storage_v1, read_only()) self.w_storage = self.storage_v2 self.storage = MultiplexerObjStorage([self.r_storage, self.w_storage]) def tearDown(self): super().tearDown() shutil.rmtree(self.tmpdir) def test_contains(self): content_p, obj_id_p = self.hash_content(b"contains_present") content_m, obj_id_m = self.hash_content(b"contains_missing") self.storage.add(content_p, obj_id=obj_id_p) self.assertIn(obj_id_p, self.storage) self.assertNotIn(obj_id_m, self.storage) def test_delete_missing(self): self.storage_v1.allow_delete = True self.storage_v2.allow_delete = True super().test_delete_missing() def test_delete_present(self): self.storage_v1.allow_delete = True self.storage_v2.allow_delete = True super().test_delete_present() def test_get_random_contents(self): content, obj_id = self.hash_content(b"get_random_content") self.storage.add(content) random_contents = list(self.storage.get_random(1)) self.assertEqual(1, len(random_contents)) self.assertIn(obj_id, random_contents) def test_access_readonly(self): # Add a content to the readonly storage content, obj_id = self.hash_content(b"content in read-only") self.storage_v1.add(content) # Try to retrieve it on the main storage self.assertIn(obj_id, self.storage) diff --git a/swh/objstorage/tests/test_objstorage_pathslicing.py b/swh/objstorage/tests/test_objstorage_pathslicing.py index ddcf01a..5925007 100644 --- a/swh/objstorage/tests/test_objstorage_pathslicing.py +++ b/swh/objstorage/tests/test_objstorage_pathslicing.py @@ -1,161 +1,163 @@ -# Copyright (C) 2015-2017 The Software Heritage developers +# Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import shutil import tempfile import unittest from unittest.mock import patch, DEFAULT from swh.model import hashutil -from swh.objstorage import exc, get_objstorage, ID_HASH_LENGTH +from swh.objstorage import exc +from swh.objstorage.factory import get_objstorage +from swh.objstorage.objstorage import ID_HASH_LENGTH from .objstorage_testing import ObjStorageTestFixture class TestPathSlicingObjStorage(ObjStorageTestFixture, unittest.TestCase): compression = "none" def setUp(self): super().setUp() self.slicing = "0:2/2:4/4:6" self.tmpdir = tempfile.mkdtemp() self.storage = get_objstorage( "pathslicing", { "root": self.tmpdir, "slicing": self.slicing, "compression": self.compression, }, ) def tearDown(self): super().tearDown() shutil.rmtree(self.tmpdir) def content_path(self, obj_id): hex_obj_id = hashutil.hash_to_hex(obj_id) return self.storage._obj_path(hex_obj_id) def test_iter(self): content, obj_id = self.hash_content(b"iter") self.assertEqual(list(iter(self.storage)), []) self.storage.add(content, obj_id=obj_id) self.assertEqual(list(iter(self.storage)), [obj_id]) def test_len(self): content, obj_id = self.hash_content(b"len") self.assertEqual(len(self.storage), 0) self.storage.add(content, obj_id=obj_id) self.assertEqual(len(self.storage), 1) def test_check_ok(self): content, obj_id = self.hash_content(b"check_ok") self.storage.add(content, obj_id=obj_id) assert self.storage.check(obj_id) is None assert self.storage.check(obj_id.hex()) is None def test_check_id_mismatch(self): content, obj_id = self.hash_content(b"check_id_mismatch") self.storage.add(b"unexpected content", obj_id=obj_id) with self.assertRaises(exc.Error) as error: self.storage.check(obj_id) self.assertEqual( ( "Corrupt object %s should have id " "12ebb2d6c81395bcc5cab965bdff640110cb67ff" % obj_id.hex(), ), error.exception.args, ) def test_get_random_contents(self): content, obj_id = self.hash_content(b"get_random_content") self.storage.add(content, obj_id=obj_id) random_contents = list(self.storage.get_random(1)) self.assertEqual(1, len(random_contents)) self.assertIn(obj_id, random_contents) def test_iterate_from(self): all_ids = [] for i in range(100): content, obj_id = self.hash_content(b"content %d" % i) self.storage.add(content, obj_id=obj_id) all_ids.append(obj_id) all_ids.sort() ids = list(self.storage.iter_from(b"\x00" * (ID_HASH_LENGTH // 2))) self.assertEqual(len(ids), len(all_ids)) self.assertEqual(ids, all_ids) ids = list(self.storage.iter_from(all_ids[0])) self.assertEqual(len(ids), len(all_ids) - 1) self.assertEqual(ids, all_ids[1:]) ids = list(self.storage.iter_from(all_ids[-1], n_leaf=True)) n_leaf = ids[-1] ids = ids[:-1] self.assertEqual(n_leaf, 1) self.assertEqual(len(ids), 0) ids = list(self.storage.iter_from(all_ids[-2], n_leaf=True)) n_leaf = ids[-1] ids = ids[:-1] self.assertEqual(n_leaf, 2) # beware, this depends on the hash algo self.assertEqual(len(ids), 1) self.assertEqual(ids, all_ids[-1:]) def test_fdatasync_default(self): content, obj_id = self.hash_content(b"check_fdatasync") with patch.multiple("os", fsync=DEFAULT, fdatasync=DEFAULT) as patched: self.storage.add(content, obj_id=obj_id) if self.storage.use_fdatasync: assert patched["fdatasync"].call_count == 1 assert patched["fsync"].call_count == 0 else: assert patched["fdatasync"].call_count == 0 assert patched["fsync"].call_count == 1 def test_fdatasync_forced_on(self): self.storage.use_fdatasync = True content, obj_id = self.hash_content(b"check_fdatasync") with patch.multiple("os", fsync=DEFAULT, fdatasync=DEFAULT) as patched: self.storage.add(content, obj_id=obj_id) assert patched["fdatasync"].call_count == 1 assert patched["fsync"].call_count == 0 def test_fdatasync_forced_off(self): self.storage.use_fdatasync = False content, obj_id = self.hash_content(b"check_fdatasync") with patch.multiple("os", fsync=DEFAULT, fdatasync=DEFAULT) as patched: self.storage.add(content, obj_id=obj_id) assert patched["fdatasync"].call_count == 0 assert patched["fsync"].call_count == 1 def test_check_not_compressed(self): content, obj_id = self.hash_content(b"check_not_compressed") self.storage.add(content, obj_id=obj_id) with open(self.content_path(obj_id), "ab") as f: # Add garbage. f.write(b"garbage") with self.assertRaises(exc.Error) as error: self.storage.check(obj_id) if self.compression == "none": self.assertIn("Corrupt object", error.exception.args[0]) else: self.assertIn("trailing data found", error.exception.args[0]) class TestPathSlicingObjStorageGzip(TestPathSlicingObjStorage): compression = "gzip" class TestPathSlicingObjStorageZlib(TestPathSlicingObjStorage): compression = "zlib" class TestPathSlicingObjStorageBz2(TestPathSlicingObjStorage): compression = "bz2" class TestPathSlicingObjStorageLzma(TestPathSlicingObjStorage): compression = "lzma" diff --git a/swh/objstorage/tests/test_objstorage_random_generator.py b/swh/objstorage/tests/test_objstorage_random_generator.py index 699ff4a..1ffc674 100644 --- a/swh/objstorage/tests/test_objstorage_random_generator.py +++ b/swh/objstorage/tests/test_objstorage_random_generator.py @@ -1,45 +1,45 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import Iterator -from swh.objstorage import get_objstorage +from swh.objstorage.factory import get_objstorage def test_random_generator_objstorage(): sto = get_objstorage("random", {}) assert sto blobs = [sto.get(None) for i in range(100)] lengths = [len(x) for x in blobs] assert max(lengths) <= 55056238 def test_random_generator_objstorage_get_stream(): sto = get_objstorage("random", {}) gen = sto.get_stream(None) assert isinstance(gen, Iterator) assert list(gen) # ensure the iterator can be consumed def test_random_generator_objstorage_list_content(): sto = get_objstorage("random", {"total": 100}) assert isinstance(sto.list_content(), Iterator) assert list(sto.list_content()) == [b"%d" % i for i in range(1, 101)] assert list(sto.list_content(limit=10)) == [b"%d" % i for i in range(1, 11)] assert list(sto.list_content(last_obj_id=b"10", limit=10)) == [ b"%d" % i for i in range(11, 21) ] def test_random_generator_objstorage_total(): sto = get_objstorage("random", {"total": 5}) assert len([x for x in sto]) == 5 def test_random_generator_objstorage_size(): sto = get_objstorage("random", {"filesize": 10}) for i in range(10): assert len(sto.get(None)) == 10 diff --git a/swh/objstorage/tests/test_objstorage_striping.py b/swh/objstorage/tests/test_objstorage_striping.py index 2e34550..7cc0ec2 100644 --- a/swh/objstorage/tests/test_objstorage_striping.py +++ b/swh/objstorage/tests/test_objstorage_striping.py @@ -1,80 +1,80 @@ -# Copyright (C) 2015-2016 The Software Heritage developers +# Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import shutil import tempfile import unittest -from swh.objstorage import get_objstorage +from swh.objstorage.factory import get_objstorage from .objstorage_testing import ObjStorageTestFixture class TestStripingObjStorage(ObjStorageTestFixture, unittest.TestCase): def setUp(self): super().setUp() self.base_dir = tempfile.mkdtemp() os.mkdir(os.path.join(self.base_dir, "root1")) os.mkdir(os.path.join(self.base_dir, "root2")) storage_config = { "cls": "striping", "args": { "objstorages": [ { "cls": "pathslicing", "args": { "root": os.path.join(self.base_dir, "root1"), "slicing": "0:2", "allow_delete": True, }, }, { "cls": "pathslicing", "args": { "root": os.path.join(self.base_dir, "root2"), "slicing": "0:2", "allow_delete": True, }, }, ] }, } self.storage = get_objstorage(**storage_config) def tearDown(self): shutil.rmtree(self.base_dir) def test_add_get_wo_id(self): self.skipTest("can't add without id in the multiplexer storage") def test_add_striping_behavior(self): exp_storage_counts = [0, 0] storage_counts = [0, 0] for i in range(100): content, obj_id = self.hash_content(b"striping_behavior_test%02d" % i) self.storage.add(content, obj_id) exp_storage_counts[self.storage.get_storage_index(obj_id)] += 1 count = 0 for i, storage in enumerate(self.storage.storages): if obj_id not in storage: continue count += 1 storage_counts[i] += 1 self.assertEqual(count, 1) self.assertEqual(storage_counts, exp_storage_counts) def test_get_striping_behavior(self): # Make sure we can read objects that are available in any backend # storage content, obj_id = self.hash_content(b"striping_behavior_test") for storage in self.storage.storages: storage.add(content, obj_id) self.assertIn(obj_id, self.storage) storage.delete(obj_id) self.assertNotIn(obj_id, self.storage) def test_list_content(self): self.skipTest("Quite a chellenge to make it work")