diff --git a/PKG-INFO b/PKG-INFO index 7e35bae..bb75d9d 100644 --- a/PKG-INFO +++ b/PKG-INFO @@ -1,88 +1,88 @@ Metadata-Version: 2.1 Name: swh.core -Version: 0.0.67 +Version: 0.0.68 Summary: Software Heritage core utilities Home-page: https://forge.softwareheritage.org/diffusion/DCORE/ Author: Software Heritage developers Author-email: swh-devel@inria.fr License: UNKNOWN -Project-URL: Source, https://forge.softwareheritage.org/source/swh-core Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate +Project-URL: Source, https://forge.softwareheritage.org/source/swh-core Description: swh-core ======== core library for swh's modules: - config parser - hash computations - serialization - logging mechanism - database connection - http-based RPC client/server Development ----------- We strongly recommend you to use a [virtualenv][1] if you want to run tests or hack the code. To set up your development environment: ``` (swh) user@host:~/swh-environment/swh-core$ pip install -e .[testing] ``` This will install every Python package needed to run this package's tests. Unit tests can be executed using [pytest][2] or [tox][3]. ``` (swh) user@host:~/swh-environment/swh-core$ pytest ============================== test session starts ============================== platform linux -- Python 3.7.3, pytest-3.10.1, py-1.8.0, pluggy-0.12.0 hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/home/ddouard/src/swh-environment/swh-core/.hypothesis/examples') rootdir: /home/ddouard/src/swh-environment/swh-core, inifile: pytest.ini plugins: requests-mock-1.6.0, hypothesis-4.26.4, celery-4.3.0, postgresql-1.4.1 collected 89 items swh/core/api/tests/test_api.py .. [ 2%] swh/core/api/tests/test_async.py .... [ 6%] swh/core/api/tests/test_serializers.py ..... [ 12%] swh/core/db/tests/test_db.py .... [ 16%] swh/core/tests/test_cli.py ...... [ 23%] swh/core/tests/test_config.py .............. [ 39%] swh/core/tests/test_statsd.py ........................................... [ 87%] .... [ 92%] swh/core/tests/test_utils.py ....... [100%] ===================== 89 passed, 9 warnings in 6.94 seconds ===================== ``` Note: this git repository uses [pre-commit][4] hooks to ensure better and more consistent code. It should already be installed in your virtualenv (if not, just type `pip install pre-commit`). Make sure to activate it in your local copy of the git repository: ``` (swh) user@host:~/swh-environment/swh-core$ pre-commit install pre-commit installed at .git/hooks/pre-commit ``` Please read the [developer setup manual][5] for more information on how to hack on Software Heritage. [1]: https://virtualenv.pypa.io [2]: https://docs.pytest.org [3]: https://tox.readthedocs.io [4]: https://pre-commit.com [5]: https://docs.softwareheritage.org/devel/developer-setup.html Platform: UNKNOWN Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable Description-Content-Type: text/markdown -Provides-Extra: http Provides-Extra: db +Provides-Extra: http Provides-Extra: testing diff --git a/swh.core.egg-info/PKG-INFO b/swh.core.egg-info/PKG-INFO index 7e35bae..bb75d9d 100644 --- a/swh.core.egg-info/PKG-INFO +++ b/swh.core.egg-info/PKG-INFO @@ -1,88 +1,88 @@ Metadata-Version: 2.1 Name: swh.core -Version: 0.0.67 +Version: 0.0.68 Summary: Software Heritage core utilities Home-page: https://forge.softwareheritage.org/diffusion/DCORE/ Author: Software Heritage developers Author-email: swh-devel@inria.fr License: UNKNOWN -Project-URL: Source, https://forge.softwareheritage.org/source/swh-core Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate +Project-URL: Source, https://forge.softwareheritage.org/source/swh-core Description: swh-core ======== core library for swh's modules: - config parser - hash computations - serialization - logging mechanism - database connection - http-based RPC client/server Development ----------- We strongly recommend you to use a [virtualenv][1] if you want to run tests or hack the code. To set up your development environment: ``` (swh) user@host:~/swh-environment/swh-core$ pip install -e .[testing] ``` This will install every Python package needed to run this package's tests. Unit tests can be executed using [pytest][2] or [tox][3]. ``` (swh) user@host:~/swh-environment/swh-core$ pytest ============================== test session starts ============================== platform linux -- Python 3.7.3, pytest-3.10.1, py-1.8.0, pluggy-0.12.0 hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/home/ddouard/src/swh-environment/swh-core/.hypothesis/examples') rootdir: /home/ddouard/src/swh-environment/swh-core, inifile: pytest.ini plugins: requests-mock-1.6.0, hypothesis-4.26.4, celery-4.3.0, postgresql-1.4.1 collected 89 items swh/core/api/tests/test_api.py .. [ 2%] swh/core/api/tests/test_async.py .... [ 6%] swh/core/api/tests/test_serializers.py ..... [ 12%] swh/core/db/tests/test_db.py .... [ 16%] swh/core/tests/test_cli.py ...... [ 23%] swh/core/tests/test_config.py .............. [ 39%] swh/core/tests/test_statsd.py ........................................... [ 87%] .... [ 92%] swh/core/tests/test_utils.py ....... [100%] ===================== 89 passed, 9 warnings in 6.94 seconds ===================== ``` Note: this git repository uses [pre-commit][4] hooks to ensure better and more consistent code. It should already be installed in your virtualenv (if not, just type `pip install pre-commit`). Make sure to activate it in your local copy of the git repository: ``` (swh) user@host:~/swh-environment/swh-core$ pre-commit install pre-commit installed at .git/hooks/pre-commit ``` Please read the [developer setup manual][5] for more information on how to hack on Software Heritage. [1]: https://virtualenv.pypa.io [2]: https://docs.pytest.org [3]: https://tox.readthedocs.io [4]: https://pre-commit.com [5]: https://docs.softwareheritage.org/devel/developer-setup.html Platform: UNKNOWN Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable Description-Content-Type: text/markdown -Provides-Extra: http Provides-Extra: db +Provides-Extra: http Provides-Extra: testing diff --git a/swh.core.egg-info/SOURCES.txt b/swh.core.egg-info/SOURCES.txt index 8ec21be..1aafb65 100644 --- a/swh.core.egg-info/SOURCES.txt +++ b/swh.core.egg-info/SOURCES.txt @@ -1,48 +1,50 @@ MANIFEST.in Makefile README.md requirements-db.txt requirements-http.txt requirements-swh.txt requirements.txt setup.py version.txt swh/__init__.py swh.core.egg-info/PKG-INFO swh.core.egg-info/SOURCES.txt swh.core.egg-info/dependency_links.txt swh.core.egg-info/entry_points.txt swh.core.egg-info/requires.txt swh.core.egg-info/top_level.txt swh/core/__init__.py swh/core/api_async.py swh/core/config.py swh/core/logger.py swh/core/statsd.py swh/core/tarball.py swh/core/utils.py swh/core/api/__init__.py swh/core/api/asynchronous.py swh/core/api/negotiation.py swh/core/api/serializers.py swh/core/api/tests/__init__.py swh/core/api/tests/server_testing.py swh/core/api/tests/test_api.py swh/core/api/tests/test_async.py swh/core/api/tests/test_serializers.py swh/core/cli/__init__.py swh/core/cli/db.py swh/core/db/__init__.py swh/core/db/common.py swh/core/db/db_utils.py swh/core/db/tests/__init__.py swh/core/db/tests/conftest.py swh/core/db/tests/db_testing.py swh/core/db/tests/test_cli.py swh/core/db/tests/test_db.py swh/core/sql/log-schema.sql swh/core/tests/__init__.py swh/core/tests/test_cli.py swh/core/tests/test_config.py +swh/core/tests/test_logger.py swh/core/tests/test_statsd.py +swh/core/tests/test_tarball.py swh/core/tests/test_utils.py \ No newline at end of file diff --git a/swh.core.egg-info/requires.txt b/swh.core.egg-info/requires.txt index dbc39ec..55f8361 100644 --- a/swh.core.egg-info/requires.txt +++ b/swh.core.egg-info/requires.txt @@ -1,33 +1,34 @@ Deprecated PyYAML systemd-python [db] psycopg2 [http] aiohttp aiohttp_utils>=3.1.1 arrow decorator Flask msgpack>0.5 python-dateutil requests [testing] Click -pytest<4 +pytest pytest-postgresql requests-mock hypothesis>=3.11.0 pre-commit +pytz psycopg2 aiohttp aiohttp_utils>=3.1.1 arrow decorator Flask msgpack>0.5 python-dateutil requests diff --git a/swh/__init__.py b/swh/__init__.py index 69e3be5..de9df06 100644 --- a/swh/__init__.py +++ b/swh/__init__.py @@ -1 +1,4 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) +from typing import Iterable + +__path__ = __import__('pkgutil').extend_path(__path__, + __name__) # type: Iterable[str] diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py index 7df9100..363c79f 100644 --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -1,347 +1,349 @@ # Copyright (C) 2015-2017 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import collections import functools import inspect import json import logging import pickle import requests import datetime +from typing import ClassVar, Optional, Type + from deprecated import deprecated from flask import Flask, Request, Response, request, abort from .serializers import (decode_response, encode_data_client as encode_data, msgpack_dumps, msgpack_loads, SWHJSONDecoder) from .negotiation import (Formatter as FormatterBase, Negotiator as NegotiatorBase, negotiate as _negotiate) logger = logging.getLogger(__name__) # support for content negotiation class Negotiator(NegotiatorBase): def best_mimetype(self): return request.accept_mimetypes.best_match( self.accept_mimetypes, 'application/json') def _abort(self, status_code, err=None): return abort(status_code, err) def negotiate(formatter_cls, *args, **kwargs): return _negotiate(Negotiator, formatter_cls, *args, **kwargs) class Formatter(FormatterBase): def _make_response(self, body, content_type): return Response(body, content_type=content_type) class SWHJSONEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, (datetime.datetime, datetime.date)): return obj.isoformat() if isinstance(obj, datetime.timedelta): return str(obj) # Let the base class default method raise the TypeError return super().default(obj) class JSONFormatter(Formatter): format = 'json' mimetypes = ['application/json'] def render(self, obj): return json.dumps(obj, cls=SWHJSONEncoder) class MsgpackFormatter(Formatter): format = 'msgpack' mimetypes = ['application/x-msgpack'] def render(self, obj): return msgpack_dumps(obj) # base API classes class RemoteException(Exception): pass def remote_api_endpoint(path): def dec(f): f._endpoint_path = path return f return dec class APIError(Exception): """API Error""" def __str__(self): return ('An unexpected error occurred in the backend: {}' .format(self.args)) class MetaRPCClient(type): """Metaclass for RPCClient, which adds a method for each endpoint of the database it is designed to access. See for example :class:`swh.indexer.storage.api.client.RemoteStorage`""" def __new__(cls, name, bases, attributes): # For each method wrapped with @remote_api_endpoint in an API backend # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new # method in RemoteStorage, with the same documentation. # # Note that, despite the usage of decorator magic (eg. functools.wrap), # this never actually calls an IndexerStorage method. backend_class = attributes.get('backend_class', None) for base in bases: if backend_class is not None: break backend_class = getattr(base, 'backend_class', None) if backend_class: for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, '_endpoint_path'): cls.__add_endpoint(meth_name, meth, attributes) return super().__new__(cls, name, bases, attributes) @staticmethod def __add_endpoint(meth_name, meth, attributes): wrapped_meth = inspect.unwrap(meth) @functools.wraps(meth) # Copy signature and doc def meth_(*args, **kwargs): # Match arguments and parameters post_data = inspect.getcallargs( wrapped_meth, *args, **kwargs) # Remove arguments that should not be passed self = post_data.pop('self') post_data.pop('cur', None) post_data.pop('db', None) # Send the request. return self.post(meth._endpoint_path, post_data) attributes[meth_name] = meth_ class RPCClient(metaclass=MetaRPCClient): """Proxy to an internal SWH RPC """ - backend_class = None + backend_class = None # type: ClassVar[Optional[type]] """For each method of `backend_class` decorated with :func:`remote_api_endpoint`, a method with the same prototype and docstring will be added to this class. Calls to this new method will be translated into HTTP requests to a remote server. This backend class will never be instantiated, it only serves as a template.""" - api_exception = APIError + api_exception = APIError # type: ClassVar[Type[Exception]] """The exception class to raise in case of communication error with the server.""" def __init__(self, url, api_exception=None, timeout=None, chunk_size=4096, **kwargs): if api_exception: self.api_exception = api_exception base_url = url if url.endswith('/') else url + '/' self.url = base_url self.session = requests.Session() adapter = requests.adapters.HTTPAdapter( max_retries=kwargs.get('max_retries', 3), pool_connections=kwargs.get('pool_connections', 20), pool_maxsize=kwargs.get('pool_maxsize', 100)) self.session.mount(self.url, adapter) self.timeout = timeout self.chunk_size = chunk_size def _url(self, endpoint): return '%s%s' % (self.url, endpoint) def raw_verb(self, verb, endpoint, **opts): if 'chunk_size' in opts: # if the chunk_size argument has been passed, consider the user # also wants stream=True, otherwise, what's the point. opts['stream'] = True if self.timeout and 'timeout' not in opts: opts['timeout'] = self.timeout try: return getattr(self.session, verb)( self._url(endpoint), **opts ) except requests.exceptions.ConnectionError as e: raise self.api_exception(e) def post(self, endpoint, data, **opts): if isinstance(data, (collections.Iterator, collections.Generator)): data = (encode_data(x) for x in data) else: data = encode_data(data) chunk_size = opts.pop('chunk_size', self.chunk_size) response = self.raw_verb( 'post', endpoint, data=data, headers={'content-type': 'application/x-msgpack', 'accept': 'application/x-msgpack'}, **opts) if opts.get('stream') or \ response.headers.get('transfer-encoding') == 'chunked': return response.iter_content(chunk_size) else: return self._decode_response(response) post_stream = post def get(self, endpoint, **opts): chunk_size = opts.pop('chunk_size', self.chunk_size) response = self.raw_verb( 'get', endpoint, headers={'accept': 'application/x-msgpack'}, **opts) if opts.get('stream') or \ response.headers.get('transfer-encoding') == 'chunked': return response.iter_content(chunk_size) else: return self._decode_response(response) def get_stream(self, endpoint, **opts): return self.get(endpoint, stream=True, **opts) def _decode_response(self, response): if response.status_code == 404: return None if response.status_code == 500: data = decode_response(response) if 'exception_pickled' in data: raise pickle.loads(data['exception_pickled']) else: raise RemoteException(data['exception']) # XXX: this breaks language-independence and should be # replaced by proper unserialization if response.status_code == 400: raise pickle.loads(decode_response(response)) elif response.status_code != 200: raise RemoteException( "Unexpected status code for API request: %s (%s)" % ( response.status_code, response.content, ) ) return decode_response(response) def __repr__(self): return '<{} url={}>'.format(self.__class__.__name__, self.url) class BytesRequest(Request): """Request with proper escaping of arbitrary byte sequences.""" encoding = 'utf-8' encoding_errors = 'surrogateescape' ENCODERS = { 'application/x-msgpack': msgpack_dumps, 'application/json': json.dumps, } def encode_data_server(data, content_type='application/x-msgpack'): encoded_data = ENCODERS[content_type](data) return Response( encoded_data, mimetype=content_type, ) def decode_request(request): content_type = request.mimetype data = request.get_data() if not data: return {} if content_type == 'application/x-msgpack': r = msgpack_loads(data) elif content_type == 'application/json': r = json.loads(data, cls=SWHJSONDecoder) else: raise ValueError('Wrong content type `%s` for API request' % content_type) return r def error_handler(exception, encoder): # XXX: this breaks language-independence and should be # replaced by proper serialization of errors logging.exception(exception) response = encoder(pickle.dumps(exception)) response.status_code = 400 return response class RPCServerApp(Flask): """For each endpoint of the given `backend_class`, tells app.route to call a function that decodes the request and sends it to the backend object provided by the factory. :param Any backend_class: The class of the backend, which will be analyzed to look for API endpoints. :param Callable[[], backend_class] backend_factory: A function with no argument that returns an instance of `backend_class`.""" request_class = BytesRequest def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs): super().__init__(*args, **kwargs) if backend_class is not None: if backend_factory is None: raise TypeError('Missing argument backend_factory') for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, '_endpoint_path'): self.__add_endpoint(meth_name, meth, backend_factory) def __add_endpoint(self, meth_name, meth, backend_factory): from flask import request @self.route('/'+meth._endpoint_path, methods=['POST']) @functools.wraps(meth) # Copy signature and doc def _f(): # Call the actual code obj_meth = getattr(backend_factory(), meth_name) return encode_data_server(obj_meth(**decode_request(request))) @deprecated(version='0.0.64', reason='Use the RPCServerApp instead') class SWHServerAPIApp(RPCServerApp): pass @deprecated(version='0.0.64', reason='Use the MetaRPCClient instead') class MetaSWHRemoteAPI(MetaRPCClient): pass @deprecated(version='0.0.64', reason='Use the RPCClient instead') class SWHRemoteAPI(RPCClient): pass diff --git a/swh/core/api/asynchronous.py b/swh/core/api/asynchronous.py index 761bce5..f38b913 100644 --- a/swh/core/api/asynchronous.py +++ b/swh/core/api/asynchronous.py @@ -1,96 +1,88 @@ import json import logging import pickle import sys import traceback from collections import OrderedDict import multidict import aiohttp.web from deprecated import deprecated from .serializers import msgpack_dumps, msgpack_loads from .serializers import SWHJSONDecoder, SWHJSONEncoder -try: - from aiohttp_utils import negotiation, Response -except ImportError: - from aiohttp import Response - negotiation = None +from aiohttp_utils import negotiation, Response def encode_msgpack(data, **kwargs): return aiohttp.web.Response( body=msgpack_dumps(data), headers=multidict.MultiDict( {'Content-Type': 'application/x-msgpack'}), **kwargs ) -if negotiation is None: - encode_data_server = encode_msgpack -else: - encode_data_server = Response +encode_data_server = Response def render_msgpack(request, data): return msgpack_dumps(data) def render_json(request, data): return json.dumps(data, cls=SWHJSONEncoder) async def decode_request(request): content_type = request.headers.get('Content-Type').split(';')[0].strip() data = await request.read() if not data: return {} if content_type == 'application/x-msgpack': r = msgpack_loads(data) elif content_type == 'application/json': r = json.loads(data.decode(), cls=SWHJSONDecoder) else: raise ValueError('Wrong content type `%s` for API request' % content_type) return r async def error_middleware(app, handler): async def middleware_handler(request): try: return await handler(request) except Exception as e: if isinstance(e, aiohttp.web.HTTPException): raise logging.exception(e) exception = traceback.format_exception(*sys.exc_info()) res = {'exception': exception, 'exception_pickled': pickle.dumps(e)} return encode_data_server(res, status=500) return middleware_handler class RPCServerApp(aiohttp.web.Application): def __init__(self, *args, middlewares=(), **kwargs): middlewares = (error_middleware,) + middlewares - if negotiation: - # renderers are sorted in order of increasing desirability (!) - # see mimeparse.best_match() docstring. - renderers = OrderedDict([ - ('application/json', render_json), - ('application/x-msgpack', render_msgpack), - ]) - nego_middleware = negotiation.negotiation_middleware( - renderers=renderers, - force_rendering=True) - middlewares = (nego_middleware,) + middlewares + # renderers are sorted in order of increasing desirability (!) + # see mimeparse.best_match() docstring. + renderers = OrderedDict([ + ('application/json', render_json), + ('application/x-msgpack', render_msgpack), + ]) + nego_middleware = negotiation.negotiation_middleware( + renderers=renderers, + force_rendering=True) + middlewares = (nego_middleware,) + middlewares super().__init__(*args, middlewares=middlewares, **kwargs) @deprecated(version='0.0.64', reason='Use the RPCServerApp instead') class SWHRemoteAPI(RPCServerApp): pass diff --git a/swh/core/api/negotiation.py b/swh/core/api/negotiation.py index 91d658d..de57742 100644 --- a/swh/core/api/negotiation.py +++ b/swh/core/api/negotiation.py @@ -1,152 +1,153 @@ # This code is a partial and adapted copy of # https://github.com/nickstenning/negotiate # # Copyright 2012-2013 Nick Stenning # 2019 The Software Heritage developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # from collections import defaultdict +from decorator import decorator from inspect import getcallargs -from decorator import decorator +from typing import Any, List, Optional class FormatterNotFound(Exception): pass class Formatter: - format = None - mimetypes = [] + format = None # type: Optional[str] + mimetypes = [] # type: List[Any] def __init__(self, request_mimetype=None): if request_mimetype is None or request_mimetype not in self.mimetypes: try: self.response_mimetype = self.mimetypes[0] except IndexError: raise NotImplementedError( "%s.mimetypes should be a non-empty list" % self.__class__.__name__) else: self.response_mimetype = request_mimetype def configure(self): pass def render(self, obj): raise NotImplementedError( "render() should be implemented by Formatter subclasses") def __call__(self, obj): return self._make_response( self.render(obj), content_type=self.response_mimetype) def _make_response(self, body, content_type): raise NotImplementedError( "_make_response() should be implemented by " "framework-specific subclasses of Formatter" ) class Negotiator: def __init__(self, func): self.func = func self._formatters = [] self._formatters_by_format = defaultdict(list) self._formatters_by_mimetype = defaultdict(list) def __call__(self, *args, **kwargs): result = self.func(*args, **kwargs) format = getcallargs(self.func, *args, **kwargs).get('format') mimetype = self.best_mimetype() try: formatter = self.get_formatter(format, mimetype) except FormatterNotFound as e: return self._abort(404, str(e)) return formatter(result) def register_formatter(self, formatter, *args, **kwargs): self._formatters.append(formatter) self._formatters_by_format[formatter.format].append( (formatter, args, kwargs)) for mimetype in formatter.mimetypes: self._formatters_by_mimetype[mimetype].append( (formatter, args, kwargs)) def get_formatter(self, format=None, mimetype=None): if format is None and mimetype is None: raise TypeError( "get_formatter expects one of the 'format' or 'mimetype' " "kwargs to be set") if format is not None: try: # the first added will be the most specific formatter_cls, args, kwargs = ( self._formatters_by_format[format][0]) except IndexError: raise FormatterNotFound( "Formatter for format '%s' not found!" % format) elif mimetype is not None: try: # the first added will be the most specific formatter_cls, args, kwargs = ( self._formatters_by_mimetype[mimetype][0]) except IndexError: raise FormatterNotFound( "Formatter for mimetype '%s' not found!" % mimetype) formatter = formatter_cls(request_mimetype=mimetype) formatter.configure(*args, **kwargs) return formatter @property def accept_mimetypes(self): return [m for f in self._formatters for m in f.mimetypes] def best_mimetype(self): raise NotImplementedError( "best_mimetype() should be implemented in " "framework-specific subclasses of Negotiator" ) def _abort(self, status_code, err=None): raise NotImplementedError( "_abort() should be implemented in framework-specific " "subclasses of Negotiator" ) def negotiate(negotiator_cls, formatter_cls, *args, **kwargs): def _negotiate(f, *args, **kwargs): return f.negotiator(*args, **kwargs) def decorate(f): if not hasattr(f, 'negotiator'): f.negotiator = negotiator_cls(f) f.negotiator.register_formatter(formatter_cls, *args, **kwargs) return decorator(_negotiate, f) return decorate diff --git a/swh/core/api/tests/test_async.py b/swh/core/api/tests/test_async.py index 96fec21..2de1ced 100644 --- a/swh/core/api/tests/test_async.py +++ b/swh/core/api/tests/test_async.py @@ -1,187 +1,186 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime -import json - import msgpack +import json import pytest from swh.core.api.asynchronous import RPCServerApp, Response from swh.core.api.asynchronous import encode_msgpack, decode_request from swh.core.api.serializers import msgpack_dumps, SWHJSONEncoder pytest_plugins = ['aiohttp.pytest_plugin', 'pytester'] async def root(request): return Response('toor') STRUCT = {'txt': 'something stupid', # 'date': datetime.date(2019, 6, 9), # not supported 'datetime': datetime.datetime(2019, 6, 9, 10, 12), 'timedelta': datetime.timedelta(days=-2, hours=3), 'int': 42, 'float': 3.14, 'subdata': {'int': 42, 'datetime': datetime.datetime(2019, 6, 10, 11, 12), }, 'list': [42, datetime.datetime(2019, 9, 10, 11, 12), 'ok'], } async def struct(request): return Response(STRUCT) async def echo(request): data = await decode_request(request) return Response(data) async def echo_no_nego(request): # let the content negotiation handle the serialization for us... data = await decode_request(request) ret = encode_msgpack(data) return ret def check_mimetype(src, dst): src = src.split(';')[0].strip() dst = dst.split(';')[0].strip() assert src == dst @pytest.fixture def app(): app = RPCServerApp() app.router.add_route('GET', '/', root) app.router.add_route('GET', '/struct', struct) app.router.add_route('POST', '/echo', echo) app.router.add_route('POST', '/echo-no-nego', echo_no_nego) return app async def test_get_simple(app, aiohttp_client) -> None: assert app is not None cli = await aiohttp_client(app) resp = await cli.get('/') assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') data = await resp.read() value = msgpack.unpackb(data, raw=False) assert value == 'toor' async def test_get_simple_nego(app, aiohttp_client) -> None: cli = await aiohttp_client(app) for ctype in ('x-msgpack', 'json'): resp = await cli.get('/', headers={'Accept': 'application/%s' % ctype}) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/%s' % ctype) assert (await decode_request(resp)) == 'toor' async def test_get_struct(app, aiohttp_client) -> None: """Test returned structured from a simple GET data is OK""" cli = await aiohttp_client(app) resp = await cli.get('/struct') assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') assert (await decode_request(resp)) == STRUCT async def test_get_struct_nego(app, aiohttp_client) -> None: """Test returned structured from a simple GET data is OK""" cli = await aiohttp_client(app) for ctype in ('x-msgpack', 'json'): resp = await cli.get('/struct', headers={'Accept': 'application/%s' % ctype}) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/%s' % ctype) assert (await decode_request(resp)) == STRUCT async def test_post_struct_msgpack(app, aiohttp_client) -> None: """Test that msgpack encoded posted struct data is returned as is""" cli = await aiohttp_client(app) # simple struct resp = await cli.post( '/echo', headers={'Content-Type': 'application/x-msgpack'}, data=msgpack_dumps({'toto': 42})) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') assert (await decode_request(resp)) == {'toto': 42} # complex struct resp = await cli.post( '/echo', headers={'Content-Type': 'application/x-msgpack'}, data=msgpack_dumps(STRUCT)) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') assert (await decode_request(resp)) == STRUCT async def test_post_struct_json(app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as is""" cli = await aiohttp_client(app) resp = await cli.post( '/echo', headers={'Content-Type': 'application/json'}, data=json.dumps({'toto': 42}, cls=SWHJSONEncoder)) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') assert (await decode_request(resp)) == {'toto': 42} resp = await cli.post( '/echo', headers={'Content-Type': 'application/json'}, data=json.dumps(STRUCT, cls=SWHJSONEncoder)) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') # assert resp.headers['Content-Type'] == 'application/x-msgpack' assert (await decode_request(resp)) == STRUCT async def test_post_struct_nego(app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as is using content negotiation (accept json or msgpack). """ cli = await aiohttp_client(app) for ctype in ('x-msgpack', 'json'): resp = await cli.post( '/echo', headers={'Content-Type': 'application/json', 'Accept': 'application/%s' % ctype}, data=json.dumps(STRUCT, cls=SWHJSONEncoder)) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/%s' % ctype) assert (await decode_request(resp)) == STRUCT async def test_post_struct_no_nego(app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as msgpack when using non-negotiation-compatible handlers. """ cli = await aiohttp_client(app) for ctype in ('x-msgpack', 'json'): resp = await cli.post( '/echo-no-nego', headers={'Content-Type': 'application/json', 'Accept': 'application/%s' % ctype}, data=json.dumps(STRUCT, cls=SWHJSONEncoder)) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') assert (await decode_request(resp)) == STRUCT diff --git a/swh/core/config.py b/swh/core/config.py index e234210..f5babfd 100644 --- a/swh/core/config.py +++ b/swh/core/config.py @@ -1,360 +1,362 @@ # Copyright (C) 2015 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import configparser import logging import os import yaml from itertools import chain from copy import deepcopy +from typing import Any, Dict, Optional, Tuple + logger = logging.getLogger(__name__) SWH_CONFIG_DIRECTORIES = [ '~/.config/swh', '~/.swh', '/etc/softwareheritage', ] SWH_GLOBAL_CONFIG = 'global.ini' SWH_DEFAULT_GLOBAL_CONFIG = { 'content_size_limit': ('int', 100 * 1024 * 1024), 'log_db': ('str', 'dbname=softwareheritage-log'), } SWH_CONFIG_EXTENSIONS = [ '.yml', '.ini', ] # conversion per type _map_convert_fn = { 'int': int, 'bool': lambda x: x.lower() == 'true', 'list[str]': lambda x: [value.strip() for value in x.split(',')], 'list[int]': lambda x: [int(value.strip()) for value in x.split(',')], } _map_check_fn = { 'int': lambda x: isinstance(x, int), 'bool': lambda x: isinstance(x, bool), 'list[str]': lambda x: (isinstance(x, list) and all(isinstance(y, str) for y in x)), 'list[int]': lambda x: (isinstance(x, list) and all(isinstance(y, int) for y in x)), } def exists_accessible(file): """Check whether a file exists, and is accessible. Returns: True if the file exists and is accessible False if the file does not exist Raises: PermissionError if the file cannot be read. """ try: os.stat(file) except PermissionError: raise except FileNotFoundError: return False else: if os.access(file, os.R_OK): return True else: raise PermissionError("Permission denied: %r" % file) def config_basepath(config_path): """Return the base path of a configuration file""" if config_path.endswith(('.ini', '.yml')): return config_path[:-4] return config_path def read_raw_config(base_config_path): """Read the raw config corresponding to base_config_path. Can read yml or ini files. """ yml_file = base_config_path + '.yml' if exists_accessible(yml_file): logger.info('Loading config file %s', yml_file) with open(yml_file) as f: return yaml.safe_load(f) ini_file = base_config_path + '.ini' if exists_accessible(ini_file): config = configparser.ConfigParser() config.read(ini_file) if 'main' in config._sections: logger.info('Loading config file %s', ini_file) return config._sections['main'] else: logger.warning('Ignoring config file %s (no [main] section)', ini_file) return {} def config_exists(config_path): """Check whether the given config exists""" basepath = config_basepath(config_path) return any(exists_accessible(basepath + extension) for extension in SWH_CONFIG_EXTENSIONS) def read(conf_file=None, default_conf=None): """Read the user's configuration file. Fill in the gap using `default_conf`. `default_conf` is similar to this:: DEFAULT_CONF = { 'a': ('str', '/tmp/swh-loader-git/log'), 'b': ('str', 'dbname=swhloadergit') 'c': ('bool', true) 'e': ('bool', None) 'd': ('int', 10) } If conf_file is None, return the default config. """ conf = {} if conf_file: base_config_path = config_basepath(os.path.expanduser(conf_file)) conf = read_raw_config(base_config_path) if not default_conf: default_conf = {} # remaining missing default configuration key are set # also type conversion is enforced for underneath layer for key in default_conf: nature_type, default_value = default_conf[key] val = conf.get(key, None) if val is None: # fallback to default value conf[key] = default_value elif not _map_check_fn.get(nature_type, lambda x: True)(val): # value present but not in the proper format, force type conversion conf[key] = _map_convert_fn.get(nature_type, lambda x: x)(val) return conf def priority_read(conf_filenames, default_conf=None): """Try reading the configuration files from conf_filenames, in order, and return the configuration from the first one that exists. default_conf has the same specification as it does in read. """ # Try all the files in order for filename in conf_filenames: full_filename = os.path.expanduser(filename) if config_exists(full_filename): return read(full_filename, default_conf) # Else, return the default configuration return read(None, default_conf) def merge_default_configs(base_config, *other_configs): """Merge several default config dictionaries, from left to right""" full_config = base_config.copy() for config in other_configs: full_config.update(config) return full_config def merge_configs(base, other): """Merge two config dictionaries This does merge config dicts recursively, with the rules, for every value of the dicts (with 'val' not being a dict): - None + type -> type - type + None -> None - dict + dict -> dict (merged) - val + dict -> TypeError - dict + val -> TypeError - val + val -> val (other) for instance: >>> d1 = { ... 'key1': { ... 'skey1': 'value1', ... 'skey2': {'sskey1': 'value2'}, ... }, ... 'key2': 'value3', ... } with >>> d2 = { ... 'key1': { ... 'skey1': 'value4', ... 'skey2': {'sskey2': 'value5'}, ... }, ... 'key3': 'value6', ... } will give: >>> d3 = { ... 'key1': { ... 'skey1': 'value4', # <-- note this ... 'skey2': { ... 'sskey1': 'value2', ... 'sskey2': 'value5', ... }, ... }, ... 'key2': 'value3', ... 'key3': 'value6', ... } >>> assert merge_configs(d1, d2) == d3 Note that no type checking is done for anything but dicts. """ if not isinstance(base, dict) or not isinstance(other, dict): raise TypeError( 'Cannot merge a %s with a %s' % (type(base), type(other))) output = {} allkeys = set(chain(base.keys(), other.keys())) for k in allkeys: vb = base.get(k) vo = other.get(k) if isinstance(vo, dict): output[k] = merge_configs(vb is not None and vb or {}, vo) elif isinstance(vb, dict) and k in other and other[k] is not None: output[k] = merge_configs(vb, vo is not None and vo or {}) elif k in other: output[k] = deepcopy(vo) else: output[k] = deepcopy(vb) return output def swh_config_paths(base_filename): """Return the Software Heritage specific configuration paths for the given filename.""" return [os.path.join(dirname, base_filename) for dirname in SWH_CONFIG_DIRECTORIES] def prepare_folders(conf, *keys): """Prepare the folder mentioned in config under keys. """ def makedir(folder): if not os.path.exists(folder): os.makedirs(folder) for key in keys: makedir(conf[key]) def load_global_config(): """Load the global Software Heritage config""" return priority_read( swh_config_paths(SWH_GLOBAL_CONFIG), SWH_DEFAULT_GLOBAL_CONFIG, ) def load_named_config(name, default_conf=None, global_conf=True): """Load the config named `name` from the Software Heritage configuration paths. If global_conf is True (default), read the global configuration too. """ conf = {} if global_conf: conf.update(load_global_config()) conf.update(priority_read(swh_config_paths(name), default_conf)) return conf class SWHConfig: """Mixin to add configuration parsing abilities to classes The class should override the class attributes: - DEFAULT_CONFIG (default configuration to be parsed) - CONFIG_BASE_FILENAME (the filename of the configuration to be used) This class defines one classmethod, parse_config_file, which parses a configuration file using the default config as set in the class attribute. """ - DEFAULT_CONFIG = {} - CONFIG_BASE_FILENAME = '' + DEFAULT_CONFIG = {} # type: Dict[str, Tuple[str, Any]] + CONFIG_BASE_FILENAME = '' # type: Optional[str] @classmethod def parse_config_file(cls, base_filename=None, config_filename=None, additional_configs=None, global_config=True): """Parse the configuration file associated to the current class. By default, parse_config_file will load the configuration cls.CONFIG_BASE_FILENAME from one of the Software Heritage configuration directories, in order, unless it is overridden by base_filename or config_filename (which shortcuts the file lookup completely). Args: - - base_filename (str) overrides the default - cls.CONFIG_BASE_FILENAME - - config_filename (str) sets the file to parse instead of - the defaults set from cls.CONFIG_BASE_FILENAME - - additional_configs (list of default configuration dicts) - allows to override or extend the configuration set in - cls.DEFAULT_CONFIG. + - base_filename (str): overrides the default + cls.CONFIG_BASE_FILENAME + - config_filename (str): sets the file to parse instead of + the defaults set from cls.CONFIG_BASE_FILENAME + - additional_configs: (list of default configuration dicts) + allows to override or extend the configuration set in + cls.DEFAULT_CONFIG. - global_config (bool): Load the global configuration (default: - True) + True) """ if config_filename: config_filenames = [config_filename] elif 'SWH_CONFIG_FILENAME' in os.environ: config_filenames = [os.environ['SWH_CONFIG_FILENAME']] else: if not base_filename: base_filename = cls.CONFIG_BASE_FILENAME config_filenames = swh_config_paths(base_filename) if not additional_configs: additional_configs = [] full_default_config = merge_default_configs(cls.DEFAULT_CONFIG, *additional_configs) config = {} if global_config: config = load_global_config() config.update(priority_read(config_filenames, full_default_config)) return config diff --git a/swh/core/db/tests/db_testing.py b/swh/core/db/tests/db_testing.py index 63cbcaf..c8bed92 100644 --- a/swh/core/db/tests/db_testing.py +++ b/swh/core/db/tests/db_testing.py @@ -1,315 +1,320 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import glob import subprocess import psycopg2 +from typing import Dict, Iterable, Optional, Tuple, Union + from swh.core.utils import numfile_sortkey as sortkey -DB_DUMP_TYPES = {'.sql': 'psql', '.dump': 'pg_dump'} + +DB_DUMP_TYPES = {'.sql': 'psql', '.dump': 'pg_dump'} # type: Dict[str, str] def swh_db_version(dbname_or_service): """Retrieve the swh version if any. In case of the db not initialized, this returns None. Otherwise, this returns the db's version. Args: dbname_or_service (str): The db's name or service Returns: Optional[Int]: Either the db's version or None """ query = 'select version from dbversion order by dbversion desc limit 1' cmd = [ 'psql', '--tuples-only', '--no-psqlrc', '--quiet', '-v', 'ON_ERROR_STOP=1', "--command=%s" % query, dbname_or_service ] try: r = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, universal_newlines=True) result = int(r.stdout.strip()) except Exception: # db not initialized result = None return result def pg_restore(dbname, dumpfile, dumptype='pg_dump'): """ Args: dbname: name of the DB to restore into dumpfile: path of the dump file dumptype: one of 'pg_dump' (for binary dumps), 'psql' (for SQL dumps) """ assert dumptype in ['pg_dump', 'psql'] if dumptype == 'pg_dump': subprocess.check_call(['pg_restore', '--no-owner', '--no-privileges', '--dbname', dbname, dumpfile]) elif dumptype == 'psql': subprocess.check_call(['psql', '--quiet', '--no-psqlrc', '-v', 'ON_ERROR_STOP=1', '-f', dumpfile, dbname]) def pg_dump(dbname, dumpfile): subprocess.check_call(['pg_dump', '--no-owner', '--no-privileges', '-Fc', '-f', dumpfile, dbname]) def pg_dropdb(dbname): subprocess.check_call(['dropdb', dbname]) def pg_createdb(dbname, check=True): """Create a db. If check is True and the db already exists, this will raise an exception (original behavior). If check is False and the db already exists, this will fail silently. If the db does not exist, the db will be created. """ subprocess.run(['createdb', dbname], check=check) def db_create(dbname, dumps=None): """create the test DB and load the test data dumps into it dumps is an iterable of couples (dump_file, dump_type). context: setUpClass """ try: pg_createdb(dbname) except subprocess.CalledProcessError: # try recovering once, in case pg_dropdb(dbname) # the db already existed pg_createdb(dbname) for dump, dtype in dumps: pg_restore(dbname, dump, dtype) return dbname def db_destroy(dbname): """destroy the test DB context: tearDownClass """ pg_dropdb(dbname) def db_connect(dbname): """connect to the test DB and open a cursor context: setUp """ conn = psycopg2.connect('dbname=' + dbname) return { 'conn': conn, 'cursor': conn.cursor() } def db_close(conn): """rollback current transaction and disconnect from the test DB context: tearDown """ if not conn.closed: conn.rollback() conn.close() class DbTestConn: def __init__(self, dbname): self.dbname = dbname def __enter__(self): self.db_setup = db_connect(self.dbname) self.conn = self.db_setup['conn'] self.cursor = self.db_setup['cursor'] return self def __exit__(self, *_): db_close(self.conn) class DbTestContext: def __init__(self, name='softwareheritage-test', dumps=None): self.dbname = name self.dumps = dumps def __enter__(self): db_create(dbname=self.dbname, dumps=self.dumps) return self def __exit__(self, *_): db_destroy(self.dbname) class DbTestFixture: """Mix this in a test subject class to get DB testing support. Use the class method add_db() to add a new database to be tested. Using this will create a DbTestConn entry in the `test_db` dictionary for all the tests, indexed by the name of the database. Example: class TestDb(DbTestFixture, unittest.TestCase): @classmethod def setUpClass(cls): cls.add_db('db_name', DUMP) super().setUpClass() def setUp(self): db = self.test_db['db_name'] print('conn: {}, cursor: {}'.format(db.conn, db.cursor)) To ensure test isolation, each test method of the test case class will execute in its own connection, cursor, and transaction. Note that if you want to define setup/teardown methods, you need to explicitly call super() to ensure that the fixture setup/teardown methods are invoked. Here is an example where all setup/teardown methods are defined in a test case: class TestDb(DbTestFixture, unittest.TestCase): @classmethod def setUpClass(cls): # your add_db() calls here super().setUpClass() # your class setup code here def setUp(self): super().setUp() # your instance setup code here def tearDown(self): # your instance teardown code here super().tearDown() @classmethod def tearDownClass(cls): # your class teardown code here super().tearDownClass() """ - _DB_DUMP_LIST = {} - _DB_LIST = {} + _DB_DUMP_LIST = {} # type: Dict[str, Iterable[Tuple[str, str]]] + _DB_LIST = {} # type: Dict[str, DbTestContext] DB_TEST_FIXTURE_IMPORTED = True @classmethod def add_db(cls, name='softwareheritage-test', dumps=None): cls._DB_DUMP_LIST[name] = dumps @classmethod def setUpClass(cls): for name, dumps in cls._DB_DUMP_LIST.items(): cls._DB_LIST[name] = DbTestContext(name, dumps) cls._DB_LIST[name].__enter__() super().setUpClass() @classmethod def tearDownClass(cls): super().tearDownClass() for name, context in cls._DB_LIST.items(): context.__exit__() def setUp(self, *args, **kwargs): self.test_db = {} for name in self._DB_LIST.keys(): self.test_db[name] = DbTestConn(name) self.test_db[name].__enter__() super().setUp(*args, **kwargs) def tearDown(self): super().tearDown() for name in self._DB_LIST.keys(): self.test_db[name].__exit__() def reset_db_tables(self, name, excluded=None): db = self.test_db[name] conn = db.conn cursor = db.cursor cursor.execute("""SELECT table_name FROM information_schema.tables WHERE table_schema = %s""", ('public',)) tables = set(table for (table,) in cursor.fetchall()) if excluded is not None: tables -= set(excluded) for table in tables: cursor.execute('truncate table %s cascade' % table) conn.commit() class SingleDbTestFixture(DbTestFixture): """Simplified fixture like DbTest but that can only handle a single DB. Gives access to shortcuts like self.cursor and self.conn. DO NOT use this with other fixtures that need to access databases, like StorageTestFixture. The class can override the following class attributes: TEST_DB_NAME: name of the DB used for testing TEST_DB_DUMP: DB dump to be restored before running test methods; can be set to None if no restore from dump is required. If the dump file name endswith" - '.sql' it will be loaded via psql, - '.dump' it will be loaded via pg_restore. Other file extensions will be ignored. Can be a string or a list of strings; each path will be expanded using glob pattern matching. The test case class will then have the following attributes, accessible via self: dbname: name of the test database conn: psycopg2 connection object cursor: open psycopg2 cursor to the DB """ TEST_DB_NAME = 'softwareheritage-test' - TEST_DB_DUMP = None + TEST_DB_DUMP = None # type: Optional[Union[str, Iterable[str]]] @classmethod def setUpClass(cls): cls.dbname = cls.TEST_DB_NAME # XXX to kill? dump_files = cls.TEST_DB_DUMP - if isinstance(dump_files, str): + if dump_files is None: + dump_files = [] + elif isinstance(dump_files, str): dump_files = [dump_files] all_dump_files = [] for files in dump_files: all_dump_files.extend( sorted(glob.glob(files), key=sortkey)) all_dump_files = [(x, DB_DUMP_TYPES[os.path.splitext(x)[1]]) for x in all_dump_files] cls.add_db(name=cls.TEST_DB_NAME, dumps=all_dump_files) super().setUpClass() def setUp(self, *args, **kwargs): super().setUp(*args, **kwargs) db = self.test_db[self.TEST_DB_NAME] self.conn = db.conn self.cursor = db.cursor diff --git a/swh/core/statsd.py b/swh/core/statsd.py index 30dcff4..30be881 100644 --- a/swh/core/statsd.py +++ b/swh/core/statsd.py @@ -1,425 +1,430 @@ # Copyright (C) 2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information # Initially imported from https://github.com/DataDog/datadogpy/ # at revision 62b3a3e89988dc18d78c282fe3ff5d1813917436 # # Copyright (c) 2015, Datadog # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of Datadog nor the names of its contributors may be # used to endorse or promote products derived from this software without # specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # # # Vastly adapted for integration in swh.core: # # - Removed python < 3.5 compat code # - trimmed the imports down to be a single module # - adjust some options: # - drop unix socket connection option # - add environment variable support for setting the statsd host and # port (pulled the idea from the main python statsd module) # - only send timer metrics in milliseconds (that's what # prometheus-statsd-exporter expects) # - drop DataDog-specific metric types (that are unsupported in # prometheus-statsd-exporter) # - made the tags a dict instead of a list (prometheus-statsd-exporter only # supports tags with a value, mirroring prometheus) # - switch from time.time to time.monotonic # - improve unit test coverage # - documentation cleanup from asyncio import iscoroutinefunction from functools import wraps from random import random from time import monotonic import itertools import logging import os import socket +import threading import warnings log = logging.getLogger('swh.core.statsd') class TimedContextManagerDecorator(object): """ A context manager and a decorator which will report the elapsed time in the context OR in a function call. Attributes: elapsed (float): the elapsed time at the point of completion """ def __init__(self, statsd, metric=None, error_metric=None, tags=None, sample_rate=1): self.statsd = statsd self.metric = metric self.error_metric = error_metric self.tags = tags self.sample_rate = sample_rate self.elapsed = None # this is for testing purpose def __call__(self, func): """ Decorator which returns the elapsed time of the function call. Default to the function name if metric was not provided. """ if not self.metric: self.metric = '%s.%s' % (func.__module__, func.__name__) # Coroutines if iscoroutinefunction(func): @wraps(func) async def wrapped_co(*args, **kwargs): start = monotonic() try: result = await func(*args, **kwargs) except: # noqa self._send_error() raise self._send(start) return result return wrapped_co # Others @wraps(func) def wrapped(*args, **kwargs): start = monotonic() try: result = func(*args, **kwargs) except: # noqa self._send_error() raise self._send(start) return result return wrapped def __enter__(self): if not self.metric: raise TypeError("Cannot used timed without a metric!") self._start = monotonic() return self def __exit__(self, type, value, traceback): # Report the elapsed time of the context manager if no error. if type is None: self._send(self._start) else: self._send_error() def _send(self, start): elapsed = (monotonic() - start) * 1000 self.statsd.timing(self.metric, elapsed, tags=self.tags, sample_rate=self.sample_rate) self.elapsed = elapsed def _send_error(self): if self.error_metric is None: self.error_metric = self.metric + '_error_count' self.statsd.increment(self.error_metric, tags=self.tags) def start(self): """Start the timer""" self.__enter__() def stop(self): """Stop the timer, send the metric value""" self.__exit__(None, None, None) class Statsd(object): """Initialize a client to send metrics to a StatsD server. Arguments: host (str): the host of the StatsD server. Defaults to localhost. port (int): the port of the StatsD server. Defaults to 8125. max_buffer_size (int): Maximum number of metrics to buffer before sending to the server if sending metrics in batch namespace (str): Namespace to prefix all metric names constant_tags (Dict[str, str]): Tags to attach to all metrics Note: This class also supports the following environment variables: STATSD_HOST Override the default host of the statsd server STATSD_PORT Override the default port of the statsd server STATSD_TAGS Tags to attach to every metric reported. Example value: "label:value,other_label:other_value" """ def __init__(self, host=None, port=None, max_buffer_size=50, namespace=None, constant_tags=None): # Connection if host is None: host = os.environ.get('STATSD_HOST') or 'localhost' self.host = host if port is None: port = os.environ.get('STATSD_PORT') or 8125 self.port = int(port) # Socket - self.socket = None + self._socket = None + self.lock = threading.Lock() self.max_buffer_size = max_buffer_size self._send = self._send_to_server self.encoding = 'utf-8' # Tags self.constant_tags = {} tags_envvar = os.environ.get('STATSD_TAGS', '') for tag in tags_envvar.split(','): if not tag: continue if ':' not in tag: warnings.warn( 'STATSD_TAGS needs to be in key:value format, ' '%s invalid' % tag, UserWarning, ) continue k, v = tag.split(':', 1) self.constant_tags[k] = v if constant_tags: self.constant_tags.update({ str(k): str(v) for k, v in constant_tags.items() }) # Namespace if namespace is not None: namespace = str(namespace) self.namespace = namespace def __enter__(self): self.open_buffer(self.max_buffer_size) return self def __exit__(self, type, value, traceback): self.close_buffer() def gauge(self, metric, value, tags=None, sample_rate=1): """ Record the value of a gauge, optionally setting a list of tags and a sample rate. >>> statsd.gauge('users.online', 123) >>> statsd.gauge('active.connections', 1001, tags={"protocol": "http"}) """ return self._report(metric, 'g', value, tags, sample_rate) def increment(self, metric, value=1, tags=None, sample_rate=1): """ Increment a counter, optionally setting a value, tags and a sample rate. >>> statsd.increment('page.views') >>> statsd.increment('files.transferred', 124) """ self._report(metric, 'c', value, tags, sample_rate) def decrement(self, metric, value=1, tags=None, sample_rate=1): """ Decrement a counter, optionally setting a value, tags and a sample rate. >>> statsd.decrement('files.remaining') >>> statsd.decrement('active.connections', 2) """ metric_value = -value if value else value self._report(metric, 'c', metric_value, tags, sample_rate) def histogram(self, metric, value, tags=None, sample_rate=1): """ Sample a histogram value, optionally setting tags and a sample rate. >>> statsd.histogram('uploaded.file.size', 1445) >>> statsd.histogram('file.count', 26, tags={"filetype": "python"}) """ self._report(metric, 'h', value, tags, sample_rate) def timing(self, metric, value, tags=None, sample_rate=1): """ Record a timing, optionally setting tags and a sample rate. >>> statsd.timing("query.response.time", 1234) """ self._report(metric, 'ms', value, tags, sample_rate) def timed(self, metric=None, error_metric=None, tags=None, sample_rate=1): """ A decorator or context manager that will measure the distribution of a function's/context's run time. Optionally specify a list of tags or a sample rate. If the metric is not defined as a decorator, the module name and function name will be used. The metric is required as a context manager. :: @statsd.timed('user.query.time', sample_rate=0.5) def get_user(user_id): # Do what you need to ... pass # Is equivalent to ... with statsd.timed('user.query.time', sample_rate=0.5): # Do what you need to ... pass # Is equivalent to ... start = time.monotonic() try: get_user(user_id) finally: statsd.timing('user.query.time', time.monotonic() - start) """ return TimedContextManagerDecorator( statsd=self, metric=metric, error_metric=error_metric, tags=tags, sample_rate=sample_rate) def set(self, metric, value, tags=None, sample_rate=1): """ Sample a set value. >>> statsd.set('visitors.uniques', 999) """ self._report(metric, 's', value, tags, sample_rate) - def get_socket(self): + @property + def socket(self): """ Return a connected socket. Note: connect the socket before assigning it to the class instance to avoid bad thread race conditions. """ - if not self.socket: - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.connect((self.host, self.port)) - self.socket = sock + with self.lock: + if not self._socket: + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.connect((self.host, self.port)) + self._socket = sock - return self.socket + return self._socket def open_buffer(self, max_buffer_size=50): """ Open a buffer to send a batch of metrics in one packet. You can also use this as a context manager. >>> with Statsd() as batch: ... batch.gauge('users.online', 123) ... batch.gauge('active.connections', 1001) """ self.max_buffer_size = max_buffer_size self.buffer = [] self._send = self._send_to_buffer def close_buffer(self): """ Flush the buffer and switch back to single metric packets. """ self._send = self._send_to_server if self.buffer: # Only send packets if there are packets to send self._flush_buffer() def close_socket(self): """ Closes connected socket if connected. """ - if self.socket: - self.socket.close() - self.socket = None + with self.lock: + if self._socket: + self._socket.close() + self._socket = None def _report(self, metric, metric_type, value, tags, sample_rate): """ Create a metric packet and send it. """ if value is None: return if sample_rate != 1 and random() > sample_rate: return # Resolve the full tag list tags = self._add_constant_tags(tags) # Create/format the metric packet payload = "%s%s:%s|%s%s%s" % ( (self.namespace + ".") if self.namespace else "", metric, value, metric_type, ("|@" + str(sample_rate)) if sample_rate != 1 else "", ("|#" + ",".join( "%s:%s" % (k, v) for (k, v) in sorted(tags.items()) )) if tags else "", ) # Send it self._send(payload) def _send_to_server(self, packet): try: # If set, use socket directly - (self.socket or self.get_socket()).send(packet.encode('utf-8')) + self.socket.send(packet.encode('utf-8')) except socket.timeout: return except socket.error: log.debug( "Error submitting statsd packet." " Dropping the packet and closing the socket." ) self.close_socket() def _send_to_buffer(self, packet): self.buffer.append(packet) if len(self.buffer) >= self.max_buffer_size: self._flush_buffer() def _flush_buffer(self): self._send_to_server("\n".join(self.buffer)) self.buffer = [] def _add_constant_tags(self, tags): return { str(k): str(v) for k, v in itertools.chain( self.constant_tags.items(), (tags if tags else {}).items(), ) } statsd = Statsd() diff --git a/swh/core/tests/test_logger.py b/swh/core/tests/test_logger.py new file mode 100644 index 0000000..85727bf --- /dev/null +++ b/swh/core/tests/test_logger.py @@ -0,0 +1,120 @@ +# Copyright (C) 2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from datetime import datetime +import logging +import pytz +import inspect + +from unittest.mock import patch + +from swh.core import logger + + +def lineno(): + """Returns the current line number in our program.""" + return inspect.currentframe().f_back.f_lineno + + +def test_db_level(): + assert logger.db_level_of_py_level(10) == 'debug' + assert logger.db_level_of_py_level(20) == 'info' + assert logger.db_level_of_py_level(30) == 'warning' + assert logger.db_level_of_py_level(40) == 'error' + assert logger.db_level_of_py_level(50) == 'critical' + + +def test_flatten_scalar(): + assert list(logger.flatten('')) == [('', '')] + assert list(logger.flatten('toto')) == [('', 'toto')] + + assert list(logger.flatten(10)) == [('', 10)] + assert list(logger.flatten(10.5)) == [('', 10.5)] + + +def test_flatten_list(): + assert list(logger.flatten([])) == [] + assert list(logger.flatten([1])) == [('0', 1)] + + assert list(logger.flatten([1, 2, ['a', 'b']])) == [ + ('0', 1), + ('1', 2), + ('2_0', 'a'), + ('2_1', 'b'), + ] + + assert list(logger.flatten([1, 2, ['a', ('x', 1)]])) == [ + ('0', 1), + ('1', 2), + ('2_0', 'a'), + ('2_1_0', 'x'), + ('2_1_1', 1), + ] + + +def test_flatten_dict(): + assert list(logger.flatten({})) == [] + assert list(logger.flatten({'a': 1})) == [('a', 1)] + + assert sorted(logger.flatten({'a': 1, + 'b': (2, 3,), + 'c': {'d': 4, 'e': 'f'}})) == [ + ('a', 1), + ('b_0', 2), + ('b_1', 3), + ('c_d', 4), + ('c_e', 'f'), + ] + + +def test_stringify(): + assert logger.stringify(None) == 'None' + assert logger.stringify(123) == '123' + assert logger.stringify('abc') == 'abc' + + date = datetime(2019, 9, 1, 16, 32) + assert logger.stringify(date) == '2019-09-01T16:32:00' + + tzdate = datetime(2019, 9, 1, 16, 32, tzinfo=pytz.utc) + assert logger.stringify(tzdate) == '2019-09-01T16:32:00+00:00' + + +@patch('swh.core.logger.send') +def test_journal_handler(send): + log = logging.getLogger('test_logger') + log.addHandler(logger.JournalHandler()) + log.setLevel(logging.DEBUG) + + _, ln = log.info('hello world'), lineno() + + send.assert_called_with( + 'hello world', + CODE_FILE=__file__, + CODE_FUNC='test_journal_handler', + CODE_LINE=ln, + LOGGER='test_logger', + PRIORITY='6', + THREAD_NAME='MainThread') + + +@patch('swh.core.logger.send') +def test_journal_handler_w_data(send): + log = logging.getLogger('test_logger') + log.addHandler(logger.JournalHandler()) + log.setLevel(logging.DEBUG) + + _, ln = log.debug('something cool %s', ['with', {'extra': 'data'}]), lineno() # noqa + + send.assert_called_with( + "something cool ['with', {'extra': 'data'}]", + CODE_FILE=__file__, + CODE_FUNC='test_journal_handler_w_data', + CODE_LINE=ln, + LOGGER='test_logger', + PRIORITY='7', + THREAD_NAME='MainThread', + SWH_LOGGING_ARGS_0_0='with', + SWH_LOGGING_ARGS_0_1_EXTRA='data' + ) diff --git a/swh/core/tests/test_statsd.py b/swh/core/tests/test_statsd.py index 7b5dd62..56d1aa9 100644 --- a/swh/core/tests/test_statsd.py +++ b/swh/core/tests/test_statsd.py @@ -1,563 +1,563 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information # Initially imported from https://github.com/DataDog/datadogpy/ # at revision 62b3a3e89988dc18d78c282fe3ff5d1813917436 # # Copyright (c) 2015, Datadog # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of Datadog nor the names of its contributors may be # used to endorse or promote products derived from this software without # specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # from collections import deque from contextlib import contextmanager import os import socket import time import unittest import pytest from swh.core.statsd import Statsd, TimedContextManagerDecorator @contextmanager def preserve_envvars(*envvars): """Context manager preserving the value of environment variables""" preserved = {} to_delete = object() for var in envvars: preserved[var] = os.environ.get(var, to_delete) yield for var in envvars: old = preserved[var] if old is not to_delete: os.environ[var] = old else: del os.environ[var] class FakeSocket(object): """ A fake socket for testing. """ def __init__(self): self.payloads = deque() def send(self, payload): assert type(payload) == bytes self.payloads.append(payload) def recv(self): try: return self.payloads.popleft().decode('utf-8') except IndexError: return None def close(self): pass def __repr__(self): return str(self.payloads) class BrokenSocket(FakeSocket): def send(self, payload): raise socket.error("Socket error") class SlowSocket(FakeSocket): def send(self, payload): raise socket.timeout("Socket timeout") class TestStatsd(unittest.TestCase): def setUp(self): """ Set up a default Statsd instance and mock the socket. """ # self.statsd = Statsd() - self.statsd.socket = FakeSocket() + self.statsd._socket = FakeSocket() def recv(self): return self.statsd.socket.recv() def test_set(self): self.statsd.set('set', 123) assert self.recv() == 'set:123|s' def test_gauge(self): self.statsd.gauge('gauge', 123.4) assert self.recv() == 'gauge:123.4|g' def test_counter(self): self.statsd.increment('page.views') self.assertEqual('page.views:1|c', self.recv()) self.statsd.increment('page.views', 11) self.assertEqual('page.views:11|c', self.recv()) self.statsd.decrement('page.views') self.assertEqual('page.views:-1|c', self.recv()) self.statsd.decrement('page.views', 12) self.assertEqual('page.views:-12|c', self.recv()) def test_histogram(self): self.statsd.histogram('histo', 123.4) self.assertEqual('histo:123.4|h', self.recv()) def test_tagged_gauge(self): self.statsd.gauge('gt', 123.4, tags={'country': 'china', 'age': 45}) self.assertEqual('gt:123.4|g|#age:45,country:china', self.recv()) def test_tagged_counter(self): self.statsd.increment('ct', tags={'country': 'españa'}) self.assertEqual('ct:1|c|#country:españa', self.recv()) def test_tagged_histogram(self): self.statsd.histogram('h', 1, tags={'test_tag': 'tag_value'}) self.assertEqual('h:1|h|#test_tag:tag_value', self.recv()) def test_sample_rate(self): self.statsd.increment('c', sample_rate=0) assert not self.recv() for i in range(10000): self.statsd.increment('sampled_counter', sample_rate=0.3) self.assert_almost_equal(3000, len(self.statsd.socket.payloads), 150) self.assertEqual('sampled_counter:1|c|@0.3', self.recv()) def test_tags_and_samples(self): for i in range(100): self.statsd.gauge('gst', 23, tags={"sampled": True}, sample_rate=0.9) self.assert_almost_equal(90, len(self.statsd.socket.payloads), 10) self.assertEqual('gst:23|g|@0.9|#sampled:True', self.recv()) def test_timing(self): self.statsd.timing('t', 123) self.assertEqual('t:123|ms', self.recv()) def test_metric_namespace(self): """ Namespace prefixes all metric names. """ self.statsd.namespace = "foo" self.statsd.gauge('gauge', 123.4) self.assertEqual('foo.gauge:123.4|g', self.recv()) # Test Client level constant tags def test_gauge_constant_tags(self): self.statsd.constant_tags = { 'bar': 'baz', } self.statsd.gauge('gauge', 123.4) assert self.recv() == 'gauge:123.4|g|#bar:baz' def test_counter_constant_tag_with_metric_level_tags(self): self.statsd.constant_tags = { 'bar': 'baz', 'foo': True, } self.statsd.increment('page.views', tags={'extra': 'extra'}) self.assertEqual( 'page.views:1|c|#bar:baz,extra:extra,foo:True', self.recv(), ) def test_gauge_constant_tags_with_metric_level_tags_twice(self): metric_level_tag = {'foo': 'bar'} self.statsd.constant_tags = {'bar': 'baz'} self.statsd.gauge('gauge', 123.4, tags=metric_level_tag) assert self.recv() == 'gauge:123.4|g|#bar:baz,foo:bar' # sending metrics multiple times with same metric-level tags # should not duplicate the tags being sent self.statsd.gauge('gauge', 123.4, tags=metric_level_tag) assert self.recv() == 'gauge:123.4|g|#bar:baz,foo:bar' def assert_almost_equal(self, a, b, delta): self.assertTrue( 0 <= abs(a - b) <= delta, "%s - %s not within %s" % (a, b, delta) ) def test_socket_error(self): - self.statsd.socket = BrokenSocket() + self.statsd._socket = BrokenSocket() self.statsd.gauge('no error', 1) assert True, 'success' def test_socket_timeout(self): - self.statsd.socket = SlowSocket() + self.statsd._socket = SlowSocket() self.statsd.gauge('no error', 1) assert True, 'success' def test_timed(self): """ Measure the distribution of a function's run time. """ @self.statsd.timed('timed.test') def func(a, b, c=1, d=1): """docstring""" time.sleep(0.5) return (a, b, c, d) self.assertEqual('func', func.__name__) self.assertEqual('docstring', func.__doc__) result = func(1, 2, d=3) # Assert it handles args and kwargs correctly. self.assertEqual(result, (1, 2, 1, 3)) packet = self.recv() name_value, type_ = packet.split('|') name, value = name_value.split(':') self.assertEqual('ms', type_) self.assertEqual('timed.test', name) self.assert_almost_equal(500, float(value), 100) def test_timed_exception(self): """ Exception bubble out of the decorator and is reported to statsd as a dedicated counter. """ @self.statsd.timed('timed.test') def func(a, b, c=1, d=1): """docstring""" time.sleep(0.5) return (a / b, c, d) self.assertEqual('func', func.__name__) self.assertEqual('docstring', func.__doc__) with self.assertRaises(ZeroDivisionError): func(1, 0) packet = self.recv() name_value, type_ = packet.split('|') name, value = name_value.split(':') self.assertEqual('c', type_) self.assertEqual('timed.test_error_count', name) self.assertEqual(int(value), 1) def test_timed_no_metric(self, ): """ Test using a decorator without providing a metric. """ @self.statsd.timed() def func(a, b, c=1, d=1): """docstring""" time.sleep(0.5) return (a, b, c, d) self.assertEqual('func', func.__name__) self.assertEqual('docstring', func.__doc__) result = func(1, 2, d=3) # Assert it handles args and kwargs correctly. self.assertEqual(result, (1, 2, 1, 3)) packet = self.recv() name_value, type_ = packet.split('|') name, value = name_value.split(':') self.assertEqual('ms', type_) self.assertEqual('swh.core.tests.test_statsd.func', name) self.assert_almost_equal(500, float(value), 100) def test_timed_coroutine(self): """ Measure the distribution of a coroutine function's run time. Warning: Python >= 3.5 only. """ import asyncio @self.statsd.timed('timed.test') @asyncio.coroutine def print_foo(): """docstring""" time.sleep(0.5) print("foo") loop = asyncio.new_event_loop() loop.run_until_complete(print_foo()) loop.close() # Assert packet = self.recv() name_value, type_ = packet.split('|') name, value = name_value.split(':') self.assertEqual('ms', type_) self.assertEqual('timed.test', name) self.assert_almost_equal(500, float(value), 100) def test_timed_context(self): """ Measure the distribution of a context's run time. """ # In milliseconds with self.statsd.timed('timed_context.test') as timer: self.assertIsInstance(timer, TimedContextManagerDecorator) time.sleep(0.5) packet = self.recv() name_value, type_ = packet.split('|') name, value = name_value.split(':') self.assertEqual('ms', type_) self.assertEqual('timed_context.test', name) self.assert_almost_equal(500, float(value), 100) self.assert_almost_equal(500, timer.elapsed, 100) def test_timed_context_exception(self): """ Exception bubbles out of the `timed` context manager and is reported to statsd as a dedicated counter. """ class ContextException(Exception): pass def func(self): with self.statsd.timed('timed_context.test'): time.sleep(0.5) raise ContextException() # Ensure the exception was raised. self.assertRaises(ContextException, func, self) # Ensure the timing was recorded. packet = self.recv() name_value, type_ = packet.split('|') name, value = name_value.split(':') self.assertEqual('c', type_) self.assertEqual('timed_context.test_error_count', name) self.assertEqual(int(value), 1) def test_timed_context_no_metric_name_exception(self): """Test that an exception occurs if using a context manager without a metric name. """ def func(self): with self.statsd.timed(): time.sleep(0.5) # Ensure the exception was raised. self.assertRaises(TypeError, func, self) # Ensure the timing was recorded. packet = self.recv() self.assertEqual(packet, None) def test_timed_start_stop_calls(self): timer = self.statsd.timed('timed_context.test') timer.start() time.sleep(0.5) timer.stop() packet = self.recv() name_value, type_ = packet.split('|') name, value = name_value.split(':') self.assertEqual('ms', type_) self.assertEqual('timed_context.test', name) self.assert_almost_equal(500, float(value), 100) def test_batched(self): self.statsd.open_buffer() self.statsd.gauge('page.views', 123) self.statsd.timing('timer', 123) self.statsd.close_buffer() self.assertEqual('page.views:123|g\ntimer:123|ms', self.recv()) def test_context_manager(self): fake_socket = FakeSocket() with Statsd() as statsd: - statsd.socket = fake_socket + statsd._socket = fake_socket statsd.gauge('page.views', 123) statsd.timing('timer', 123) self.assertEqual('page.views:123|g\ntimer:123|ms', fake_socket.recv()) def test_batched_buffer_autoflush(self): fake_socket = FakeSocket() with Statsd() as statsd: - statsd.socket = fake_socket + statsd._socket = fake_socket for i in range(51): statsd.increment('mycounter') self.assertEqual( '\n'.join(['mycounter:1|c' for i in range(50)]), fake_socket.recv(), ) self.assertEqual('mycounter:1|c', fake_socket.recv()) def test_module_level_instance(self): from swh.core.statsd import statsd self.assertTrue(isinstance(statsd, Statsd)) def test_instantiating_does_not_connect(self): local_statsd = Statsd() - self.assertEqual(None, local_statsd.socket) + self.assertEqual(None, local_statsd._socket) def test_accessing_socket_opens_socket(self): local_statsd = Statsd() try: - self.assertIsNotNone(local_statsd.get_socket()) + self.assertIsNotNone(local_statsd.socket) finally: - local_statsd.socket.close() + local_statsd.close_socket() def test_accessing_socket_multiple_times_returns_same_socket(self): local_statsd = Statsd() fresh_socket = FakeSocket() - local_statsd.socket = fresh_socket - self.assertEqual(fresh_socket, local_statsd.get_socket()) - self.assertNotEqual(FakeSocket(), local_statsd.get_socket()) + local_statsd._socket = fresh_socket + self.assertEqual(fresh_socket, local_statsd.socket) + self.assertNotEqual(FakeSocket(), local_statsd.socket) def test_tags_from_environment(self): with preserve_envvars('STATSD_TAGS'): os.environ['STATSD_TAGS'] = 'country:china,age:45' statsd = Statsd() - statsd.socket = FakeSocket() + statsd._socket = FakeSocket() statsd.gauge('gt', 123.4) self.assertEqual('gt:123.4|g|#age:45,country:china', statsd.socket.recv()) def test_tags_from_environment_and_constant(self): with preserve_envvars('STATSD_TAGS'): os.environ['STATSD_TAGS'] = 'country:china,age:45' statsd = Statsd(constant_tags={'country': 'canada'}) - statsd.socket = FakeSocket() + statsd._socket = FakeSocket() statsd.gauge('gt', 123.4) self.assertEqual('gt:123.4|g|#age:45,country:canada', statsd.socket.recv()) def test_tags_from_environment_warning(self): with preserve_envvars('STATSD_TAGS'): os.environ['STATSD_TAGS'] = 'valid:tag,invalid_tag' with pytest.warns(UserWarning) as record: statsd = Statsd() assert len(record) == 1 assert 'invalid_tag' in record[0].message.args[0] assert 'valid:tag' not in record[0].message.args[0] assert statsd.constant_tags == {'valid': 'tag'} def test_gauge_doesnt_send_none(self): self.statsd.gauge('metric', None) assert self.recv() is None def test_increment_doesnt_send_none(self): self.statsd.increment('metric', None) assert self.recv() is None def test_decrement_doesnt_send_none(self): self.statsd.decrement('metric', None) assert self.recv() is None def test_timing_doesnt_send_none(self): self.statsd.timing('metric', None) assert self.recv() is None def test_histogram_doesnt_send_none(self): self.statsd.histogram('metric', None) assert self.recv() is None def test_param_host(self): with preserve_envvars('STATSD_HOST', 'STATSD_PORT'): os.environ['STATSD_HOST'] = 'test-value' os.environ['STATSD_PORT'] = '' local_statsd = Statsd(host='actual-test-value') self.assertEqual(local_statsd.host, 'actual-test-value') self.assertEqual(local_statsd.port, 8125) def test_param_port(self): with preserve_envvars('STATSD_HOST', 'STATSD_PORT'): os.environ['STATSD_HOST'] = '' os.environ['STATSD_PORT'] = '12345' local_statsd = Statsd(port=4321) self.assertEqual(local_statsd.host, 'localhost') self.assertEqual(local_statsd.port, 4321) def test_envvar_host(self): with preserve_envvars('STATSD_HOST', 'STATSD_PORT'): os.environ['STATSD_HOST'] = 'test-value' os.environ['STATSD_PORT'] = '' local_statsd = Statsd() self.assertEqual(local_statsd.host, 'test-value') self.assertEqual(local_statsd.port, 8125) def test_envvar_port(self): with preserve_envvars('STATSD_HOST', 'STATSD_PORT'): os.environ['STATSD_HOST'] = '' os.environ['STATSD_PORT'] = '12345' local_statsd = Statsd() self.assertEqual(local_statsd.host, 'localhost') self.assertEqual(local_statsd.port, 12345) def test_namespace_added(self): local_statsd = Statsd(namespace='test-namespace') - local_statsd.socket = FakeSocket() + local_statsd._socket = FakeSocket() local_statsd.gauge('gauge', 123.4) assert local_statsd.socket.recv() == 'test-namespace.gauge:123.4|g' def test_contextmanager_empty(self): with self.statsd: assert True, 'success' def test_contextmanager_buffering(self): with self.statsd as s: s.gauge('gauge', 123.4) s.gauge('gauge_other', 456.78) self.assertIsNone(s.socket.recv()) self.assertEqual(self.recv(), 'gauge:123.4|g\ngauge_other:456.78|g') def test_timed_elapsed(self): with self.statsd.timed('test_timer') as t: pass self.assertGreaterEqual(t.elapsed, 0) self.assertEqual(self.recv(), 'test_timer:%s|ms' % t.elapsed) diff --git a/swh/core/tests/test_tarball.py b/swh/core/tests/test_tarball.py new file mode 100644 index 0000000..92e4f5e --- /dev/null +++ b/swh/core/tests/test_tarball.py @@ -0,0 +1,67 @@ +# Copyright (C) 2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from zipfile import ZipFile + +from swh.core import tarball + + +def test_is_tarball(tmp_path): + + nozip = tmp_path / 'nozip.zip' + nozip.write_text('Im no zip') + + assert tarball.is_tarball(str(nozip)) is False + + notar = tmp_path / 'notar.tar' + notar.write_text('Im no tar') + + assert tarball.is_tarball(str(notar)) is False + + zipfile = tmp_path / 'truezip.zip' + with ZipFile(str(zipfile), 'w') as myzip: + myzip.writestr('file1.txt', 'some content') + + assert tarball.is_tarball(str(zipfile)) is True + + +def test_compress_uncompress_zip(tmp_path): + tocompress = tmp_path / 'compressme' + tocompress.mkdir() + + for i in range(10): + fpath = tocompress / ('file%s.txt' % i) + fpath.write_text('content of file %s' % i) + + zipfile = tmp_path / 'archive.zip' + tarball.compress(str(zipfile), 'zip', str(tocompress)) + + assert tarball.is_tarball(str(zipfile)) + + destdir = tmp_path / 'destdir' + tarball.uncompress(str(zipfile), str(destdir)) + + lsdir = sorted(x.name for x in destdir.iterdir()) + assert ['file%s.txt' % i for i in range(10)] == lsdir + + +def test_compress_uncompress_tar(tmp_path): + tocompress = tmp_path / 'compressme' + tocompress.mkdir() + + for i in range(10): + fpath = tocompress / ('file%s.txt' % i) + fpath.write_text('content of file %s' % i) + + tarfile = tmp_path / 'archive.tar' + tarball.compress(str(tarfile), 'tar', str(tocompress)) + + assert tarball.is_tarball(str(tarfile)) + + destdir = tmp_path / 'destdir' + tarball.uncompress(str(tarfile), str(destdir)) + + lsdir = sorted(x.name for x in destdir.iterdir()) + assert ['file%s.txt' % i for i in range(10)] == lsdir diff --git a/version.txt b/version.txt index 16db69b..d00a869 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -v0.0.67-0-g9c720a1 \ No newline at end of file +v0.0.68-0-g2980105 \ No newline at end of file