diff --git a/requirements.txt b/requirements.txt --- a/requirements.txt +++ b/requirements.txt @@ -4,8 +4,7 @@ vcversioner # remote storage API server -flask - +aiohttp click # optional dependencies diff --git a/swh/objstorage/api/client.py b/swh/objstorage/api/client.py --- a/swh/objstorage/api/client.py +++ b/swh/objstorage/api/client.py @@ -5,8 +5,9 @@ from swh.core.api import SWHRemoteAPI +from swh.model import hashutil -from ..objstorage import ObjStorage +from ..objstorage import ObjStorage, DEFAULT_CHUNK_SIZE from ..exc import ObjStorageAPIError @@ -42,7 +43,22 @@ return self.post('content/get/batch', {'obj_ids': obj_ids}) def check(self, obj_id): - self.post('content/check', {'obj_id': obj_id}) + return self.post('content/check', {'obj_id': obj_id}) + + # Management methods def get_random(self, batch_size): return self.post('content/get/random', {'batch_size': batch_size}) + + # Streaming methods + + def add_stream(self, content_iter, obj_id, check_presence=True): + obj_id = hashutil.hash_to_hex(obj_id) + return self.post_stream('content/add_stream/{}'.format(obj_id), + params={'check_presence': check_presence}, + data=content_iter) + + def get_stream(self, obj_id, chunk_size=DEFAULT_CHUNK_SIZE): + obj_id = hashutil.hash_to_hex(obj_id) + return super().get_stream('content/get_stream/{}'.format(obj_id), + chunk_size=chunk_size) diff --git a/swh/objstorage/api/server.py b/swh/objstorage/api/server.py --- a/swh/objstorage/api/server.py +++ b/swh/objstorage/api/server.py @@ -3,15 +3,14 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import asyncio +import aiohttp.web import click -import logging - -from flask import g, request from swh.core import config -from swh.core.api import (SWHServerAPIApp, decode_request, - error_handler, - encode_data_server as encode_data) +from swh.core.api_async import (SWHRemoteAPI, decode_request, + encode_data_server as encode_data) +from swh.model import hashutil from swh.objstorage import get_objstorage @@ -23,78 +22,104 @@ }) } -app = SWHServerAPIApp(__name__) - - -@app.errorhandler(Exception) -def my_error_handler(exception): - return error_handler(exception, encode_data) +@asyncio.coroutine +def index(request): + return aiohttp.web.Response(body="SWH Objstorage API server") -@app.before_request -def before_request(): - g.objstorage = get_objstorage(app.config['cls'], app.config['args']) +@asyncio.coroutine +def check_config(request): + req = yield from decode_request(request) + return encode_data(request.app['objstorage'].check_config(**req)) -@app.route('/') -def index(): - return "SWH Objstorage API server" +@asyncio.coroutine +def contains(request): + req = yield from decode_request(request) + return encode_data(request.app['objstorage'].__contains__(**req)) -@app.route('/check_config', methods=['POST']) -def check_config(): - return encode_data(g.objstorage.check_config(**decode_request(request))) +@asyncio.coroutine +def add_bytes(request): + req = yield from decode_request(request) + return encode_data(request.app['objstorage'].add(**req)) -@app.route('/content') -def content(): - return str(list(g.storage)) +@asyncio.coroutine +def get_bytes(request): + req = yield from decode_request(request) + return encode_data(request.app['objstorage'].get(**req)) -@app.route('/content/contains', methods=['POST']) -def contains(): - return encode_data(g.objstorage.__contains__(**decode_request(request))) +@asyncio.coroutine +def get_batch(request): + req = yield from decode_request(request) + return encode_data(request.app['objstorage'].get_batch(**req)) -@app.route('/content/add', methods=['POST']) -def add_bytes(): - return encode_data(g.objstorage.add(**decode_request(request))) +@asyncio.coroutine +def check(request): + req = yield from decode_request(request) + return encode_data(request.app['objstorage'].check(**req)) -@app.route('/content/get', methods=['POST']) -def get_bytes(): - return encode_data(g.objstorage.get(**decode_request(request))) +# Management methods -@app.route('/content/get/batch', methods=['POST']) -def get_batch(): - return encode_data(g.objstorage.get_batch(**decode_request(request))) +@asyncio.coroutine +def get_random_contents(request): + req = yield from decode_request(request) + return encode_data(request.app['objstorage'].get_random(**req)) -@app.route('/content/get/random', methods=['POST']) -def get_random_contents(): - return encode_data( - g.objstorage.get_random(**decode_request(request)) - ) +# Streaming methods +@asyncio.coroutine +def add_stream(request): + hex_id = request.match_info['hex_id'] + obj_id = hashutil.hash_to_bytes(hex_id) + check_pres = (request.query.get('check_presence', '').lower() == 'true') + objstorage = request.app['objstorage'] -@app.route('/content/check', methods=['POST']) -def check(): - return encode_data(g.objstorage.check(**decode_request(request))) + if check_pres and obj_id in objstorage: + return encode_data(obj_id) + with objstorage.chunk_writer(obj_id) as write: + # XXX (3.5): use 'async for chunk in request.content.iter_any()' + while not request.content.at_eof(): + chunk = yield from request.content.readany() + write(chunk) -def run_from_webserver(environ, start_response): - """Run the WSGI app from the webserver, loading the configuration. + return encode_data(obj_id) - """ - config_path = '/etc/softwareheritage/storage/objstorage.yml' - app.config.update(config.read(config_path, DEFAULT_CONFIG)) +@asyncio.coroutine +def get_stream(request): + hex_id = request.match_info['hex_id'] + obj_id = hashutil.hash_to_bytes(hex_id) + response = aiohttp.web.StreamResponse() + yield from response.prepare(request) + for chunk in request.app['objstorage'].get_stream(obj_id, 2 << 20): + response.write(chunk) + yield from response.drain() + return response - handler = logging.StreamHandler() - app.logger.addHandler(handler) - return app(environ, start_response) +def make_app(config, **kwargs): + app = SWHRemoteAPI(**kwargs) + app.router.add_route('GET', '/', index) + app.router.add_route('POST', '/check_config', check_config) + app.router.add_route('POST', '/content/contains', contains) + app.router.add_route('POST', '/content/add', add_bytes) + app.router.add_route('POST', '/content/get', get_bytes) + app.router.add_route('POST', '/content/get/batch', get_batch) + app.router.add_route('POST', '/content/get/random', get_random_contents) + app.router.add_route('POST', '/content/check', check) + app.router.add_route('POST', '/content/add_stream/{hex_id}', add_stream) + app.router.add_route('GET', '/content/get_stream/{hex_id}', get_stream) + app.update(config) + app['objstorage'] = get_objstorage(app['cls'], app['args']) + return app @click.command() @@ -105,8 +130,8 @@ @click.option('--debug/--nodebug', default=True, help="Indicates if the server should run in debug mode") def launch(config_path, host, port, debug): - app.config.update(config.read(config_path, DEFAULT_CONFIG)) - app.run(host, port=int(port), debug=bool(debug)) + app = make_app(config.read(config_path, DEFAULT_CONFIG), debug=bool(debug)) + aiohttp.web.run_app(app, host=host, port=int(port)) if __name__ == '__main__': diff --git a/swh/objstorage/objstorage.py b/swh/objstorage/objstorage.py --- a/swh/objstorage/objstorage.py +++ b/swh/objstorage/objstorage.py @@ -12,6 +12,7 @@ ID_HASH_ALGO = 'sha1' ID_HASH_LENGTH = 40 # Size in bytes of the hash hexadecimal representation. +DEFAULT_CHUNK_SIZE = 2 * 1024 * 1024 # Size in bytes of the streaming chunks def compute_hash(content): @@ -38,6 +39,12 @@ - get_random() get random object id of existing contents (used for the content integrity checker). + Some of the methods have available streaming equivalents: + + - add_stream() same as add() but with a chunked iterator + - restore_stream() same as add_stream() but erase already existing content + - get_stream() same as get() but returns a chunked iterator + Each implementation of this interface can have a different behavior and its own way to store the contents. """ @@ -90,7 +97,7 @@ def restore(self, content, obj_id=None, *args, **kwargs): """Restore a content that have been corrupted. - This function is identical to add_bytes but does not check if + This function is identical to add but does not check if the object id is already in the file system. The default implementation provided by the current class is suitable for most cases. @@ -164,6 +171,8 @@ """ pass + # Management methods + def get_random(self, batch_size, *args, **kwargs): """Get random ids of existing contents. @@ -179,3 +188,56 @@ """ pass + + # Streaming methods + + def add_stream(self, content_iter, obj_id, check_presence=True): + """Add a new object to the object storage using streaming. + + This function is identical to add() except it takes a generator that + yields the chunked content instead of the whole content at once. + + Args: + content (bytes): chunked generator that yields the object's raw + content to add in storage. + obj_id (bytes): object identifier + check_presence (bool): indicate if the presence of the + content should be verified before adding the file. + + Returns: + the id (bytes) of the object into the storage. + + """ + raise NotImplementedError + + def restore_stream(self, content_iter, obj_id=None): + """Restore a content that have been corrupted using streaming. + + This function is identical to restore() except it takes a generator + that yields the chunked content instead of the whole content at once. + The default implementation provided by the current class is + suitable for most cases. + + Args: + content (bytes): chunked generator that yields the object's raw + content to add in storage. + obj_id (bytes): object identifier + + """ + # check_presence to false will erase the potential previous content. + return self.add_stream(content_iter, obj_id, check_presence=False) + + def get_stream(self, obj_id, chunk_size=DEFAULT_CHUNK_SIZE): + """Retrieve the content of a given object as a chunked iterator. + + Args: + obj_id (bytes): object id. + + Returns: + the content of the requested object as bytes. + + Raises: + ObjNotFoundError: if the requested object is missing. + + """ + raise NotImplementedError diff --git a/swh/objstorage/objstorage_pathslicing.py b/swh/objstorage/objstorage_pathslicing.py --- a/swh/objstorage/objstorage_pathslicing.py +++ b/swh/objstorage/objstorage_pathslicing.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import functools import os import gzip import tempfile @@ -12,7 +13,8 @@ from swh.model import hashutil -from .objstorage import ObjStorage, compute_hash, ID_HASH_ALGO, ID_HASH_LENGTH +from .objstorage import (ObjStorage, compute_hash, ID_HASH_ALGO, + ID_HASH_LENGTH, DEFAULT_CHUNK_SIZE) from .exc import ObjNotFoundError, Error @@ -269,6 +271,8 @@ # IOError is for compatibility with older python versions raise Error('Corrupt object %s is not a gzip file' % obj_id) + # Management methods + def get_random(self, batch_size): def get_random_content(self, batch_size): """ Get a batch of content inside a single directory. @@ -294,3 +298,30 @@ length, it = get_random_content(self, batch_size) batch_size = batch_size - length yield from it + + # Streaming methods + + @contextmanager + def chunk_writer(self, obj_id): + hex_obj_id = hashutil.hash_to_hex(obj_id) + with _write_obj_file(hex_obj_id, self) as f: + yield f.write + + def add_stream(self, content_iter, obj_id, check_presence=True): + if check_presence and obj_id in self: + return obj_id + + with self.chunk_writer(obj_id) as writer: + for chunk in content_iter: + writer(chunk) + + return obj_id + + def get_stream(self, obj_id, chunk_size=DEFAULT_CHUNK_SIZE): + if obj_id not in self: + raise ObjNotFoundError(obj_id) + + hex_obj_id = hashutil.hash_to_hex(obj_id) + with _read_obj_file(hex_obj_id, self) as f: + reader = functools.partial(f.read, chunk_size) + yield from iter(reader, b'') diff --git a/swh/objstorage/tests/objstorage_testing.py b/swh/objstorage/tests/objstorage_testing.py --- a/swh/objstorage/tests/objstorage_testing.py +++ b/swh/objstorage/tests/objstorage_testing.py @@ -3,6 +3,8 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import time + from nose.tools import istest from swh.model import hashutil @@ -98,3 +100,38 @@ self.storage.check(obj_id) except: self.fail('Integrity check failed') + + @istest + def add_stream(self): + content = [b'chunk1', b'chunk2'] + _, obj_id = self.hash_content(b''.join(content)) + try: + self.storage.add_stream(iter(content), obj_id=obj_id) + except NotImplementedError: + return + self.assertContentMatch(obj_id, b''.join(content)) + + @istest + def add_stream_sleep(self): + def gen_content(): + yield b'chunk1' + time.sleep(0.5) + yield b'chunk2' + _, obj_id = self.hash_content(b'placeholder_id') + try: + self.storage.add_stream(gen_content(), obj_id=obj_id) + except NotImplementedError: + return + self.assertContentMatch(obj_id, b'chunk1chunk2') + + @istest + def get_stream(self): + content_l = [b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9'] + content = b''.join(content_l) + _, obj_id = self.hash_content(content) + self.storage.add(content, obj_id=obj_id) + try: + r = list(self.storage.get_stream(obj_id, chunk_size=1)) + except NotImplementedError: + return + self.assertEqual(r, content_l) diff --git a/swh/objstorage/tests/server_testing.py b/swh/objstorage/tests/server_testing.py --- a/swh/objstorage/tests/server_testing.py +++ b/swh/objstorage/tests/server_testing.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import aiohttp.web import multiprocessing import socket import time @@ -46,7 +47,7 @@ # WSGI app configuration for key, value in self.config.items(): - self.app.config[key] = value + self.app[key] = value # Get an available port number sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.bind(('127.0.0.1', 0)) @@ -55,7 +56,7 @@ # Worker function for multiprocessing def worker(app, port): - return app.run(port=port, use_reloader=False) + return aiohttp.web.run_app(app, port=port) self.process = multiprocessing.Process( target=worker, args=(self.app, self.port) diff --git a/swh/objstorage/tests/test_objstorage_api.py b/swh/objstorage/tests/test_objstorage_api.py --- a/swh/objstorage/tests/test_objstorage_api.py +++ b/swh/objstorage/tests/test_objstorage_api.py @@ -11,7 +11,7 @@ from swh.objstorage import get_objstorage from swh.objstorage.tests.objstorage_testing import ObjStorageTestFixture from swh.objstorage.tests.server_testing import ServerTestFixture -from swh.objstorage.api.server import app +from swh.objstorage.api.server import make_app @attr('db') @@ -29,7 +29,7 @@ } } - self.app = app + self.app = make_app(self.config) super().setUp() self.storage = get_objstorage('remote', { 'url': self.url() diff --git a/swh/objstorage/tests/test_objstorage_instantiation.py b/swh/objstorage/tests/test_objstorage_instantiation.py --- a/swh/objstorage/tests/test_objstorage_instantiation.py +++ b/swh/objstorage/tests/test_objstorage_instantiation.py @@ -8,14 +8,12 @@ from nose.tools import istest -from swh.objstorage.tests.server_testing import ServerTestFixture from swh.objstorage import get_objstorage from swh.objstorage.objstorage_pathslicing import PathSlicingObjStorage from swh.objstorage.api.client import RemoteObjStorage -from swh.objstorage.api.server import app -class TestObjStorageInitialization(ServerTestFixture, unittest.TestCase): +class TestObjStorageInitialization(unittest.TestCase): """ Test that the methods for ObjStorage initializations with `get_objstorage` works properly. """ @@ -23,7 +21,6 @@ def setUp(self): self.path = tempfile.mkdtemp() # Server is launched at self.url() - self.app = app self.config = {'storage_base': tempfile.mkdtemp(), 'storage_slicing': '0:1/0:5'} super().setUp() @@ -42,7 +39,7 @@ conf = { 'cls': 'remote', 'args': { - 'url': self.url() + 'url': 'http://127.0.0.1:4242/' } } st = get_objstorage(**conf)