diff --git a/mypy.ini b/mypy.ini --- a/mypy.ini +++ b/mypy.ini @@ -19,3 +19,6 @@ [mypy-rados.*] ignore_missing_imports = True + +[mypy-requests_toolbelt.*] +ignore_missing_imports = True diff --git a/requirements-test.txt b/requirements-test.txt --- a/requirements-test.txt +++ b/requirements-test.txt @@ -2,5 +2,6 @@ azure-storage-blob >= 12.0, != 12.9.0 # version 12.9.0 breaks mypy https://github.com/Azure/azure-sdk-for-python/pull/20891 pytest python-cephlibs +requests_toolbelt types-pyyaml types-requests diff --git a/requirements.txt b/requirements.txt --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ # remote storage API server aiohttp >= 3 click +requests >= 1.9 # optional dependencies # apache-libcloud diff --git a/swh/objstorage/backends/seaweed.py b/swh/objstorage/backends/seaweed.py --- a/swh/objstorage/backends/seaweed.py +++ b/swh/objstorage/backends/seaweed.py @@ -4,12 +4,15 @@ # See top-level LICENSE file for more information import io +from itertools import islice import logging +from typing import Iterator, Optional from urllib.parse import urljoin, urlparse import requests from swh.model import hashutil +from swh.objstorage.backends.pathslicing import PathSlicer from swh.objstorage.exc import Error, ObjNotFoundError from swh.objstorage.objstorage import ( DEFAULT_LIMIT, @@ -30,47 +33,102 @@ """ def __init__(self, url): + if url.endswith("/"): + url = url[:-1] self.url = url + self.baseurl = urljoin(url, "/") + self.basepath = urlparse(url).path + + self.session = requests.Session() + self.session.headers["Accept"] = "application/json" + + self.batchsize = DEFAULT_LIMIT + + def build_url(self, path): + assert path == self.basepath or path.startswith(self.basepath) + return urljoin(self.baseurl, path) def get(self, remote_path): - url = urljoin(self.url, remote_path) + url = self.build_url(remote_path) LOGGER.debug("Get file %s", url) - return requests.get(url).content + resp = self.session.get(url) + resp.raise_for_status() + return resp.content def exists(self, remote_path): - url = urljoin(self.url, remote_path) + url = self.build_url(remote_path) LOGGER.debug("Check file %s", url) - return requests.head(url).status_code == 200 + return self.session.head(url).status_code == 200 def put(self, fp, remote_path): - url = urljoin(self.url, remote_path) + url = self.build_url(remote_path) LOGGER.debug("Put file %s", url) - return requests.post(url, files={"file": fp}) + return self.session.post(url, files={"file": fp}) def delete(self, remote_path): - url = urljoin(self.url, remote_path) + url = self.build_url(remote_path) LOGGER.debug("Delete file %s", url) - return requests.delete(url) + return self.session.delete(url) + + def iterfiles( + self, dir: str, last_file_name: Optional[str] = None + ) -> Iterator[str]: + """Recursively yield absolute file names - def list(self, dir, last_file_name=None, limit=DEFAULT_LIMIT): - """list sub folders and files of @dir. show a better look if you turn on + Args: + dir (str): retrieve file names starting from this directory; must + be an absolute path. + last_file_name (str): if given, starts from the file just after; must + be basename. - returns a dict of "sub-folders and files" + Yields: + absolute file names """ - d = dir if dir.endswith("/") else (dir + "/") - url = urljoin(self.url, d) - headers = {"Accept": "application/json"} - params = {"limit": limit} + if dir.endswith("/"): + dir = dir[:-1] + + # first, generates files going "down" + yield from self._iter_files(dir, last_file_name) + + # then, continue iterate going up the tree + while True: + dir, last = dir.rsplit("/", 1) + if not (dir + "/").startswith(self.basepath): + # we are done + break + yield from self._iter_files(dir, last_file_name=last) + + def _iter_files( + self, dir: str, last_file_name: Optional[str] = None + ) -> Iterator[str]: + for entry in self._iter_one_dir(dir, last_file_name): + fullpath = entry["FullPath"] + if entry["Mode"] & 1 << 31: # it's a directory, recurse + yield from self._iter_files(fullpath) + else: + yield fullpath + + def _iter_one_dir(self, remote_path, last_file_name=None): + url = self.build_url(remote_path) + params = {"limit": self.batchsize} if last_file_name: params["lastFileName"] = last_file_name LOGGER.debug("List directory %s", url) - rsp = requests.get(url, params=params, headers=headers) - if rsp.ok: - return rsp.json() - else: - LOGGER.error('Error listing "%s". [HTTP %d]' % (url, rsp.status_code)) + while True: + rsp = self.session.get(url, params=params) + if rsp.ok: + dircontent = rsp.json() + if dircontent["Entries"]: + yield from dircontent["Entries"] + if not dircontent["ShouldDisplayLoadMore"]: + break + params["lastFileName"] = dircontent["LastFileName"] + + else: + LOGGER.error('Error listing "%s". [HTTP %d]' % (url, rsp.status_code)) + break class WeedObjStorage(ObjStorage): @@ -79,10 +137,15 @@ https://github.com/chrislusf/seaweedfs/wiki/Filer-Server-API """ - def __init__(self, url="http://127.0.0.1:8888/swh", compression=None, **kwargs): + def __init__( + self, url="http://127.0.0.1:8888/swh", compression=None, slicing="", **kwargs + ): super().__init__(**kwargs) self.wf = WeedFiler(url) self.root_path = urlparse(url).path + if not self.root_path.endswith("/"): + self.root_path += "/" + self.slicer = PathSlicer(self.root_path, slicing) self.compression = compression def check_config(self, *, check_write): @@ -176,15 +239,19 @@ def list_content(self, last_obj_id=None, limit=DEFAULT_LIMIT): if last_obj_id: - last_obj_id = hashutil.hash_to_hex(last_obj_id) - resp = self.wf.list(self.root_path, last_obj_id, limit) - if resp is not None: - entries = resp["Entries"] - if entries: - for obj in entries: - if obj is not None: - bytehex = obj["FullPath"].rsplit("/", 1)[-1] - yield hashutil.bytehex_to_hash(bytehex.encode()) + objid = hashutil.hash_to_hex(last_obj_id) + objpath = self._path(objid) + startdir, lastfilename = objpath.rsplit("/", 1) + else: + startdir = self.root_path + lastfilename = None + # startdir = self.wf.build_url(startdir) + + for fname in islice( + self.wf.iterfiles(startdir, last_file_name=lastfilename), limit + ): + bytehex = fname.rsplit("/", 1)[-1] + yield hashutil.bytehex_to_hash(bytehex.encode()) # internal methods def _put_object(self, content, obj_id): @@ -206,4 +273,4 @@ self.wf.put(io.BytesIO(b"".join(compressor(content))), self._path(obj_id)) def _path(self, obj_id): - return hashutil.hash_to_hex(obj_id) + return self.slicer.get_path(hashutil.hash_to_hex(obj_id)) diff --git a/swh/objstorage/tests/test_objstorage_seaweedfs.py b/swh/objstorage/tests/test_objstorage_seaweedfs.py --- a/swh/objstorage/tests/test_objstorage_seaweedfs.py +++ b/swh/objstorage/tests/test_objstorage_seaweedfs.py @@ -3,50 +3,218 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from itertools import dropwhile, islice +import json +import os import unittest -from swh.objstorage.backends.seaweed import DEFAULT_LIMIT, WeedObjStorage +from requests.utils import get_encoding_from_headers +import requests_mock +from requests_mock.contrib import fixture + +from swh.objstorage.backends.pathslicing import PathSlicer +from swh.objstorage.backends.seaweed import WeedObjStorage from swh.objstorage.exc import Error from swh.objstorage.objstorage import decompressors from swh.objstorage.tests.objstorage_testing import ObjStorageTestFixture -class MockWeedFiler: - """ WeedFiler mock that replicates its API """ +class PathDict: + """A dict-like object that handles "path-like" keys in a recursive dict + structure. + + For example: + + >>> a = PathDict() + >>> a['path/to/file'] = 'some file content' + + will create a dict structure (in self.data) like: + + >>> print(a.data) + {'path': {'to': {'file': 'some file content'}}} + >>> 'path/to/file' in a + True + + This is a helper class for the FilerRequestsMock below. + """ + + def __init__(self): + self.data = {} + + def __setitem__(self, key, value): + if key.endswith("/"): + raise ValueError("Nope") + if key.startswith("/"): + key = key[1:] + path = key.split("/") + resu = self.data + for p in path[:-1]: + resu = resu.setdefault(p, {}) + resu[path[-1]] = value + + def __getitem__(self, key): + assert isinstance(key, str) + if key.startswith("/"): + key = key[1:] + if key.endswith("/"): + key = key[:-1] + + path = key.split("/") + resu = self.data + for p in path: + resu = resu[p] + return resu + + def __delitem__(self, key): + if key.startswith("/"): + key = key[1:] + if key.endswith("/"): + key = key[:-1] + path = key.split("/") + resu = self.data + for p in path[:-1]: + resu = resu.setdefault(p, {}) + del resu[path[-1]] + + def __contains__(self, key): + try: + self[key] + return True + except KeyError: + return False + + def flat(self): + def go(d): + for k, v in d.items(): + if isinstance(v, dict): + yield from go(v) + else: + yield k + + yield from go(self.data) + + +class FilerRequestsMock: + """This is a requests_mock based mock for the seaweedfs Filer API + + It does not implement the whole API, only the parts required to make the + WeedFiler (used by WeedObjStorage) work. + + It stores the files in a dict-based structure, eg. the file + '0a/32/0a3245983255' will be stored in a dict like: + + {'0a': {'32': {'0a3245983255': b'content'}}} - def __init__(self, url): - self.url = url - self.content = {} + It uses the PathDict helper class to make it a bit easier to handle this + dict structure. - def get(self, remote_path): - return self.content[remote_path] + """ - def put(self, fp, remote_path): - self.content[remote_path] = fp.read() + MODE_DIR = 0o20000000771 + MODE_FILE = 0o660 - def exists(self, remote_path): - return remote_path in self.content + def __init__(self): + self.content = PathDict() + self.requests_mock = fixture.Fixture() + self.requests_mock.setUp() + self.requests_mock.register_uri( + requests_mock.GET, requests_mock.ANY, content=self.get_cb + ) + self.requests_mock.register_uri( + requests_mock.POST, requests_mock.ANY, content=self.post_cb + ) + self.requests_mock.register_uri( + requests_mock.HEAD, requests_mock.ANY, content=self.head_cb + ) + self.requests_mock.register_uri( + requests_mock.DELETE, requests_mock.ANY, content=self.delete_cb + ) - def delete(self, remote_path): - del self.content[remote_path] + def head_cb(self, request, context): + if request.path not in self.content: + context.status_code = 404 - def list(self, dir, last_file_name=None, limit=DEFAULT_LIMIT): - keys = sorted(self.content.keys()) - if last_file_name is None: - idx = 0 + def get_cb(self, request, context): + content = None + if request.path not in self.content: + context.status_code = 404 else: - idx = keys.index(last_file_name) + 1 - return {"Entries": [{"FullPath": x} for x in keys[idx : idx + limit]]} + content = self.content[request.path] + if isinstance(content, dict): + if "limit" in request.qs: + limit = int(request.qs["limit"][0]) + assert limit > 0 + else: + limit = None + + items = sorted(content.items()) + if items and "lastfilename" in request.qs: + lastfilename = request.qs["lastfilename"][0] + # exclude all filenames up to lastfilename + items = dropwhile( + lambda kv: kv[0].split("/")[-1] <= lastfilename, items + ) + + if limit: + # +1 to easily detect if there are more + items = islice(items, limit + 1) + + entries = [ + { + "FullPath": os.path.join(request.path, fname), + "Mode": self.MODE_DIR + if isinstance(obj, dict) + else self.MODE_FILE, + } + for fname, obj in items + ] + + thereismore = False + if limit and len(entries) > limit: + entries = entries[:limit] + thereismore = True + + if entries: + lastfilename = entries[-1]["FullPath"].split("/")[-1] + else: + lastfilename = None + text = json.dumps( + { + "Path": request.path, + "Limit": limit, + "LastFileName": lastfilename, + "ShouldDisplayLoadMore": thereismore, + "Entries": entries, + } + ) + encoding = get_encoding_from_headers(request.headers) or "utf-8" + content = text.encode(encoding) + return content + + def post_cb(self, request, context): + from requests_toolbelt.multipart import decoder + + multipart_data = decoder.MultipartDecoder( + request.body, request.headers["content-type"] + ) + part = multipart_data.parts[0] + self.content[request.path] = part.content + + def delete_cb(self, request, context): + del self.content[request.path] class TestWeedObjStorage(ObjStorageTestFixture, unittest.TestCase): compression = "none" + slicing = "" def setUp(self): super().setUp() self.url = "http://127.0.0.1/test" - self.storage = WeedObjStorage(url=self.url, compression=self.compression) - self.storage.wf = MockWeedFiler(self.url) + self.storage = WeedObjStorage( + url=self.url, compression=self.compression, slicing=self.slicing + ) + self.mock = FilerRequestsMock() def test_compression(self): content, obj_id = self.hash_content(b"test compression") @@ -63,7 +231,7 @@ self.storage.add(content, obj_id=obj_id) path = self.storage._path(obj_id) - self.storage.wf.content[path] += b"trailing garbage" + self.mock.content[path] += b"trailing garbage" if self.compression == "none": with self.assertRaises(Error) as e: @@ -73,6 +241,13 @@ self.storage.get(obj_id) assert "trailing data" in e.exception.args[0] + def test_slicing(self): + content, obj_id = self.hash_content(b"test compression") + self.storage.add(content, obj_id=obj_id) + + slicer = PathSlicer("/test", self.slicing) + assert slicer.get_path(obj_id.hex()) in self.mock.content + class TestWeedObjStorageBz2(TestWeedObjStorage): compression = "bz2" @@ -88,3 +263,15 @@ class TestWeedObjStorageZlib(TestWeedObjStorage): compression = "zlib" + + +class TestWeedObjStorageWithSlicing(TestWeedObjStorage): + slicing = "0:2/2:4" + + +class TestWeedObjStorageWithSlicingAndSmallBatch(TestWeedObjStorage): + slicing = "0:2/2:4" + + def setUp(self): + super().setUp() + self.storage.wf.batchsize = 1