diff --git a/swh/storage/api/server.py b/swh/storage/api/server.py index 61d06d0b..80132fa0 100644 --- a/swh/storage/api/server.py +++ b/swh/storage/api/server.py @@ -1,461 +1,448 @@ # Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import json import logging import click from flask import request from swh.core import config from swh.storage import get_storage as get_swhstorage from swh.core.api import (SWHServerAPIApp, decode_request, error_handler, encode_data_server as encode_data) -DEFAULT_CONFIG_PATH = 'storage/storage' -DEFAULT_CONFIG = { - 'storage': ('dict', { - 'cls': 'local', - 'args': { - 'db': 'dbname=softwareheritage-dev', - 'objstorage': { - 'cls': 'pathslicing', - 'args': { - 'root': '/srv/softwareheritage/objects', - 'slicing': '0:2/2:4/4:6', - }, - }, - }, - }) -} - - app = SWHServerAPIApp(__name__) storage = None @app.errorhandler(Exception) def my_error_handler(exception): return error_handler(exception, encode_data) def get_storage(): global storage if not storage: storage = get_swhstorage(**app.config['storage']) return storage @app.route('/') def index(): return ''' Software Heritage storage server

You have reached the Software Heritage storage server.
See its documentation and API for more information

''' @app.route('/check_config', methods=['POST']) def check_config(): return encode_data(get_storage().check_config(**decode_request(request))) @app.route('/content/missing', methods=['POST']) def content_missing(): return encode_data(get_storage().content_missing( **decode_request(request))) @app.route('/content/missing/sha1', methods=['POST']) def content_missing_per_sha1(): return encode_data(get_storage().content_missing_per_sha1( **decode_request(request))) @app.route('/content/present', methods=['POST']) def content_find(): return encode_data(get_storage().content_find(**decode_request(request))) @app.route('/content/add', methods=['POST']) def content_add(): return encode_data(get_storage().content_add(**decode_request(request))) @app.route('/content/update', methods=['POST']) def content_update(): return encode_data(get_storage().content_update(**decode_request(request))) @app.route('/content/data', methods=['POST']) def content_get(): return encode_data(get_storage().content_get(**decode_request(request))) @app.route('/content/metadata', methods=['POST']) def content_get_metadata(): return encode_data(get_storage().content_get_metadata( **decode_request(request))) @app.route('/content/range', methods=['POST']) def content_get_range(): return encode_data(get_storage().content_get_range( **decode_request(request))) @app.route('/directory/missing', methods=['POST']) def directory_missing(): return encode_data(get_storage().directory_missing( **decode_request(request))) @app.route('/directory/add', methods=['POST']) def directory_add(): return encode_data(get_storage().directory_add(**decode_request(request))) @app.route('/directory/path', methods=['POST']) def directory_entry_get_by_path(): return encode_data(get_storage().directory_entry_get_by_path( **decode_request(request))) @app.route('/directory/ls', methods=['GET']) def directory_ls(): dir = request.args['directory'].encode('utf-8', 'surrogateescape') rec = json.loads(request.args.get('recursive', 'False').lower()) return encode_data(get_storage().directory_ls(dir, recursive=rec)) @app.route('/revision/add', methods=['POST']) def revision_add(): return encode_data(get_storage().revision_add(**decode_request(request))) @app.route('/revision', methods=['POST']) def revision_get(): return encode_data(get_storage().revision_get(**decode_request(request))) @app.route('/revision/log', methods=['POST']) def revision_log(): return encode_data(get_storage().revision_log(**decode_request(request))) @app.route('/revision/shortlog', methods=['POST']) def revision_shortlog(): return encode_data(get_storage().revision_shortlog( **decode_request(request))) @app.route('/revision/missing', methods=['POST']) def revision_missing(): return encode_data(get_storage().revision_missing( **decode_request(request))) @app.route('/release/add', methods=['POST']) def release_add(): return encode_data(get_storage().release_add(**decode_request(request))) @app.route('/release', methods=['POST']) def release_get(): return encode_data(get_storage().release_get(**decode_request(request))) @app.route('/release/missing', methods=['POST']) def release_missing(): return encode_data(get_storage().release_missing( **decode_request(request))) @app.route('/object/find_by_sha1_git', methods=['POST']) def object_find_by_sha1_git(): return encode_data(get_storage().object_find_by_sha1_git( **decode_request(request))) @app.route('/snapshot/add', methods=['POST']) def snapshot_add(): return encode_data(get_storage().snapshot_add(**decode_request(request))) @app.route('/snapshot', methods=['POST']) def snapshot_get(): return encode_data(get_storage().snapshot_get(**decode_request(request))) @app.route('/snapshot/by_origin_visit', methods=['POST']) def snapshot_get_by_origin_visit(): return encode_data(get_storage().snapshot_get_by_origin_visit( **decode_request(request))) @app.route('/snapshot/latest', methods=['POST']) def snapshot_get_latest(): return encode_data(get_storage().snapshot_get_latest( **decode_request(request))) @app.route('/snapshot/count_branches', methods=['POST']) def snapshot_count_branches(): return encode_data(get_storage().snapshot_count_branches( **decode_request(request))) @app.route('/snapshot/get_branches', methods=['POST']) def snapshot_get_branches(): return encode_data(get_storage().snapshot_get_branches( **decode_request(request))) @app.route('/origin/get', methods=['POST']) def origin_get(): return encode_data(get_storage().origin_get(**decode_request(request))) @app.route('/origin/get_range', methods=['POST']) def origin_get_range(): return encode_data(get_storage().origin_get_range( **decode_request(request))) @app.route('/origin/search', methods=['POST']) def origin_search(): return encode_data(get_storage().origin_search(**decode_request(request))) @app.route('/origin/count', methods=['POST']) def origin_count(): return encode_data(get_storage().origin_count(**decode_request(request))) @app.route('/origin/add_multi', methods=['POST']) def origin_add(): return encode_data(get_storage().origin_add(**decode_request(request))) @app.route('/origin/add', methods=['POST']) def origin_add_one(): return encode_data(get_storage().origin_add_one(**decode_request(request))) @app.route('/origin/visit/get', methods=['POST']) def origin_visit_get(): return encode_data(get_storage().origin_visit_get( **decode_request(request))) @app.route('/origin/visit/getby', methods=['POST']) def origin_visit_get_by(): return encode_data( get_storage().origin_visit_get_by(**decode_request(request))) @app.route('/origin/visit/add', methods=['POST']) def origin_visit_add(): return encode_data(get_storage().origin_visit_add( **decode_request(request))) @app.route('/origin/visit/update', methods=['POST']) def origin_visit_update(): return encode_data(get_storage().origin_visit_update( **decode_request(request))) @app.route('/person', methods=['POST']) def person_get(): return encode_data(get_storage().person_get(**decode_request(request))) @app.route('/fetch_history', methods=['GET']) def fetch_history_get(): return encode_data(get_storage().fetch_history_get(request.args['id'])) @app.route('/fetch_history/start', methods=['POST']) def fetch_history_start(): return encode_data( get_storage().fetch_history_start(**decode_request(request))) @app.route('/fetch_history/end', methods=['POST']) def fetch_history_end(): return encode_data( get_storage().fetch_history_end(**decode_request(request))) @app.route('/entity/add', methods=['POST']) def entity_add(): return encode_data( get_storage().entity_add(**decode_request(request))) @app.route('/entity/get', methods=['POST']) def entity_get(): return encode_data( get_storage().entity_get(**decode_request(request))) @app.route('/entity', methods=['GET']) def entity_get_one(): return encode_data(get_storage().entity_get_one(request.args['uuid'])) @app.route('/entity/from_lister_metadata', methods=['POST']) def entity_from_lister_metadata(): return encode_data(get_storage().entity_get_from_lister_metadata( **decode_request(request))) @app.route('/tool/data', methods=['POST']) def tool_get(): return encode_data(get_storage().tool_get( **decode_request(request))) @app.route('/tool/add', methods=['POST']) def tool_add(): return encode_data(get_storage().tool_add( **decode_request(request))) @app.route('/origin/metadata/add', methods=['POST']) def origin_metadata_add(): return encode_data(get_storage().origin_metadata_add(**decode_request( request))) @app.route('/origin/metadata/get', methods=['POST']) def origin_metadata_get_by(): return encode_data(get_storage().origin_metadata_get_by(**decode_request( request))) @app.route('/provider/add', methods=['POST']) def metadata_provider_add(): return encode_data(get_storage().metadata_provider_add(**decode_request( request))) @app.route('/provider/get', methods=['POST']) def metadata_provider_get(): return encode_data(get_storage().metadata_provider_get(**decode_request( request))) @app.route('/provider/getby', methods=['POST']) def metadata_provider_get_by(): return encode_data(get_storage().metadata_provider_get_by(**decode_request( request))) @app.route('/stat/counters', methods=['GET']) def stat_counters(): return encode_data(get_storage().stat_counters()) @app.route('/algos/diff_directories', methods=['POST']) def diff_directories(): return encode_data(get_storage().diff_directories( **decode_request(request))) @app.route('/algos/diff_revisions', methods=['POST']) def diff_revisions(): return encode_data(get_storage().diff_revisions(**decode_request(request))) @app.route('/algos/diff_revision', methods=['POST']) def diff_revision(): return encode_data(get_storage().diff_revision(**decode_request(request))) api_cfg = None -def load_and_check_config(config_file, type='any'): - """Check minimal configuration is set or raise an error explanation. +def load_and_check_config(config_file, type='local'): + """Check the minimal configuration is set to run the api or raise an + error explanation. Args: config_file (str): Path to the configuration file to load - type (str): configuration type. For 'production' type, more + type (str): configuration type. For 'local' type, more checks are done. Raises: Error if the setup is not as expected Returns: configuration as a dict """ + if not config_file: + raise EnvironmentError('Configuration file must be defined') + if not os.path.exists(config_file): - raise ValueError('Configuration file %s does not exist.' % config_file) + raise EnvironmentError('Configuration file %s does not exist' % ( + config_file, )) - cfg = config.read(config_file, DEFAULT_CONFIG) + cfg = config.read(config_file) if 'storage' not in cfg: - raise ValueError("missing '%storage' configuration") + raise EnvironmentError("Missing '%storage' configuration") - if type == 'production': + if type == 'local': vcfg = cfg['storage'] - if vcfg['cls'] != 'local': + cls = vcfg.get('cls') + if cls != 'local': raise EnvironmentError( "The storage backend can only be started with a 'local' " - "configuration", err=True) + "configuration") args = vcfg['args'] for key in ('db', 'objstorage'): if not args.get(key): - raise ValueError( - "invalid configuration; missing %s config entry." % key) + raise EnvironmentError( + "Invalid configuration; missing '%s' config entry" % key) return cfg -def run_from_webserver(environ, start_response, - config_path=DEFAULT_CONFIG_PATH): +def make_app_from_configfile(environ, start_response): """Run the WSGI app from the webserver, loading the configuration from a configuration file. - SWH_CONFIG_FILENAME environment variables takes precedence over - the config_path provided to this function. + SWH_CONFIG_FILENAME environment variable defines the + configuration path to load. """ global api_cfg if not api_cfg: - config_file = environ.get('SWH_CONFIG_FILENAME', config_path) + config_file = environ.get('SWH_CONFIG_FILENAME') api_cfg = load_and_check_config(config_file) app.config.update(api_cfg) handler = logging.StreamHandler() app.logger.addHandler(handler) return app(environ, start_response) @click.command() @click.argument('config-path', required=1) @click.option('--host', default='0.0.0.0', help="Host to run the server") @click.option('--port', default=5002, type=click.INT, help="Binding port of the server") @click.option('--debug/--nodebug', default=True, help="Indicates if the server should run in debug mode") def launch(config_path, host, port, debug): api_cfg = load_and_check_config(config_path, type='any') app.config.update(api_cfg) app.run(host, port=int(port), debug=bool(debug)) if __name__ == '__main__': launch() diff --git a/swh/storage/tests/test_server.py b/swh/storage/tests/test_server.py new file mode 100644 index 00000000..534d85fa --- /dev/null +++ b/swh/storage/tests/test_server.py @@ -0,0 +1,119 @@ +# Copyright (C) 2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import pytest +import yaml + +from swh.storage.api.server import load_and_check_config + + +def test_load_and_check_config_no_configuration(): + """Inexistant configuration files raises""" + with pytest.raises(EnvironmentError) as e: + load_and_check_config(None) + + assert e.value.args[0] == 'Configuration file must be defined' + + config_path = '/some/inexistant/config.yml' + with pytest.raises(EnvironmentError) as e: + load_and_check_config(config_path) + + assert e.value.args[0] == 'Configuration file %s does not exist' % ( + config_path, ) + + +def test_load_and_check_config_wrong_configuration(tmpdir): + """Wrong configuration raises""" + config_path = tmpdir / 'config.yml' + config_path.write_text('something: useless', encoding='utf-8') + + with pytest.raises(EnvironmentError) as e: + load_and_check_config(config_path) + + assert e.value.args[0] == 'Missing \'%storage\' configuration' + + +def test_load_and_check_config_remote_config_local_type_raise(tmpdir): + """'local' configuration without 'local' storage raises""" + config_path = tmpdir / 'config.yml' + config = { + 'storage': { + 'cls': 'remote', + 'args': {} + } + } + + config_path.write_text(yaml.dump(config), encoding='utf-8') + with pytest.raises(EnvironmentError) as e: + load_and_check_config(config_path, type='local') + + assert ( + e.value.args[0] == + "The storage backend can only be started with a 'local' configuration" + ) + + +def test_load_and_check_config_local_incomplete_configuration(tmpdir): + """Incomplete 'local' configuration should raise""" + config_path = tmpdir / 'config.yml' + + config = { + 'storage': { + 'cls': 'local', + 'args': { + 'db': 'database', + 'objstorage': 'object_storage' + } + } + } + + import copy + for key in ('db', 'objstorage'): + c = copy.deepcopy(config) + c['storage']['args'].pop(key) + config_path.write_text(yaml.dump(c), encoding='utf-8') + with pytest.raises(EnvironmentError) as e: + load_and_check_config(config_path) + + assert ( + e.value.args[0] == + "Invalid configuration; missing '%s' config entry" % key + ) + + +def test_load_and_check_config_local_config_fine(tmpdir): + """'Remote configuration is fine""" + config_path = tmpdir / 'config.yml' + + config = { + 'storage': { + 'cls': 'local', + 'args': { + 'db': 'db', + 'objstorage': 'something', + } + } + } + + config_path.write_text(yaml.dump(config), encoding='utf-8') + cfg = load_and_check_config(config_path, type='local') + assert cfg == config + + +def test_load_and_check_config_remote_config_fine(tmpdir): + """'Remote configuration is fine""" + config_path = tmpdir / 'config.yml' + + config = { + 'storage': { + 'cls': 'remote', + 'args': {} + } + } + + config_path.write_text(yaml.dump(config), encoding='utf-8') + cfg = load_and_check_config(config_path, type='any') + + assert cfg == config