diff --git a/.gitignore b/.gitignore index 6d6899d..805c492 100644 --- a/.gitignore +++ b/.gitignore @@ -1,13 +1,14 @@ *.pyc *.sw? *~ /.coverage /.coverage.* .eggs/ __pycache__ build dist swh.core.egg-info version.txt .tox .hypothesis +.mypy_cache/ diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..2d45f26 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,36 @@ +[mypy] +namespace_packages = True +warn_unused_ignores = True + + +# 3rd party libraries without stubs (yet) + +[mypy-aiohttp_utils.*] +ignore_missing_imports = True + +[mypy-arrow.*] +ignore_missing_imports = True + +[mypy-celery.*] +ignore_missing_imports = True + +[mypy-decorator.*] +ignore_missing_imports = True + +[mypy-deprecated.*] +ignore_missing_imports = True + +[mypy-msgpack.*] +ignore_missing_imports = True + +[mypy-pkg_resources.*] +ignore_missing_imports = True + +[mypy-psycopg2.*] +ignore_missing_imports = True + +[mypy-pytest.*] +ignore_missing_imports = True + +[mypy-systemd.*] +ignore_missing_imports = True diff --git a/swh/__init__.py b/swh/__init__.py index 69e3be5..de9df06 100644 --- a/swh/__init__.py +++ b/swh/__init__.py @@ -1 +1,4 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) +from typing import Iterable + +__path__ = __import__('pkgutil').extend_path(__path__, + __name__) # type: Iterable[str] diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py index 7df9100..363c79f 100644 --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -1,347 +1,349 @@ # Copyright (C) 2015-2017 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import collections import functools import inspect import json import logging import pickle import requests import datetime +from typing import ClassVar, Optional, Type + from deprecated import deprecated from flask import Flask, Request, Response, request, abort from .serializers import (decode_response, encode_data_client as encode_data, msgpack_dumps, msgpack_loads, SWHJSONDecoder) from .negotiation import (Formatter as FormatterBase, Negotiator as NegotiatorBase, negotiate as _negotiate) logger = logging.getLogger(__name__) # support for content negotiation class Negotiator(NegotiatorBase): def best_mimetype(self): return request.accept_mimetypes.best_match( self.accept_mimetypes, 'application/json') def _abort(self, status_code, err=None): return abort(status_code, err) def negotiate(formatter_cls, *args, **kwargs): return _negotiate(Negotiator, formatter_cls, *args, **kwargs) class Formatter(FormatterBase): def _make_response(self, body, content_type): return Response(body, content_type=content_type) class SWHJSONEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, (datetime.datetime, datetime.date)): return obj.isoformat() if isinstance(obj, datetime.timedelta): return str(obj) # Let the base class default method raise the TypeError return super().default(obj) class JSONFormatter(Formatter): format = 'json' mimetypes = ['application/json'] def render(self, obj): return json.dumps(obj, cls=SWHJSONEncoder) class MsgpackFormatter(Formatter): format = 'msgpack' mimetypes = ['application/x-msgpack'] def render(self, obj): return msgpack_dumps(obj) # base API classes class RemoteException(Exception): pass def remote_api_endpoint(path): def dec(f): f._endpoint_path = path return f return dec class APIError(Exception): """API Error""" def __str__(self): return ('An unexpected error occurred in the backend: {}' .format(self.args)) class MetaRPCClient(type): """Metaclass for RPCClient, which adds a method for each endpoint of the database it is designed to access. See for example :class:`swh.indexer.storage.api.client.RemoteStorage`""" def __new__(cls, name, bases, attributes): # For each method wrapped with @remote_api_endpoint in an API backend # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new # method in RemoteStorage, with the same documentation. # # Note that, despite the usage of decorator magic (eg. functools.wrap), # this never actually calls an IndexerStorage method. backend_class = attributes.get('backend_class', None) for base in bases: if backend_class is not None: break backend_class = getattr(base, 'backend_class', None) if backend_class: for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, '_endpoint_path'): cls.__add_endpoint(meth_name, meth, attributes) return super().__new__(cls, name, bases, attributes) @staticmethod def __add_endpoint(meth_name, meth, attributes): wrapped_meth = inspect.unwrap(meth) @functools.wraps(meth) # Copy signature and doc def meth_(*args, **kwargs): # Match arguments and parameters post_data = inspect.getcallargs( wrapped_meth, *args, **kwargs) # Remove arguments that should not be passed self = post_data.pop('self') post_data.pop('cur', None) post_data.pop('db', None) # Send the request. return self.post(meth._endpoint_path, post_data) attributes[meth_name] = meth_ class RPCClient(metaclass=MetaRPCClient): """Proxy to an internal SWH RPC """ - backend_class = None + backend_class = None # type: ClassVar[Optional[type]] """For each method of `backend_class` decorated with :func:`remote_api_endpoint`, a method with the same prototype and docstring will be added to this class. Calls to this new method will be translated into HTTP requests to a remote server. This backend class will never be instantiated, it only serves as a template.""" - api_exception = APIError + api_exception = APIError # type: ClassVar[Type[Exception]] """The exception class to raise in case of communication error with the server.""" def __init__(self, url, api_exception=None, timeout=None, chunk_size=4096, **kwargs): if api_exception: self.api_exception = api_exception base_url = url if url.endswith('/') else url + '/' self.url = base_url self.session = requests.Session() adapter = requests.adapters.HTTPAdapter( max_retries=kwargs.get('max_retries', 3), pool_connections=kwargs.get('pool_connections', 20), pool_maxsize=kwargs.get('pool_maxsize', 100)) self.session.mount(self.url, adapter) self.timeout = timeout self.chunk_size = chunk_size def _url(self, endpoint): return '%s%s' % (self.url, endpoint) def raw_verb(self, verb, endpoint, **opts): if 'chunk_size' in opts: # if the chunk_size argument has been passed, consider the user # also wants stream=True, otherwise, what's the point. opts['stream'] = True if self.timeout and 'timeout' not in opts: opts['timeout'] = self.timeout try: return getattr(self.session, verb)( self._url(endpoint), **opts ) except requests.exceptions.ConnectionError as e: raise self.api_exception(e) def post(self, endpoint, data, **opts): if isinstance(data, (collections.Iterator, collections.Generator)): data = (encode_data(x) for x in data) else: data = encode_data(data) chunk_size = opts.pop('chunk_size', self.chunk_size) response = self.raw_verb( 'post', endpoint, data=data, headers={'content-type': 'application/x-msgpack', 'accept': 'application/x-msgpack'}, **opts) if opts.get('stream') or \ response.headers.get('transfer-encoding') == 'chunked': return response.iter_content(chunk_size) else: return self._decode_response(response) post_stream = post def get(self, endpoint, **opts): chunk_size = opts.pop('chunk_size', self.chunk_size) response = self.raw_verb( 'get', endpoint, headers={'accept': 'application/x-msgpack'}, **opts) if opts.get('stream') or \ response.headers.get('transfer-encoding') == 'chunked': return response.iter_content(chunk_size) else: return self._decode_response(response) def get_stream(self, endpoint, **opts): return self.get(endpoint, stream=True, **opts) def _decode_response(self, response): if response.status_code == 404: return None if response.status_code == 500: data = decode_response(response) if 'exception_pickled' in data: raise pickle.loads(data['exception_pickled']) else: raise RemoteException(data['exception']) # XXX: this breaks language-independence and should be # replaced by proper unserialization if response.status_code == 400: raise pickle.loads(decode_response(response)) elif response.status_code != 200: raise RemoteException( "Unexpected status code for API request: %s (%s)" % ( response.status_code, response.content, ) ) return decode_response(response) def __repr__(self): return '<{} url={}>'.format(self.__class__.__name__, self.url) class BytesRequest(Request): """Request with proper escaping of arbitrary byte sequences.""" encoding = 'utf-8' encoding_errors = 'surrogateescape' ENCODERS = { 'application/x-msgpack': msgpack_dumps, 'application/json': json.dumps, } def encode_data_server(data, content_type='application/x-msgpack'): encoded_data = ENCODERS[content_type](data) return Response( encoded_data, mimetype=content_type, ) def decode_request(request): content_type = request.mimetype data = request.get_data() if not data: return {} if content_type == 'application/x-msgpack': r = msgpack_loads(data) elif content_type == 'application/json': r = json.loads(data, cls=SWHJSONDecoder) else: raise ValueError('Wrong content type `%s` for API request' % content_type) return r def error_handler(exception, encoder): # XXX: this breaks language-independence and should be # replaced by proper serialization of errors logging.exception(exception) response = encoder(pickle.dumps(exception)) response.status_code = 400 return response class RPCServerApp(Flask): """For each endpoint of the given `backend_class`, tells app.route to call a function that decodes the request and sends it to the backend object provided by the factory. :param Any backend_class: The class of the backend, which will be analyzed to look for API endpoints. :param Callable[[], backend_class] backend_factory: A function with no argument that returns an instance of `backend_class`.""" request_class = BytesRequest def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs): super().__init__(*args, **kwargs) if backend_class is not None: if backend_factory is None: raise TypeError('Missing argument backend_factory') for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, '_endpoint_path'): self.__add_endpoint(meth_name, meth, backend_factory) def __add_endpoint(self, meth_name, meth, backend_factory): from flask import request @self.route('/'+meth._endpoint_path, methods=['POST']) @functools.wraps(meth) # Copy signature and doc def _f(): # Call the actual code obj_meth = getattr(backend_factory(), meth_name) return encode_data_server(obj_meth(**decode_request(request))) @deprecated(version='0.0.64', reason='Use the RPCServerApp instead') class SWHServerAPIApp(RPCServerApp): pass @deprecated(version='0.0.64', reason='Use the MetaRPCClient instead') class MetaSWHRemoteAPI(MetaRPCClient): pass @deprecated(version='0.0.64', reason='Use the RPCClient instead') class SWHRemoteAPI(RPCClient): pass diff --git a/swh/core/api/negotiation.py b/swh/core/api/negotiation.py index 91d658d..de57742 100644 --- a/swh/core/api/negotiation.py +++ b/swh/core/api/negotiation.py @@ -1,152 +1,153 @@ # This code is a partial and adapted copy of # https://github.com/nickstenning/negotiate # # Copyright 2012-2013 Nick Stenning # 2019 The Software Heritage developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # from collections import defaultdict +from decorator import decorator from inspect import getcallargs -from decorator import decorator +from typing import Any, List, Optional class FormatterNotFound(Exception): pass class Formatter: - format = None - mimetypes = [] + format = None # type: Optional[str] + mimetypes = [] # type: List[Any] def __init__(self, request_mimetype=None): if request_mimetype is None or request_mimetype not in self.mimetypes: try: self.response_mimetype = self.mimetypes[0] except IndexError: raise NotImplementedError( "%s.mimetypes should be a non-empty list" % self.__class__.__name__) else: self.response_mimetype = request_mimetype def configure(self): pass def render(self, obj): raise NotImplementedError( "render() should be implemented by Formatter subclasses") def __call__(self, obj): return self._make_response( self.render(obj), content_type=self.response_mimetype) def _make_response(self, body, content_type): raise NotImplementedError( "_make_response() should be implemented by " "framework-specific subclasses of Formatter" ) class Negotiator: def __init__(self, func): self.func = func self._formatters = [] self._formatters_by_format = defaultdict(list) self._formatters_by_mimetype = defaultdict(list) def __call__(self, *args, **kwargs): result = self.func(*args, **kwargs) format = getcallargs(self.func, *args, **kwargs).get('format') mimetype = self.best_mimetype() try: formatter = self.get_formatter(format, mimetype) except FormatterNotFound as e: return self._abort(404, str(e)) return formatter(result) def register_formatter(self, formatter, *args, **kwargs): self._formatters.append(formatter) self._formatters_by_format[formatter.format].append( (formatter, args, kwargs)) for mimetype in formatter.mimetypes: self._formatters_by_mimetype[mimetype].append( (formatter, args, kwargs)) def get_formatter(self, format=None, mimetype=None): if format is None and mimetype is None: raise TypeError( "get_formatter expects one of the 'format' or 'mimetype' " "kwargs to be set") if format is not None: try: # the first added will be the most specific formatter_cls, args, kwargs = ( self._formatters_by_format[format][0]) except IndexError: raise FormatterNotFound( "Formatter for format '%s' not found!" % format) elif mimetype is not None: try: # the first added will be the most specific formatter_cls, args, kwargs = ( self._formatters_by_mimetype[mimetype][0]) except IndexError: raise FormatterNotFound( "Formatter for mimetype '%s' not found!" % mimetype) formatter = formatter_cls(request_mimetype=mimetype) formatter.configure(*args, **kwargs) return formatter @property def accept_mimetypes(self): return [m for f in self._formatters for m in f.mimetypes] def best_mimetype(self): raise NotImplementedError( "best_mimetype() should be implemented in " "framework-specific subclasses of Negotiator" ) def _abort(self, status_code, err=None): raise NotImplementedError( "_abort() should be implemented in framework-specific " "subclasses of Negotiator" ) def negotiate(negotiator_cls, formatter_cls, *args, **kwargs): def _negotiate(f, *args, **kwargs): return f.negotiator(*args, **kwargs) def decorate(f): if not hasattr(f, 'negotiator'): f.negotiator = negotiator_cls(f) f.negotiator.register_formatter(formatter_cls, *args, **kwargs) return decorator(_negotiate, f) return decorate diff --git a/swh/core/api/tests/test_api.py b/swh/core/api/tests/test_api.py index 32180f8..d414376 100644 --- a/swh/core/api/tests/test_api.py +++ b/swh/core/api/tests/test_api.py @@ -1,84 +1,84 @@ # Copyright (C) 2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import unittest -import requests_mock +import requests_mock # type: ignore from werkzeug.wrappers import BaseResponse from werkzeug.test import Client as WerkzeugTestClient from swh.core.api import ( error_handler, encode_data_server, remote_api_endpoint, RPCClient, RPCServerApp) class ApiTest(unittest.TestCase): def test_server(self): testcase = self nb_endpoint_calls = 0 class TestStorage: @remote_api_endpoint('test_endpoint_url') def test_endpoint(self, test_data, db=None, cur=None): nonlocal nb_endpoint_calls nb_endpoint_calls += 1 testcase.assertEqual(test_data, 'spam') return 'egg' app = RPCServerApp('testapp', backend_class=TestStorage, backend_factory=lambda: TestStorage()) @app.errorhandler(Exception) def my_error_handler(exception): return error_handler(exception, encode_data_server) client = WerkzeugTestClient(app, BaseResponse) res = client.post('/test_endpoint_url', headers={'Content-Type': 'application/x-msgpack'}, data=b'\x81\xa9test_data\xa4spam') self.assertEqual(nb_endpoint_calls, 1) self.assertEqual(b''.join(res.response), b'\xa3egg') def test_client(self): class TestStorage: @remote_api_endpoint('test_endpoint_url') def test_endpoint(self, test_data, db=None, cur=None): pass nb_http_calls = 0 def callback(request, context): nonlocal nb_http_calls nb_http_calls += 1 self.assertEqual(request.headers['Content-Type'], 'application/x-msgpack') self.assertEqual(request.body, b'\x81\xa9test_data\xa4spam') context.headers['Content-Type'] = 'application/x-msgpack' context.content = b'\xa3egg' return b'\xa3egg' adapter = requests_mock.Adapter() adapter.register_uri('POST', 'mock://example.com/test_endpoint_url', content=callback) class Testclient(RPCClient): backend_class = TestStorage def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # we need to mount the mock adapter on the base url to override # RPCClient's mechanism that also mounts an HTTPAdapter # (for configuration purpose) self.session.mount('mock://example.com/', adapter) c = Testclient(url='mock://example.com/') res = c.test_endpoint('spam') self.assertEqual(nb_http_calls, 1) self.assertEqual(res, 'egg') diff --git a/swh/core/api/tests/test_async.py b/swh/core/api/tests/test_async.py index 96fec21..2de1ced 100644 --- a/swh/core/api/tests/test_async.py +++ b/swh/core/api/tests/test_async.py @@ -1,187 +1,186 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime -import json - import msgpack +import json import pytest from swh.core.api.asynchronous import RPCServerApp, Response from swh.core.api.asynchronous import encode_msgpack, decode_request from swh.core.api.serializers import msgpack_dumps, SWHJSONEncoder pytest_plugins = ['aiohttp.pytest_plugin', 'pytester'] async def root(request): return Response('toor') STRUCT = {'txt': 'something stupid', # 'date': datetime.date(2019, 6, 9), # not supported 'datetime': datetime.datetime(2019, 6, 9, 10, 12), 'timedelta': datetime.timedelta(days=-2, hours=3), 'int': 42, 'float': 3.14, 'subdata': {'int': 42, 'datetime': datetime.datetime(2019, 6, 10, 11, 12), }, 'list': [42, datetime.datetime(2019, 9, 10, 11, 12), 'ok'], } async def struct(request): return Response(STRUCT) async def echo(request): data = await decode_request(request) return Response(data) async def echo_no_nego(request): # let the content negotiation handle the serialization for us... data = await decode_request(request) ret = encode_msgpack(data) return ret def check_mimetype(src, dst): src = src.split(';')[0].strip() dst = dst.split(';')[0].strip() assert src == dst @pytest.fixture def app(): app = RPCServerApp() app.router.add_route('GET', '/', root) app.router.add_route('GET', '/struct', struct) app.router.add_route('POST', '/echo', echo) app.router.add_route('POST', '/echo-no-nego', echo_no_nego) return app async def test_get_simple(app, aiohttp_client) -> None: assert app is not None cli = await aiohttp_client(app) resp = await cli.get('/') assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') data = await resp.read() value = msgpack.unpackb(data, raw=False) assert value == 'toor' async def test_get_simple_nego(app, aiohttp_client) -> None: cli = await aiohttp_client(app) for ctype in ('x-msgpack', 'json'): resp = await cli.get('/', headers={'Accept': 'application/%s' % ctype}) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/%s' % ctype) assert (await decode_request(resp)) == 'toor' async def test_get_struct(app, aiohttp_client) -> None: """Test returned structured from a simple GET data is OK""" cli = await aiohttp_client(app) resp = await cli.get('/struct') assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') assert (await decode_request(resp)) == STRUCT async def test_get_struct_nego(app, aiohttp_client) -> None: """Test returned structured from a simple GET data is OK""" cli = await aiohttp_client(app) for ctype in ('x-msgpack', 'json'): resp = await cli.get('/struct', headers={'Accept': 'application/%s' % ctype}) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/%s' % ctype) assert (await decode_request(resp)) == STRUCT async def test_post_struct_msgpack(app, aiohttp_client) -> None: """Test that msgpack encoded posted struct data is returned as is""" cli = await aiohttp_client(app) # simple struct resp = await cli.post( '/echo', headers={'Content-Type': 'application/x-msgpack'}, data=msgpack_dumps({'toto': 42})) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') assert (await decode_request(resp)) == {'toto': 42} # complex struct resp = await cli.post( '/echo', headers={'Content-Type': 'application/x-msgpack'}, data=msgpack_dumps(STRUCT)) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') assert (await decode_request(resp)) == STRUCT async def test_post_struct_json(app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as is""" cli = await aiohttp_client(app) resp = await cli.post( '/echo', headers={'Content-Type': 'application/json'}, data=json.dumps({'toto': 42}, cls=SWHJSONEncoder)) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') assert (await decode_request(resp)) == {'toto': 42} resp = await cli.post( '/echo', headers={'Content-Type': 'application/json'}, data=json.dumps(STRUCT, cls=SWHJSONEncoder)) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') # assert resp.headers['Content-Type'] == 'application/x-msgpack' assert (await decode_request(resp)) == STRUCT async def test_post_struct_nego(app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as is using content negotiation (accept json or msgpack). """ cli = await aiohttp_client(app) for ctype in ('x-msgpack', 'json'): resp = await cli.post( '/echo', headers={'Content-Type': 'application/json', 'Accept': 'application/%s' % ctype}, data=json.dumps(STRUCT, cls=SWHJSONEncoder)) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/%s' % ctype) assert (await decode_request(resp)) == STRUCT async def test_post_struct_no_nego(app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as msgpack when using non-negotiation-compatible handlers. """ cli = await aiohttp_client(app) for ctype in ('x-msgpack', 'json'): resp = await cli.post( '/echo-no-nego', headers={'Content-Type': 'application/json', 'Accept': 'application/%s' % ctype}, data=json.dumps(STRUCT, cls=SWHJSONEncoder)) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') assert (await decode_request(resp)) == STRUCT diff --git a/swh/core/config.py b/swh/core/config.py index 316b27b..f5babfd 100644 --- a/swh/core/config.py +++ b/swh/core/config.py @@ -1,360 +1,362 @@ # Copyright (C) 2015 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import configparser import logging import os import yaml from itertools import chain from copy import deepcopy +from typing import Any, Dict, Optional, Tuple + logger = logging.getLogger(__name__) SWH_CONFIG_DIRECTORIES = [ '~/.config/swh', '~/.swh', '/etc/softwareheritage', ] SWH_GLOBAL_CONFIG = 'global.ini' SWH_DEFAULT_GLOBAL_CONFIG = { 'content_size_limit': ('int', 100 * 1024 * 1024), 'log_db': ('str', 'dbname=softwareheritage-log'), } SWH_CONFIG_EXTENSIONS = [ '.yml', '.ini', ] # conversion per type _map_convert_fn = { 'int': int, 'bool': lambda x: x.lower() == 'true', 'list[str]': lambda x: [value.strip() for value in x.split(',')], 'list[int]': lambda x: [int(value.strip()) for value in x.split(',')], } _map_check_fn = { 'int': lambda x: isinstance(x, int), 'bool': lambda x: isinstance(x, bool), 'list[str]': lambda x: (isinstance(x, list) and all(isinstance(y, str) for y in x)), 'list[int]': lambda x: (isinstance(x, list) and all(isinstance(y, int) for y in x)), } def exists_accessible(file): """Check whether a file exists, and is accessible. Returns: True if the file exists and is accessible False if the file does not exist Raises: PermissionError if the file cannot be read. """ try: os.stat(file) except PermissionError: raise except FileNotFoundError: return False else: if os.access(file, os.R_OK): return True else: raise PermissionError("Permission denied: %r" % file) def config_basepath(config_path): """Return the base path of a configuration file""" if config_path.endswith(('.ini', '.yml')): return config_path[:-4] return config_path def read_raw_config(base_config_path): """Read the raw config corresponding to base_config_path. Can read yml or ini files. """ yml_file = base_config_path + '.yml' if exists_accessible(yml_file): logger.info('Loading config file %s', yml_file) with open(yml_file) as f: return yaml.safe_load(f) ini_file = base_config_path + '.ini' if exists_accessible(ini_file): config = configparser.ConfigParser() config.read(ini_file) if 'main' in config._sections: logger.info('Loading config file %s', ini_file) return config._sections['main'] else: logger.warning('Ignoring config file %s (no [main] section)', ini_file) return {} def config_exists(config_path): """Check whether the given config exists""" basepath = config_basepath(config_path) return any(exists_accessible(basepath + extension) for extension in SWH_CONFIG_EXTENSIONS) def read(conf_file=None, default_conf=None): """Read the user's configuration file. Fill in the gap using `default_conf`. `default_conf` is similar to this:: DEFAULT_CONF = { 'a': ('str', '/tmp/swh-loader-git/log'), 'b': ('str', 'dbname=swhloadergit') 'c': ('bool', true) 'e': ('bool', None) 'd': ('int', 10) } If conf_file is None, return the default config. """ conf = {} if conf_file: base_config_path = config_basepath(os.path.expanduser(conf_file)) conf = read_raw_config(base_config_path) if not default_conf: default_conf = {} # remaining missing default configuration key are set # also type conversion is enforced for underneath layer for key in default_conf: nature_type, default_value = default_conf[key] val = conf.get(key, None) if val is None: # fallback to default value conf[key] = default_value elif not _map_check_fn.get(nature_type, lambda x: True)(val): # value present but not in the proper format, force type conversion conf[key] = _map_convert_fn.get(nature_type, lambda x: x)(val) return conf def priority_read(conf_filenames, default_conf=None): """Try reading the configuration files from conf_filenames, in order, and return the configuration from the first one that exists. default_conf has the same specification as it does in read. """ # Try all the files in order for filename in conf_filenames: full_filename = os.path.expanduser(filename) if config_exists(full_filename): return read(full_filename, default_conf) # Else, return the default configuration return read(None, default_conf) def merge_default_configs(base_config, *other_configs): """Merge several default config dictionaries, from left to right""" full_config = base_config.copy() for config in other_configs: full_config.update(config) return full_config def merge_configs(base, other): """Merge two config dictionaries This does merge config dicts recursively, with the rules, for every value of the dicts (with 'val' not being a dict): - None + type -> type - type + None -> None - dict + dict -> dict (merged) - val + dict -> TypeError - dict + val -> TypeError - val + val -> val (other) for instance: >>> d1 = { ... 'key1': { ... 'skey1': 'value1', ... 'skey2': {'sskey1': 'value2'}, ... }, ... 'key2': 'value3', ... } with >>> d2 = { ... 'key1': { ... 'skey1': 'value4', ... 'skey2': {'sskey2': 'value5'}, ... }, ... 'key3': 'value6', ... } will give: >>> d3 = { ... 'key1': { ... 'skey1': 'value4', # <-- note this ... 'skey2': { ... 'sskey1': 'value2', ... 'sskey2': 'value5', ... }, ... }, ... 'key2': 'value3', ... 'key3': 'value6', ... } >>> assert merge_configs(d1, d2) == d3 Note that no type checking is done for anything but dicts. """ if not isinstance(base, dict) or not isinstance(other, dict): raise TypeError( 'Cannot merge a %s with a %s' % (type(base), type(other))) output = {} allkeys = set(chain(base.keys(), other.keys())) for k in allkeys: vb = base.get(k) vo = other.get(k) if isinstance(vo, dict): output[k] = merge_configs(vb is not None and vb or {}, vo) elif isinstance(vb, dict) and k in other and other[k] is not None: output[k] = merge_configs(vb, vo is not None and vo or {}) elif k in other: output[k] = deepcopy(vo) else: output[k] = deepcopy(vb) return output def swh_config_paths(base_filename): """Return the Software Heritage specific configuration paths for the given filename.""" return [os.path.join(dirname, base_filename) for dirname in SWH_CONFIG_DIRECTORIES] def prepare_folders(conf, *keys): """Prepare the folder mentioned in config under keys. """ def makedir(folder): if not os.path.exists(folder): os.makedirs(folder) for key in keys: makedir(conf[key]) def load_global_config(): """Load the global Software Heritage config""" return priority_read( swh_config_paths(SWH_GLOBAL_CONFIG), SWH_DEFAULT_GLOBAL_CONFIG, ) def load_named_config(name, default_conf=None, global_conf=True): """Load the config named `name` from the Software Heritage configuration paths. If global_conf is True (default), read the global configuration too. """ conf = {} if global_conf: conf.update(load_global_config()) conf.update(priority_read(swh_config_paths(name), default_conf)) return conf class SWHConfig: """Mixin to add configuration parsing abilities to classes The class should override the class attributes: - DEFAULT_CONFIG (default configuration to be parsed) - CONFIG_BASE_FILENAME (the filename of the configuration to be used) This class defines one classmethod, parse_config_file, which parses a configuration file using the default config as set in the class attribute. """ - DEFAULT_CONFIG = {} - CONFIG_BASE_FILENAME = '' + DEFAULT_CONFIG = {} # type: Dict[str, Tuple[str, Any]] + CONFIG_BASE_FILENAME = '' # type: Optional[str] @classmethod def parse_config_file(cls, base_filename=None, config_filename=None, additional_configs=None, global_config=True): """Parse the configuration file associated to the current class. By default, parse_config_file will load the configuration cls.CONFIG_BASE_FILENAME from one of the Software Heritage configuration directories, in order, unless it is overridden by base_filename or config_filename (which shortcuts the file lookup completely). Args: - base_filename (str): overrides the default cls.CONFIG_BASE_FILENAME - config_filename (str): sets the file to parse instead of the defaults set from cls.CONFIG_BASE_FILENAME - additional_configs: (list of default configuration dicts) allows to override or extend the configuration set in cls.DEFAULT_CONFIG. - global_config (bool): Load the global configuration (default: True) """ if config_filename: config_filenames = [config_filename] elif 'SWH_CONFIG_FILENAME' in os.environ: config_filenames = [os.environ['SWH_CONFIG_FILENAME']] else: if not base_filename: base_filename = cls.CONFIG_BASE_FILENAME config_filenames = swh_config_paths(base_filename) if not additional_configs: additional_configs = [] full_default_config = merge_default_configs(cls.DEFAULT_CONFIG, *additional_configs) config = {} if global_config: config = load_global_config() config.update(priority_read(config_filenames, full_default_config)) return config diff --git a/swh/core/db/tests/db_testing.py b/swh/core/db/tests/db_testing.py index 36eb914..c8bed92 100644 --- a/swh/core/db/tests/db_testing.py +++ b/swh/core/db/tests/db_testing.py @@ -1,317 +1,320 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import glob import subprocess import psycopg2 +from typing import Dict, Iterable, Optional, Tuple, Union + from swh.core.utils import numfile_sortkey as sortkey -DB_DUMP_TYPES = {'.sql': 'psql', '.dump': 'pg_dump'} + +DB_DUMP_TYPES = {'.sql': 'psql', '.dump': 'pg_dump'} # type: Dict[str, str] def swh_db_version(dbname_or_service): """Retrieve the swh version if any. In case of the db not initialized, this returns None. Otherwise, this returns the db's version. Args: dbname_or_service (str): The db's name or service Returns: Optional[Int]: Either the db's version or None """ query = 'select version from dbversion order by dbversion desc limit 1' cmd = [ 'psql', '--tuples-only', '--no-psqlrc', '--quiet', '-v', 'ON_ERROR_STOP=1', "--command=%s" % query, dbname_or_service ] try: r = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, universal_newlines=True) result = int(r.stdout.strip()) except Exception: # db not initialized result = None return result def pg_restore(dbname, dumpfile, dumptype='pg_dump'): """ Args: dbname: name of the DB to restore into dumpfile: path of the dump file dumptype: one of 'pg_dump' (for binary dumps), 'psql' (for SQL dumps) """ assert dumptype in ['pg_dump', 'psql'] if dumptype == 'pg_dump': subprocess.check_call(['pg_restore', '--no-owner', '--no-privileges', '--dbname', dbname, dumpfile]) elif dumptype == 'psql': subprocess.check_call(['psql', '--quiet', '--no-psqlrc', '-v', 'ON_ERROR_STOP=1', '-f', dumpfile, dbname]) def pg_dump(dbname, dumpfile): subprocess.check_call(['pg_dump', '--no-owner', '--no-privileges', '-Fc', '-f', dumpfile, dbname]) def pg_dropdb(dbname): subprocess.check_call(['dropdb', dbname]) def pg_createdb(dbname, check=True): """Create a db. If check is True and the db already exists, this will raise an exception (original behavior). If check is False and the db already exists, this will fail silently. If the db does not exist, the db will be created. """ subprocess.run(['createdb', dbname], check=check) def db_create(dbname, dumps=None): """create the test DB and load the test data dumps into it dumps is an iterable of couples (dump_file, dump_type). context: setUpClass """ try: pg_createdb(dbname) except subprocess.CalledProcessError: # try recovering once, in case pg_dropdb(dbname) # the db already existed pg_createdb(dbname) for dump, dtype in dumps: pg_restore(dbname, dump, dtype) return dbname def db_destroy(dbname): """destroy the test DB context: tearDownClass """ pg_dropdb(dbname) def db_connect(dbname): """connect to the test DB and open a cursor context: setUp """ conn = psycopg2.connect('dbname=' + dbname) return { 'conn': conn, 'cursor': conn.cursor() } def db_close(conn): """rollback current transaction and disconnect from the test DB context: tearDown """ if not conn.closed: conn.rollback() conn.close() class DbTestConn: def __init__(self, dbname): self.dbname = dbname def __enter__(self): self.db_setup = db_connect(self.dbname) self.conn = self.db_setup['conn'] self.cursor = self.db_setup['cursor'] return self def __exit__(self, *_): db_close(self.conn) class DbTestContext: def __init__(self, name='softwareheritage-test', dumps=None): self.dbname = name self.dumps = dumps def __enter__(self): db_create(dbname=self.dbname, dumps=self.dumps) return self def __exit__(self, *_): db_destroy(self.dbname) class DbTestFixture: """Mix this in a test subject class to get DB testing support. Use the class method add_db() to add a new database to be tested. Using this will create a DbTestConn entry in the `test_db` dictionary for all the tests, indexed by the name of the database. Example: class TestDb(DbTestFixture, unittest.TestCase): @classmethod def setUpClass(cls): cls.add_db('db_name', DUMP) super().setUpClass() def setUp(self): db = self.test_db['db_name'] print('conn: {}, cursor: {}'.format(db.conn, db.cursor)) To ensure test isolation, each test method of the test case class will execute in its own connection, cursor, and transaction. Note that if you want to define setup/teardown methods, you need to explicitly call super() to ensure that the fixture setup/teardown methods are invoked. Here is an example where all setup/teardown methods are defined in a test case: class TestDb(DbTestFixture, unittest.TestCase): @classmethod def setUpClass(cls): # your add_db() calls here super().setUpClass() # your class setup code here def setUp(self): super().setUp() # your instance setup code here def tearDown(self): # your instance teardown code here super().tearDown() @classmethod def tearDownClass(cls): # your class teardown code here super().tearDownClass() """ - _DB_DUMP_LIST = {} - _DB_LIST = {} + _DB_DUMP_LIST = {} # type: Dict[str, Iterable[Tuple[str, str]]] + _DB_LIST = {} # type: Dict[str, DbTestContext] DB_TEST_FIXTURE_IMPORTED = True @classmethod def add_db(cls, name='softwareheritage-test', dumps=None): cls._DB_DUMP_LIST[name] = dumps @classmethod def setUpClass(cls): for name, dumps in cls._DB_DUMP_LIST.items(): cls._DB_LIST[name] = DbTestContext(name, dumps) cls._DB_LIST[name].__enter__() super().setUpClass() @classmethod def tearDownClass(cls): super().tearDownClass() for name, context in cls._DB_LIST.items(): context.__exit__() def setUp(self, *args, **kwargs): self.test_db = {} for name in self._DB_LIST.keys(): self.test_db[name] = DbTestConn(name) self.test_db[name].__enter__() super().setUp(*args, **kwargs) def tearDown(self): super().tearDown() for name in self._DB_LIST.keys(): self.test_db[name].__exit__() def reset_db_tables(self, name, excluded=None): db = self.test_db[name] conn = db.conn cursor = db.cursor cursor.execute("""SELECT table_name FROM information_schema.tables WHERE table_schema = %s""", ('public',)) tables = set(table for (table,) in cursor.fetchall()) if excluded is not None: tables -= set(excluded) for table in tables: cursor.execute('truncate table %s cascade' % table) conn.commit() class SingleDbTestFixture(DbTestFixture): """Simplified fixture like DbTest but that can only handle a single DB. Gives access to shortcuts like self.cursor and self.conn. DO NOT use this with other fixtures that need to access databases, like StorageTestFixture. The class can override the following class attributes: TEST_DB_NAME: name of the DB used for testing TEST_DB_DUMP: DB dump to be restored before running test methods; can be set to None if no restore from dump is required. If the dump file name endswith" - '.sql' it will be loaded via psql, - '.dump' it will be loaded via pg_restore. Other file extensions will be ignored. Can be a string or a list of strings; each path will be expanded using glob pattern matching. The test case class will then have the following attributes, accessible via self: dbname: name of the test database conn: psycopg2 connection object cursor: open psycopg2 cursor to the DB """ TEST_DB_NAME = 'softwareheritage-test' - TEST_DB_DUMP = None + TEST_DB_DUMP = None # type: Optional[Union[str, Iterable[str]]] @classmethod def setUpClass(cls): cls.dbname = cls.TEST_DB_NAME # XXX to kill? dump_files = cls.TEST_DB_DUMP if dump_files is None: dump_files = [] elif isinstance(dump_files, str): dump_files = [dump_files] all_dump_files = [] for files in dump_files: all_dump_files.extend( sorted(glob.glob(files), key=sortkey)) all_dump_files = [(x, DB_DUMP_TYPES[os.path.splitext(x)[1]]) for x in all_dump_files] cls.add_db(name=cls.TEST_DB_NAME, dumps=all_dump_files) super().setUpClass() def setUp(self, *args, **kwargs): super().setUp(*args, **kwargs) db = self.test_db[self.TEST_DB_NAME] self.conn = db.conn self.cursor = db.cursor diff --git a/swh/core/py.typed b/swh/core/py.typed new file mode 100644 index 0000000..1242d43 --- /dev/null +++ b/swh/core/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561. diff --git a/tox.ini b/tox.ini index 586c09b..6909f8b 100644 --- a/tox.ini +++ b/tox.ini @@ -1,44 +1,53 @@ [tox] -envlist=flake8,py3-{core,db,server} +envlist=flake8,py3-{core,db,server},mypy [testenv:py3-core] deps = -rrequirements-test.txt . commands = pytest --doctest-modules swh/core/tests {posargs} [testenv:py3-db] deps = -rrequirements-test.txt .[db] pifpaf commands = pifpaf run postgresql -- \ pytest swh/core/db/tests {posargs} [testenv:py3-server] deps = -rrequirements-test.txt .[http] commands = pytest swh/core/api/tests {posargs} [testenv:py3] deps = .[testing] pytest-cov pifpaf commands = pifpaf run postgresql -- \ pytest --doctest-modules \ --hypothesis-profile=slow \ --cov=swh --cov-branch \ {posargs} [testenv:flake8] skip_install = true deps = flake8 commands = {envpython} -m flake8 + +[testenv:mypy] +skip_install = true +deps = + .[testing] + mypy + django-stubs +commands = + mypy swh