diff --git a/conftest.py b/conftest.py index 888e3c3..8b6908f 100644 --- a/conftest.py +++ b/conftest.py @@ -1,20 +1,20 @@ -import pytest from hypothesis import settings +import pytest from swh.core.cli import swh as _swhmain # define tests profile. Full documentation is at: # https://hypothesis.readthedocs.io/en/latest/settings.html#settings-profiles settings.register_profile("fast", max_examples=5, deadline=5000) settings.register_profile("slow", max_examples=20, deadline=5000) @pytest.fixture def swhmain(): """Yield an instance of the main `swh` click command that cleans the added subcommands up on teardown.""" commands = _swhmain.commands.copy() aliases = _swhmain.aliases.copy() yield _swhmain _swhmain.commands = commands _swhmain.aliases = aliases diff --git a/setup.py b/setup.py index f7426e5..0c550c9 100755 --- a/setup.py +++ b/setup.py @@ -1,88 +1,88 @@ #!/usr/bin/env python3 # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from io import open import os -from setuptools import setup, find_packages - from os import path -from io import open + +from setuptools import find_packages, setup here = path.abspath(path.dirname(__file__)) # Get the long description from the README file with open(path.join(here, "README.md"), encoding="utf-8") as f: long_description = f.read() def parse_requirements(*names): requirements = [] for name in names: if name: reqf = "requirements-%s.txt" % name else: reqf = "requirements.txt" if not os.path.exists(reqf): return requirements with open(reqf) as f: for line in f.readlines(): line = line.strip() if not line or line.startswith("#"): continue requirements.append(line) return requirements setup( name="swh.core", description="Software Heritage core utilities", long_description=long_description, long_description_content_type="text/markdown", python_requires=">=3.7", author="Software Heritage developers", author_email="swh-devel@inria.fr", url="https://forge.softwareheritage.org/diffusion/DCORE/", packages=find_packages(), py_modules=["pytest_swh_core"], scripts=[], install_requires=parse_requirements(None, "swh"), setup_requires=["setuptools-scm"], use_scm_version=True, extras_require={ "testing-core": parse_requirements("test"), "logging": parse_requirements("logging"), "db": parse_requirements("db"), "testing-db": parse_requirements("test-db"), "http": parse_requirements("http"), # kitchen sink, please do not use "testing": parse_requirements("test", "test-db", "db", "http", "logging"), }, include_package_data=True, entry_points=""" [console_scripts] swh=swh.core.cli:main swh-db-init=swh.core.cli.db:db_init [swh.cli.subcommands] db=swh.core.cli.db:db db-init=swh.core.cli.db:db_init [pytest11] pytest_swh_core = swh.core.pytest_plugin """, classifiers=[ "Programming Language :: Python :: 3", "Intended Audience :: Developers", "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", "Operating System :: OS Independent", "Development Status :: 5 - Production/Stable", ], project_urls={ "Bug Reports": "https://forge.softwareheritage.org/maniphest", "Funding": "https://www.softwareheritage.org/donate", "Source": "https://forge.softwareheritage.org/source/swh-core", "Documentation": "https://docs.softwareheritage.org/devel/swh-core/", }, ) diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py index 376397d..2094698 100644 --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -1,457 +1,452 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import abc import functools import inspect import logging import pickle -import requests - from typing import ( Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union, ) -from flask import Flask, Request, Response, request, abort +from flask import Flask, Request, Response, abort, request +import requests from werkzeug.exceptions import HTTPException +from .negotiation import Formatter as FormatterBase +from .negotiation import Negotiator as NegotiatorBase +from .negotiation import negotiate as _negotiate from .serializers import ( - decode_response, - encode_data_client as encode_data, - msgpack_dumps, - msgpack_loads, + exception_to_dict, json_dumps, json_loads, - exception_to_dict, -) - -from .negotiation import ( - Formatter as FormatterBase, - Negotiator as NegotiatorBase, - negotiate as _negotiate, + msgpack_dumps, + msgpack_loads, ) - +from .serializers import decode_response +from .serializers import encode_data_client as encode_data logger = logging.getLogger(__name__) # support for content negotiation class Negotiator(NegotiatorBase): def best_mimetype(self): return request.accept_mimetypes.best_match( self.accept_mimetypes, "application/json" ) def _abort(self, status_code, err=None): return abort(status_code, err) def negotiate(formatter_cls, *args, **kwargs): return _negotiate(Negotiator, formatter_cls, *args, **kwargs) class Formatter(FormatterBase): def _make_response(self, body, content_type): return Response(body, content_type=content_type) def configure(self, extra_encoders=None): self.extra_encoders = extra_encoders class JSONFormatter(Formatter): format = "json" mimetypes = ["application/json"] def render(self, obj): return json_dumps(obj, extra_encoders=self.extra_encoders) class MsgpackFormatter(Formatter): format = "msgpack" mimetypes = ["application/x-msgpack"] def render(self, obj): return msgpack_dumps(obj, extra_encoders=self.extra_encoders) # base API classes class RemoteException(Exception): """raised when remote returned an out-of-band failure notification, e.g., as a HTTP status code or serialized exception Attributes: response: HTTP response corresponding to the failure """ def __init__( self, payload: Optional[Any] = None, response: Optional[requests.Response] = None, ): if payload is not None: super().__init__(payload) else: super().__init__() self.response = response def __str__(self): if ( self.args and isinstance(self.args[0], dict) and "type" in self.args[0] and "args" in self.args[0] ): return ( f"' ) else: return super().__str__() F = TypeVar("F", bound=Callable) def remote_api_endpoint(path) -> Callable[[F], F]: def dec(f: F) -> F: f._endpoint_path = path # type: ignore return f return dec class APIError(Exception): """API Error""" def __str__(self): return "An unexpected error occurred in the backend: {}".format(self.args) class MetaRPCClient(type): """Metaclass for RPCClient, which adds a method for each endpoint of the database it is designed to access. See for example :class:`swh.indexer.storage.api.client.RemoteStorage`""" def __new__(cls, name, bases, attributes): # For each method wrapped with @remote_api_endpoint in an API backend # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new # method in RemoteStorage, with the same documentation. # # Note that, despite the usage of decorator magic (eg. functools.wrap), # this never actually calls an IndexerStorage method. backend_class = attributes.get("backend_class", None) for base in bases: if backend_class is not None: break backend_class = getattr(base, "backend_class", None) if backend_class: for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, "_endpoint_path"): cls.__add_endpoint(meth_name, meth, attributes) return super().__new__(cls, name, bases, attributes) @staticmethod def __add_endpoint(meth_name, meth, attributes): wrapped_meth = inspect.unwrap(meth) @functools.wraps(meth) # Copy signature and doc def meth_(*args, **kwargs): # Match arguments and parameters post_data = inspect.getcallargs(wrapped_meth, *args, **kwargs) # Remove arguments that should not be passed self = post_data.pop("self") post_data.pop("cur", None) post_data.pop("db", None) # Send the request. return self.post(meth._endpoint_path, post_data) if meth_name not in attributes: attributes[meth_name] = meth_ class RPCClient(metaclass=MetaRPCClient): """Proxy to an internal SWH RPC """ backend_class = None # type: ClassVar[Optional[type]] """For each method of `backend_class` decorated with :func:`remote_api_endpoint`, a method with the same prototype and docstring will be added to this class. Calls to this new method will be translated into HTTP requests to a remote server. This backend class will never be instantiated, it only serves as a template.""" api_exception = APIError # type: ClassVar[Type[Exception]] """The exception class to raise in case of communication error with the server.""" reraise_exceptions: ClassVar[List[Type[Exception]]] = [] """On server errors, if any of the exception classes in this list has the same name as the error name, then the exception will be instantiated and raised instead of a generic RemoteException.""" extra_type_encoders: List[Tuple[type, str, Callable]] = [] """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` to be able to serialize more object types.""" extra_type_decoders: Dict[str, Callable] = {} """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` to be able to deserialize more object types.""" def __init__( self, url, api_exception=None, timeout=None, chunk_size=4096, reraise_exceptions=None, **kwargs, ): if api_exception: self.api_exception = api_exception if reraise_exceptions: self.reraise_exceptions = reraise_exceptions base_url = url if url.endswith("/") else url + "/" self.url = base_url self.session = requests.Session() adapter = requests.adapters.HTTPAdapter( max_retries=kwargs.get("max_retries", 3), pool_connections=kwargs.get("pool_connections", 20), pool_maxsize=kwargs.get("pool_maxsize", 100), ) self.session.mount(self.url, adapter) self.timeout = timeout self.chunk_size = chunk_size def _url(self, endpoint): return "%s%s" % (self.url, endpoint) def raw_verb(self, verb, endpoint, **opts): if "chunk_size" in opts: # if the chunk_size argument has been passed, consider the user # also wants stream=True, otherwise, what's the point. opts["stream"] = True if self.timeout and "timeout" not in opts: opts["timeout"] = self.timeout try: return getattr(self.session, verb)(self._url(endpoint), **opts) except requests.exceptions.ConnectionError as e: raise self.api_exception(e) def post(self, endpoint, data, **opts): if isinstance(data, (abc.Iterator, abc.Generator)): data = (self._encode_data(x) for x in data) else: data = self._encode_data(data) chunk_size = opts.pop("chunk_size", self.chunk_size) response = self.raw_verb( "post", endpoint, data=data, headers={ "content-type": "application/x-msgpack", "accept": "application/x-msgpack", }, **opts, ) if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": self.raise_for_status(response) return response.iter_content(chunk_size) else: return self._decode_response(response) def _encode_data(self, data): return encode_data(data, extra_encoders=self.extra_type_encoders) post_stream = post def get(self, endpoint, **opts): chunk_size = opts.pop("chunk_size", self.chunk_size) response = self.raw_verb( "get", endpoint, headers={"accept": "application/x-msgpack"}, **opts ) if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": self.raise_for_status(response) return response.iter_content(chunk_size) else: return self._decode_response(response) def get_stream(self, endpoint, **opts): return self.get(endpoint, stream=True, **opts) def raise_for_status(self, response) -> None: """check response HTTP status code and raise an exception if it denotes an error; do nothing otherwise """ status_code = response.status_code status_class = response.status_code // 100 if status_code == 404: raise RemoteException(payload="404 not found", response=response) exception = None # TODO: only old servers send pickled error; stop trying to unpickle # after they are all upgraded try: if status_class == 4: data = self._decode_response(response, check_status=False) if isinstance(data, dict): for exc_type in self.reraise_exceptions: if exc_type.__name__ == data["exception"]["type"]: exception = exc_type(*data["exception"]["args"]) break else: exception = RemoteException( payload=data["exception"], response=response ) else: exception = pickle.loads(data) elif status_class == 5: data = self._decode_response(response, check_status=False) if "exception_pickled" in data: exception = pickle.loads(data["exception_pickled"]) else: exception = RemoteException( payload=data["exception"], response=response ) except (TypeError, pickle.UnpicklingError): raise RemoteException(payload=data, response=response) if exception: raise exception from None if status_class != 2: raise RemoteException( payload=f"API HTTP error: {status_code} {response.content}", response=response, ) def _decode_response(self, response, check_status=True): if check_status: self.raise_for_status(response) return decode_response(response, extra_decoders=self.extra_type_decoders) def __repr__(self): return "<{} url={}>".format(self.__class__.__name__, self.url) class BytesRequest(Request): """Request with proper escaping of arbitrary byte sequences.""" encoding = "utf-8" encoding_errors = "surrogateescape" ENCODERS: Dict[str, Callable[[Any], Union[bytes, str]]] = { "application/x-msgpack": msgpack_dumps, "application/json": json_dumps, } def encode_data_server( data, content_type="application/x-msgpack", extra_type_encoders=None ): encoded_data = ENCODERS[content_type](data, extra_encoders=extra_type_encoders) return Response(encoded_data, mimetype=content_type,) def decode_request(request, extra_decoders=None): content_type = request.mimetype data = request.get_data() if not data: return {} if content_type == "application/x-msgpack": r = msgpack_loads(data, extra_decoders=extra_decoders) elif content_type == "application/json": # XXX this .decode() is needed for py35. # Should not be needed any more with py37 r = json_loads(data.decode("utf-8"), extra_decoders=extra_decoders) else: raise ValueError("Wrong content type `%s` for API request" % content_type) return r def error_handler(exception, encoder, status_code=500): logging.exception(exception) response = encoder(exception_to_dict(exception)) if isinstance(exception, HTTPException): response.status_code = exception.code else: # TODO: differentiate between server errors and client errors response.status_code = status_code return response class RPCServerApp(Flask): """For each endpoint of the given `backend_class`, tells app.route to call a function that decodes the request and sends it to the backend object provided by the factory. :param Any backend_class: The class of the backend, which will be analyzed to look for API endpoints. :param Optional[Callable[[], backend_class]] backend_factory: A function with no argument that returns an instance of `backend_class`. If unset, defaults to calling `backend_class` constructor directly. """ request_class = BytesRequest extra_type_encoders: List[Tuple[type, str, Callable]] = [] """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` to be able to serialize more object types.""" extra_type_decoders: Dict[str, Callable] = {} """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` to be able to deserialize more object types.""" def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs): super().__init__(*args, **kwargs) self.backend_class = backend_class if backend_class is not None: if backend_factory is None: backend_factory = backend_class for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, "_endpoint_path"): self.__add_endpoint(meth_name, meth, backend_factory) def __add_endpoint(self, meth_name, meth, backend_factory): from flask import request @self.route("/" + meth._endpoint_path, methods=["POST"]) @negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders) @negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders) @functools.wraps(meth) # Copy signature and doc def _f(): # Call the actual code obj_meth = getattr(backend_factory(), meth_name) kw = decode_request(request, extra_decoders=self.extra_type_decoders) return obj_meth(**kw) diff --git a/swh/core/api/asynchronous.py b/swh/core/api/asynchronous.py index 0483043..4522834 100644 --- a/swh/core/api/asynchronous.py +++ b/swh/core/api/asynchronous.py @@ -1,97 +1,100 @@ # Copyright (C) 2017-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import OrderedDict import logging from typing import Tuple, Type import aiohttp.web +from aiohttp_utils import Response, negotiation from deprecated import deprecated import multidict -from .serializers import msgpack_dumps, msgpack_loads -from .serializers import json_dumps, json_loads -from .serializers import exception_to_dict - -from aiohttp_utils import negotiation, Response +from .serializers import ( + exception_to_dict, + json_dumps, + json_loads, + msgpack_dumps, + msgpack_loads, +) def encode_msgpack(data, **kwargs): return aiohttp.web.Response( body=msgpack_dumps(data), headers=multidict.MultiDict({"Content-Type": "application/x-msgpack"}), **kwargs, ) encode_data_server = Response def render_msgpack(request, data): return msgpack_dumps(data) def render_json(request, data): return json_dumps(data) async def decode_request(request): content_type = request.headers.get("Content-Type").split(";")[0].strip() data = await request.read() if not data: return {} if content_type == "application/x-msgpack": r = msgpack_loads(data) elif content_type == "application/json": r = json_loads(data) else: raise ValueError("Wrong content type `%s` for API request" % content_type) return r async def error_middleware(app, handler): async def middleware_handler(request): try: return await handler(request) except Exception as e: if isinstance(e, aiohttp.web.HTTPException): raise logging.exception(e) res = exception_to_dict(e) if isinstance(e, app.client_exception_classes): status = 400 else: status = 500 return encode_data_server(res, status=status) return middleware_handler class RPCServerApp(aiohttp.web.Application): client_exception_classes: Tuple[Type[Exception], ...] = () """Exceptions that should be handled as a client error (eg. object not found, invalid argument)""" def __init__(self, *args, middlewares=(), **kwargs): middlewares = (error_middleware,) + middlewares # renderers are sorted in order of increasing desirability (!) # see mimeparse.best_match() docstring. renderers = OrderedDict( [ ("application/json", render_json), ("application/x-msgpack", render_msgpack), ] ) nego_middleware = negotiation.negotiation_middleware( renderers=renderers, force_rendering=True ) middlewares = (nego_middleware,) + middlewares super().__init__(*args, middlewares=middlewares, **kwargs) @deprecated(version="0.0.64", reason="Use the RPCServerApp instead") class SWHRemoteAPI(RPCServerApp): pass diff --git a/swh/core/api/classes.py b/swh/core/api/classes.py index 0aefcae..c84db28 100644 --- a/swh/core/api/classes.py +++ b/swh/core/api/classes.py @@ -1,44 +1,35 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from dataclasses import dataclass, field - -from typing import ( - Callable, - Generic, - Iterable, - List, - Optional, - TypeVar, -) - +from typing import Callable, Generic, Iterable, List, Optional, TypeVar TResult = TypeVar("TResult") TToken = TypeVar("TToken") @dataclass(eq=True) class PagedResult(Generic[TResult, TToken]): """Represents a page of results; with a token to get the next page""" results: List[TResult] = field(default_factory=list) next_page_token: Optional[TToken] = field(default=None) def stream_results( f: Callable[..., PagedResult[TResult, TToken]], *args, **kwargs ) -> Iterable[TResult]: """Consume the paginated result and stream the page results """ if "page_token" in kwargs: raise TypeError('stream_results has no argument "page_token".') page_token = None while True: page_result = f(*args, page_token=page_token, **kwargs) yield from page_result.results page_token = page_result.next_page_token if page_token is None: break diff --git a/swh/core/api/negotiation.py b/swh/core/api/negotiation.py index 4e2abab..8f6398b 100644 --- a/swh/core/api/negotiation.py +++ b/swh/core/api/negotiation.py @@ -1,157 +1,156 @@ # This code is a partial and adapted copy of # https://github.com/nickstenning/negotiate # # Copyright 2012-2013 Nick Stenning # 2019 The Software Heritage developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # from collections import defaultdict -from decorator import decorator from inspect import getcallargs +from typing import Any, Callable, DefaultDict, List, NoReturn, Optional, Type -from typing import Any, List, Optional, Callable, Type, NoReturn, DefaultDict - +from decorator import decorator from requests import Response class FormatterNotFound(Exception): pass class Formatter: format: Optional[str] = None mimetypes: List[str] = [] def __init__(self, request_mimetype: Optional[str] = None) -> None: if request_mimetype is None or request_mimetype not in self.mimetypes: try: self.response_mimetype = self.mimetypes[0] except IndexError: raise NotImplementedError( "%s.mimetypes should be a non-empty list" % self.__class__.__name__ ) else: self.response_mimetype = request_mimetype def configure(self) -> None: pass def render(self, obj: Any) -> bytes: raise NotImplementedError( "render() should be implemented by Formatter subclasses" ) def __call__(self, obj: Any) -> Response: return self._make_response( self.render(obj), content_type=self.response_mimetype ) def _make_response(self, body: bytes, content_type: str) -> Response: raise NotImplementedError( "_make_response() should be implemented by " "framework-specific subclasses of Formatter" ) class Negotiator: def __init__(self, func: Callable[..., Any]) -> None: self.func = func self._formatters: List[Type[Formatter]] = [] self._formatters_by_format: DefaultDict = defaultdict(list) self._formatters_by_mimetype: DefaultDict = defaultdict(list) def __call__(self, *args, **kwargs) -> Response: result = self.func(*args, **kwargs) format = getcallargs(self.func, *args, **kwargs).get("format") mimetype = self.best_mimetype() try: formatter = self.get_formatter(format, mimetype) except FormatterNotFound as e: return self._abort(404, str(e)) return formatter(result) def register_formatter(self, formatter: Type[Formatter], *args, **kwargs) -> None: self._formatters.append(formatter) self._formatters_by_format[formatter.format].append((formatter, args, kwargs)) for mimetype in formatter.mimetypes: self._formatters_by_mimetype[mimetype].append((formatter, args, kwargs)) def get_formatter( self, format: Optional[str] = None, mimetype: Optional[str] = None ) -> Formatter: if format is None and mimetype is None: raise TypeError( "get_formatter expects one of the 'format' or 'mimetype' " "kwargs to be set" ) if format is not None: try: # the first added will be the most specific formatter_cls, args, kwargs = self._formatters_by_format[format][0] except IndexError: raise FormatterNotFound("Formatter for format '%s' not found!" % format) elif mimetype is not None: try: # the first added will be the most specific formatter_cls, args, kwargs = self._formatters_by_mimetype[mimetype][0] except IndexError: raise FormatterNotFound( "Formatter for mimetype '%s' not found!" % mimetype ) formatter = formatter_cls(request_mimetype=mimetype) formatter.configure(*args, **kwargs) return formatter @property def accept_mimetypes(self) -> List[str]: return [m for f in self._formatters for m in f.mimetypes] def best_mimetype(self) -> str: raise NotImplementedError( "best_mimetype() should be implemented in " "framework-specific subclasses of Negotiator" ) def _abort(self, status_code: int, err: Optional[str] = None) -> NoReturn: raise NotImplementedError( "_abort() should be implemented in framework-specific " "subclasses of Negotiator" ) def negotiate( negotiator_cls: Type[Negotiator], formatter_cls: Type[Formatter], *args, **kwargs ) -> Callable: def _negotiate(f, *args, **kwargs): return f.negotiator(*args, **kwargs) def decorate(f): if not hasattr(f, "negotiator"): f.negotiator = negotiator_cls(f) f.negotiator.register_formatter(formatter_cls, *args, **kwargs) return decorator(_negotiate, f) return decorate diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py index 8d8cbf1..3b6c139 100644 --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -1,294 +1,293 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 import datetime from enum import Enum import json import traceback import types +from typing import Any, Dict, Tuple, Union from uuid import UUID import arrow import iso8601 import msgpack - -from typing import Any, Dict, Union, Tuple from requests import Response from swh.core.api.classes import PagedResult def encode_datetime(dt: datetime.datetime) -> str: """Wrapper of datetime.datetime.isoformat() that forbids naive datetimes.""" if dt.tzinfo is None: raise ValueError(f"{dt} is a naive datetime.") return dt.isoformat() def _encode_paged_result(obj: PagedResult) -> Dict[str, Any]: """Serialize PagedResult to a Dict.""" return { "results": obj.results, "next_page_token": obj.next_page_token, } def _decode_paged_result(obj: Dict[str, Any]) -> PagedResult: """Deserialize Dict into PagedResult""" return PagedResult(results=obj["results"], next_page_token=obj["next_page_token"],) ENCODERS = [ (arrow.Arrow, "arrow", arrow.Arrow.isoformat), (datetime.datetime, "datetime", encode_datetime), ( datetime.timedelta, "timedelta", lambda o: { "days": o.days, "seconds": o.seconds, "microseconds": o.microseconds, }, ), (UUID, "uuid", str), (PagedResult, "paged_result", _encode_paged_result), # Only for JSON: (bytes, "bytes", lambda o: base64.b85encode(o).decode("ascii")), ] DECODERS = { "arrow": arrow.get, "datetime": lambda d: iso8601.parse_date(d, default_timezone=None), "timedelta": lambda d: datetime.timedelta(**d), "uuid": UUID, "paged_result": _decode_paged_result, # Only for JSON: "bytes": base64.b85decode, } class MsgpackExtTypeCodes(Enum): LONG_INT = 1 def encode_data_client(data: Any, extra_encoders=None) -> bytes: try: return msgpack_dumps(data, extra_encoders=extra_encoders) except OverflowError as e: raise ValueError("Limits were reached. Please, check your input.\n" + str(e)) def decode_response(response: Response, extra_decoders=None) -> Any: content_type = response.headers["content-type"] if content_type.startswith("application/x-msgpack"): r = msgpack_loads(response.content, extra_decoders=extra_decoders) elif content_type.startswith("application/json"): r = json_loads(response.text, extra_decoders=extra_decoders) elif content_type.startswith("text/"): r = response.text else: raise ValueError("Wrong content type `%s` for API response" % content_type) return r class SWHJSONEncoder(json.JSONEncoder): """JSON encoder for data structures generated by Software Heritage. This JSON encoder extends the default Python JSON encoder and adds awareness for the following specific types: - bytes (get encoded as a Base85 string); - datetime.datetime (get encoded as an ISO8601 string). Non-standard types get encoded as a a dictionary with two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. SWHJSONEncoder also encodes arbitrary iterables as a list (allowing serialization of generators). Caveats: Limitations in the JSONEncoder extension mechanism prevent us from "escaping" dictionaries that only contain the swhtype and d keys, and therefore arbitrary data structures can't be round-tripped through SWHJSONEncoder and SWHJSONDecoder. """ def __init__(self, extra_encoders=None, **kwargs): super().__init__(**kwargs) self.encoders = ENCODERS if extra_encoders: self.encoders += extra_encoders def default(self, o: Any) -> Union[Dict[str, Union[Dict[str, int], str]], list]: for (type_, type_name, encoder) in self.encoders: if isinstance(o, type_): return { "swhtype": type_name, "d": encoder(o), } try: return super().default(o) except TypeError as e: try: iterable = iter(o) except TypeError: raise e from None else: return list(iterable) class SWHJSONDecoder(json.JSONDecoder): """JSON decoder for data structures encoded with SWHJSONEncoder. This JSON decoder extends the default Python JSON decoder, allowing the decoding of: - bytes (encoded as a Base85 string); - datetime.datetime (encoded as an ISO8601 string). Non-standard types must be encoded as a a dictionary with exactly two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. To limit the impact our encoding, if the swhtype key doesn't contain a known value, the dictionary is decoded as-is. """ def __init__(self, extra_decoders=None, **kwargs): super().__init__(**kwargs) self.decoders = DECODERS if extra_decoders: self.decoders = {**self.decoders, **extra_decoders} def decode_data(self, o: Any) -> Any: if isinstance(o, dict): if set(o.keys()) == {"d", "swhtype"}: if o["swhtype"] == "bytes": return base64.b85decode(o["d"]) decoder = self.decoders.get(o["swhtype"]) if decoder: return decoder(self.decode_data(o["d"])) return {key: self.decode_data(value) for key, value in o.items()} if isinstance(o, list): return [self.decode_data(value) for value in o] else: return o def raw_decode(self, s: str, idx: int = 0) -> Tuple[Any, int]: data, index = super().raw_decode(s, idx) return self.decode_data(data), index def json_dumps(data: Any, extra_encoders=None) -> str: return json.dumps(data, cls=SWHJSONEncoder, extra_encoders=extra_encoders) def json_loads(data: str, extra_decoders=None) -> Any: return json.loads(data, cls=SWHJSONDecoder, extra_decoders=extra_decoders) def msgpack_dumps(data: Any, extra_encoders=None) -> bytes: """Write data as a msgpack stream""" encoders = ENCODERS if extra_encoders: encoders += extra_encoders def encode_types(obj): if isinstance(obj, int): # integer overflowed while packing. Handle it as an extended type length, rem = divmod(obj.bit_length(), 8) if rem: length += 1 return msgpack.ExtType( MsgpackExtTypeCodes.LONG_INT.value, int.to_bytes(obj, length, "big") ) if isinstance(obj, types.GeneratorType): return list(obj) for (type_, type_name, encoder) in encoders: if isinstance(obj, type_): return { b"swhtype": type_name, b"d": encoder(obj), } return obj return msgpack.packb(data, use_bin_type=True, default=encode_types) def msgpack_loads(data: bytes, extra_decoders=None) -> Any: """Read data as a msgpack stream. .. Caution:: This function is used by swh.journal to decode the contents of the journal. This function **must** be kept backwards-compatible. """ decoders = DECODERS if extra_decoders: decoders = {**decoders, **extra_decoders} def ext_hook(code, data): if code == MsgpackExtTypeCodes.LONG_INT.value: return int.from_bytes(data, "big") raise ValueError("Unknown msgpack extended code %s" % code) def decode_types(obj): # Support for current encodings if set(obj.keys()) == {b"d", b"swhtype"}: decoder = decoders.get(obj[b"swhtype"]) if decoder: return decoder(obj[b"d"]) # Support for legacy encodings if b"__datetime__" in obj and obj[b"__datetime__"]: return iso8601.parse_date(obj[b"s"], default_timezone=None) if b"__uuid__" in obj and obj[b"__uuid__"]: return UUID(obj[b"s"]) if b"__timedelta__" in obj and obj[b"__timedelta__"]: return datetime.timedelta(**obj[b"s"]) if b"__arrow__" in obj and obj[b"__arrow__"]: return arrow.get(obj[b"s"]) # Fallthrough return obj try: try: return msgpack.unpackb( data, raw=False, object_hook=decode_types, ext_hook=ext_hook, strict_map_key=False, ) except TypeError: # msgpack < 0.6.0 return msgpack.unpackb( data, raw=False, object_hook=decode_types, ext_hook=ext_hook ) except TypeError: # msgpack < 0.5.2 return msgpack.unpackb( data, encoding="utf-8", object_hook=decode_types, ext_hook=ext_hook ) def exception_to_dict(exception): tb = traceback.format_exception(None, exception, exception.__traceback__) return { "exception": { "type": type(exception).__name__, "args": exception.args, "message": str(exception), "traceback": tb, } } diff --git a/swh/core/api/tests/test_async.py b/swh/core/api/tests/test_async.py index 7b8400a..4b25aa3 100644 --- a/swh/core/api/tests/test_async.py +++ b/swh/core/api/tests/test_async.py @@ -1,241 +1,243 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime -import msgpack import json +import msgpack import pytest -from swh.core.api.asynchronous import RPCServerApp, Response -from swh.core.api.asynchronous import encode_msgpack, decode_request - -from swh.core.api.serializers import msgpack_dumps, SWHJSONEncoder - +from swh.core.api.asynchronous import ( + Response, + RPCServerApp, + decode_request, + encode_msgpack, +) +from swh.core.api.serializers import SWHJSONEncoder, msgpack_dumps pytest_plugins = ["aiohttp.pytest_plugin", "pytester"] class TestServerException(Exception): pass class TestClientError(Exception): pass async def root(request): return Response("toor") STRUCT = { "txt": "something stupid", # 'date': datetime.date(2019, 6, 9), # not supported "datetime": datetime.datetime(2019, 6, 9, 10, 12, tzinfo=datetime.timezone.utc), "timedelta": datetime.timedelta(days=-2, hours=3), "int": 42, "float": 3.14, "subdata": { "int": 42, "datetime": datetime.datetime( 2019, 6, 10, 11, 12, tzinfo=datetime.timezone.utc ), }, "list": [ 42, datetime.datetime(2019, 9, 10, 11, 12, tzinfo=datetime.timezone.utc), "ok", ], } async def struct(request): return Response(STRUCT) async def echo(request): data = await decode_request(request) return Response(data) async def server_exception(request): raise TestServerException() async def client_error(request): raise TestClientError() async def echo_no_nego(request): # let the content negotiation handle the serialization for us... data = await decode_request(request) ret = encode_msgpack(data) return ret def check_mimetype(src, dst): src = src.split(";")[0].strip() dst = dst.split(";")[0].strip() assert src == dst @pytest.fixture def async_app(): app = RPCServerApp() app.client_exception_classes = (TestClientError,) app.router.add_route("GET", "/", root) app.router.add_route("GET", "/struct", struct) app.router.add_route("POST", "/echo", echo) app.router.add_route("GET", "/server_exception", server_exception) app.router.add_route("GET", "/client_error", client_error) app.router.add_route("POST", "/echo-no-nego", echo_no_nego) return app async def test_get_simple(async_app, aiohttp_client) -> None: assert async_app is not None cli = await aiohttp_client(async_app) resp = await cli.get("/") assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") data = await resp.read() value = msgpack.unpackb(data, raw=False) assert value == "toor" async def test_get_server_exception(async_app, aiohttp_client) -> None: cli = await aiohttp_client(async_app) resp = await cli.get("/server_exception") assert resp.status == 500 data = await resp.read() data = msgpack.unpackb(data, raw=False) assert data["exception"]["type"] == "TestServerException" async def test_get_client_error(async_app, aiohttp_client) -> None: cli = await aiohttp_client(async_app) resp = await cli.get("/client_error") assert resp.status == 400 data = await resp.read() data = msgpack.unpackb(data, raw=False) assert data["exception"]["type"] == "TestClientError" async def test_get_simple_nego(async_app, aiohttp_client) -> None: cli = await aiohttp_client(async_app) for ctype in ("x-msgpack", "json"): resp = await cli.get("/", headers={"Accept": "application/%s" % ctype}) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) assert (await decode_request(resp)) == "toor" async def test_get_struct(async_app, aiohttp_client) -> None: """Test returned structured from a simple GET data is OK""" cli = await aiohttp_client(async_app) resp = await cli.get("/struct") assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT async def test_get_struct_nego(async_app, aiohttp_client) -> None: """Test returned structured from a simple GET data is OK""" cli = await aiohttp_client(async_app) for ctype in ("x-msgpack", "json"): resp = await cli.get("/struct", headers={"Accept": "application/%s" % ctype}) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) assert (await decode_request(resp)) == STRUCT async def test_post_struct_msgpack(async_app, aiohttp_client) -> None: """Test that msgpack encoded posted struct data is returned as is""" cli = await aiohttp_client(async_app) # simple struct resp = await cli.post( "/echo", headers={"Content-Type": "application/x-msgpack"}, data=msgpack_dumps({"toto": 42}), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == {"toto": 42} # complex struct resp = await cli.post( "/echo", headers={"Content-Type": "application/x-msgpack"}, data=msgpack_dumps(STRUCT), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT async def test_post_struct_json(async_app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as is""" cli = await aiohttp_client(async_app) resp = await cli.post( "/echo", headers={"Content-Type": "application/json"}, data=json.dumps({"toto": 42}, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == {"toto": 42} resp = await cli.post( "/echo", headers={"Content-Type": "application/json"}, data=json.dumps(STRUCT, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") # assert resp.headers['Content-Type'] == 'application/x-msgpack' assert (await decode_request(resp)) == STRUCT async def test_post_struct_nego(async_app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as is using content negotiation (accept json or msgpack). """ cli = await aiohttp_client(async_app) for ctype in ("x-msgpack", "json"): resp = await cli.post( "/echo", headers={ "Content-Type": "application/json", "Accept": "application/%s" % ctype, }, data=json.dumps(STRUCT, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) assert (await decode_request(resp)) == STRUCT async def test_post_struct_no_nego(async_app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as msgpack when using non-negotiation-compatible handlers. """ cli = await aiohttp_client(async_app) for ctype in ("x-msgpack", "json"): resp = await cli.post( "/echo-no-nego", headers={ "Content-Type": "application/json", "Accept": "application/%s" % ctype, }, data=json.dumps(STRUCT, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT diff --git a/swh/core/api/tests/test_classes.py b/swh/core/api/tests/test_classes.py index d63d410..7509f51 100644 --- a/swh/core/api/tests/test_classes.py +++ b/swh/core/api/tests/test_classes.py @@ -1,60 +1,60 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from swh.core.api.classes import PagedResult as CorePagedResult, stream_results - from typing import TypeVar +from swh.core.api.classes import PagedResult as CorePagedResult +from swh.core.api.classes import stream_results T = TypeVar("T") TestPagedResult = CorePagedResult[T, bytes] def test_stream_results_no_result(): def paged_results(page_token) -> TestPagedResult: return TestPagedResult(results=[], next_page_token=None) # only 1 call, no pagination actual_data = stream_results(paged_results) assert list(actual_data) == [] def test_stream_results_no_pagination(): input_data = [ {"url": "something"}, {"url": "something2"}, ] def paged_results(page_token) -> TestPagedResult: return TestPagedResult(results=input_data, next_page_token=None) # only 1 call, no pagination actual_data = stream_results(paged_results) assert list(actual_data) == input_data def test_stream_results_pagination(): input_data = [ {"url": "something"}, {"url": "something2"}, ] input_data2 = [ {"url": "something3"}, ] input_data3 = [ {"url": "something4"}, ] def page_results2(page_token=None) -> TestPagedResult: result_per_token = { None: TestPagedResult(results=input_data, next_page_token=b"two"), b"two": TestPagedResult(results=input_data2, next_page_token=b"three"), b"three": TestPagedResult(results=input_data3, next_page_token=None), } return result_per_token[page_token] # multiple calls to solve the pagination calls actual_data = stream_results(page_results2) assert list(actual_data) == input_data + input_data2 + input_data3 diff --git a/swh/core/api/tests/test_gunicorn.py b/swh/core/api/tests/test_gunicorn.py index c0d12ef..92a3284 100644 --- a/swh/core/api/tests/test_gunicorn.py +++ b/swh/core/api/tests/test_gunicorn.py @@ -1,116 +1,117 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os -import pkg_resources from unittest.mock import patch +import pkg_resources + import swh.core.api.gunicorn_config as gunicorn_config def test_post_fork_default(): with patch("sentry_sdk.init") as sentry_sdk_init: gunicorn_config.post_fork(None, None) sentry_sdk_init.assert_not_called() def test_post_fork_with_dsn_env(): flask_integration = object() # unique object to check for equality with patch( "sentry_sdk.integrations.flask.FlaskIntegration", new=lambda: flask_integration ): with patch("sentry_sdk.init") as sentry_sdk_init: with patch.dict(os.environ, {"SWH_SENTRY_DSN": "test_dsn"}): gunicorn_config.post_fork(None, None) sentry_sdk_init.assert_called_once_with( dsn="test_dsn", integrations=[flask_integration], debug=False, release=None, environment=None, ) def test_post_fork_with_package_env(): flask_integration = object() # unique object to check for equality with patch( "sentry_sdk.integrations.flask.FlaskIntegration", new=lambda: flask_integration ): with patch("sentry_sdk.init") as sentry_sdk_init: with patch.dict( os.environ, { "SWH_SENTRY_DSN": "test_dsn", "SWH_SENTRY_ENVIRONMENT": "tests", "SWH_MAIN_PACKAGE": "swh.core", }, ): gunicorn_config.post_fork(None, None) version = pkg_resources.get_distribution("swh.core").version sentry_sdk_init.assert_called_once_with( dsn="test_dsn", integrations=[flask_integration], debug=False, release="swh.core@" + version, environment="tests", ) def test_post_fork_debug(): flask_integration = object() # unique object to check for equality with patch( "sentry_sdk.integrations.flask.FlaskIntegration", new=lambda: flask_integration ): with patch("sentry_sdk.init") as sentry_sdk_init: with patch.dict( os.environ, {"SWH_SENTRY_DSN": "test_dsn", "SWH_SENTRY_DEBUG": "1"} ): gunicorn_config.post_fork(None, None) sentry_sdk_init.assert_called_once_with( dsn="test_dsn", integrations=[flask_integration], debug=True, release=None, environment=None, ) def test_post_fork_no_flask(): with patch("sentry_sdk.init") as sentry_sdk_init: with patch.dict(os.environ, {"SWH_SENTRY_DSN": "test_dsn"}): gunicorn_config.post_fork(None, None, flask=False) sentry_sdk_init.assert_called_once_with( dsn="test_dsn", integrations=[], debug=False, release=None, environment=None, ) def test_post_fork_extras(): flask_integration = object() # unique object to check for equality with patch( "sentry_sdk.integrations.flask.FlaskIntegration", new=lambda: flask_integration ): with patch("sentry_sdk.init") as sentry_sdk_init: with patch.dict(os.environ, {"SWH_SENTRY_DSN": "test_dsn"}): gunicorn_config.post_fork( None, None, sentry_integrations=["foo"], extra_sentry_kwargs={"bar": "baz"}, ) sentry_sdk_init.assert_called_once_with( dsn="test_dsn", integrations=["foo", flask_integration], debug=False, bar="baz", release=None, environment=None, ) diff --git a/swh/core/api/tests/test_rpc_client.py b/swh/core/api/tests/test_rpc_client.py index 24a2fba..7dffac1 100644 --- a/swh/core/api/tests/test_rpc_client.py +++ b/swh/core/api/tests/test_rpc_client.py @@ -1,85 +1,86 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import re + import pytest -from swh.core.api import remote_api_endpoint, RPCClient +from swh.core.api import RPCClient, remote_api_endpoint -from .test_serializers import ExtraType, extra_encoders, extra_decoders +from .test_serializers import ExtraType, extra_decoders, extra_encoders @pytest.fixture def rpc_client(requests_mock): class TestStorage: @remote_api_endpoint("test_endpoint_url") def test_endpoint(self, test_data, db=None, cur=None): ... @remote_api_endpoint("path/to/endpoint") def something(self, data, db=None, cur=None): ... @remote_api_endpoint("serializer_test") def serializer_test(self, data, db=None, cur=None): ... @remote_api_endpoint("overridden/endpoint") def overridden_method(self, data): return "foo" class Testclient(RPCClient): backend_class = TestStorage extra_type_encoders = extra_encoders extra_type_decoders = extra_decoders def overridden_method(self, data): return "bar" def callback(request, context): assert request.headers["Content-Type"] == "application/x-msgpack" context.headers["Content-Type"] = "application/x-msgpack" if request.path == "/test_endpoint_url": context.content = b"\xa3egg" elif request.path == "/path/to/endpoint": context.content = b"\xa4spam" elif request.path == "/serializer_test": context.content = ( b"\x82\xc4\x07swhtype\xa9extratype" b"\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux" ) else: assert False return context.content requests_mock.post(re.compile("mock://example.com/"), content=callback) return Testclient(url="mock://example.com") def test_client(rpc_client): assert hasattr(rpc_client, "test_endpoint") assert hasattr(rpc_client, "something") res = rpc_client.test_endpoint("spam") assert res == "egg" res = rpc_client.test_endpoint(test_data="spam") assert res == "egg" res = rpc_client.something("whatever") assert res == "spam" res = rpc_client.something(data="whatever") assert res == "spam" def test_client_extra_serializers(rpc_client): res = rpc_client.serializer_test(["foo", ExtraType("bar", b"baz")]) assert res == ExtraType({"spam": "egg"}, "qux") def test_client_overridden_method(rpc_client): res = rpc_client.overridden_method("foo") assert res == "bar" diff --git a/swh/core/api/tests/test_rpc_client_server.py b/swh/core/api/tests/test_rpc_client_server.py index ec651eb..81b0afa 100644 --- a/swh/core/api/tests/test_rpc_client_server.py +++ b/swh/core/api/tests/test_rpc_client_server.py @@ -1,111 +1,117 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest -from swh.core.api import remote_api_endpoint, RPCServerApp, RPCClient -from swh.core.api import error_handler, encode_data_server, RemoteException +from swh.core.api import ( + RemoteException, + RPCClient, + RPCServerApp, + encode_data_server, + error_handler, + remote_api_endpoint, +) # this class is used on the server part class RPCTest: @remote_api_endpoint("endpoint_url") def endpoint(self, test_data, db=None, cur=None): assert test_data == "spam" return "egg" @remote_api_endpoint("path/to/endpoint") def something(self, data, db=None, cur=None): return data @remote_api_endpoint("raises_typeerror") def raise_typeerror(self): raise TypeError("Did I pass through?") # this class is used on the client part. We cannot inherit from RPCTest # because the automagic metaclass based code that generates the RPCClient # proxy class from this does not handle inheritance properly. # We do add an endpoint on the client side that has no implementation # server-side to test this very situation (in should generate a 404) class RPCTest2: @remote_api_endpoint("endpoint_url") def endpoint(self, test_data, db=None, cur=None): assert test_data == "spam" return "egg" @remote_api_endpoint("path/to/endpoint") def something(self, data, db=None, cur=None): return data @remote_api_endpoint("not_on_server") def not_on_server(self, db=None, cur=None): return "ok" @remote_api_endpoint("raises_typeerror") def raise_typeerror(self): return "data" class RPCTestClient(RPCClient): backend_class = RPCTest2 @pytest.fixture def app(): # This fixture is used by the 'swh_rpc_adapter' fixture # which is defined in swh/core/pytest_plugin.py application = RPCServerApp("testapp", backend_class=RPCTest) @application.errorhandler(Exception) def my_error_handler(exception): return error_handler(exception, encode_data_server) return application @pytest.fixture def swh_rpc_client_class(): # This fixture is used by the 'swh_rpc_client' fixture # which is defined in swh/core/pytest_plugin.py return RPCTestClient def test_api_client_endpoint_missing(swh_rpc_client): with pytest.raises(AttributeError): swh_rpc_client.missing(data="whatever") def test_api_server_endpoint_missing(swh_rpc_client): # A 'missing' endpoint (server-side) should raise an exception # due to a 404, since at the end, we do a GET/POST an inexistent URL with pytest.raises(Exception, match="404 not found"): swh_rpc_client.not_on_server() def test_api_endpoint_kwargs(swh_rpc_client): res = swh_rpc_client.something(data="whatever") assert res == "whatever" res = swh_rpc_client.endpoint(test_data="spam") assert res == "egg" def test_api_endpoint_args(swh_rpc_client): res = swh_rpc_client.something("whatever") assert res == "whatever" res = swh_rpc_client.endpoint("spam") assert res == "egg" def test_api_typeerror(swh_rpc_client): with pytest.raises(RemoteException) as exc_info: swh_rpc_client.raise_typeerror() assert exc_info.value.args[0]["type"] == "TypeError" assert exc_info.value.args[0]["args"] == ["Did I pass through?"] assert ( str(exc_info.value) == "" ) diff --git a/swh/core/api/tests/test_rpc_server.py b/swh/core/api/tests/test_rpc_server.py index 31fdb27..84bdf2f 100644 --- a/swh/core/api/tests/test_rpc_server.py +++ b/swh/core/api/tests/test_rpc_server.py @@ -1,122 +1,128 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import pytest import json -import msgpack from flask import url_for +import msgpack +import pytest + +from swh.core.api import ( + JSONFormatter, + MsgpackFormatter, + RPCServerApp, + negotiate, + remote_api_endpoint, +) -from swh.core.api import remote_api_endpoint, RPCServerApp -from swh.core.api import negotiate, JSONFormatter, MsgpackFormatter -from .test_serializers import ExtraType, extra_encoders, extra_decoders +from .test_serializers import ExtraType, extra_decoders, extra_encoders class MyRPCServerApp(RPCServerApp): extra_type_encoders = extra_encoders extra_type_decoders = extra_decoders @pytest.fixture def app(): class TestStorage: @remote_api_endpoint("test_endpoint_url") def test_endpoint(self, test_data, db=None, cur=None): assert test_data == "spam" return "egg" @remote_api_endpoint("path/to/endpoint") def something(self, data, db=None, cur=None): return data @remote_api_endpoint("serializer_test") def serializer_test(self, data, db=None, cur=None): assert data == ["foo", ExtraType("bar", b"baz")] return ExtraType({"spam": "egg"}, "qux") return MyRPCServerApp("testapp", backend_class=TestStorage) def test_api_endpoint(flask_app_client): res = flask_app_client.post( url_for("something"), headers=[("Content-Type", "application/json"), ("Accept", "application/json")], data=json.dumps({"data": "toto"}), ) assert res.status_code == 200 assert res.mimetype == "application/json" def test_api_nego_default(flask_app_client): res = flask_app_client.post( url_for("something"), headers=[("Content-Type", "application/json")], data=json.dumps({"data": "toto"}), ) assert res.status_code == 200 assert res.mimetype == "application/json" assert res.data == b'"toto"' def test_api_nego_accept(flask_app_client): res = flask_app_client.post( url_for("something"), headers=[ ("Accept", "application/x-msgpack"), ("Content-Type", "application/x-msgpack"), ], data=msgpack.dumps({"data": "toto"}), ) assert res.status_code == 200 assert res.mimetype == "application/x-msgpack" assert res.data == b"\xa4toto" def test_rpc_server(flask_app_client): res = flask_app_client.post( url_for("test_endpoint"), headers=[ ("Content-Type", "application/x-msgpack"), ("Accept", "application/x-msgpack"), ], data=b"\x81\xa9test_data\xa4spam", ) assert res.status_code == 200 assert res.mimetype == "application/x-msgpack" assert res.data == b"\xa3egg" def test_rpc_server_extra_serializers(flask_app_client): res = flask_app_client.post( url_for("serializer_test"), headers=[ ("Content-Type", "application/x-msgpack"), ("Accept", "application/x-msgpack"), ], data=b"\x81\xa4data\x92\xa3foo\x82\xc4\x07swhtype\xa9extratype" b"\xc4\x01d\x92\xa3bar\xc4\x03baz", ) assert res.status_code == 200 assert res.mimetype == "application/x-msgpack" assert res.data == ( b"\x82\xc4\x07swhtype\xa9extratype\xc4" b"\x01d\x92\x81\xa4spam\xa3egg\xa3qux" ) def test_api_negotiate_no_extra_encoders(app, flask_app_client): url = "/test/negotiate/no/extra/encoders" @app.route(url, methods=["POST"]) @negotiate(MsgpackFormatter) @negotiate(JSONFormatter) def endpoint(): return "test" res = flask_app_client.post(url, headers=[("Content-Type", "application/json")],) assert res.status_code == 200 assert res.mimetype == "application/json" assert res.data == b'"test"' diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py index 0b4803d..9d7c261 100644 --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -1,252 +1,250 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import json - from typing import Any, Callable, List, Tuple, Union - -from arrow import Arrow from uuid import UUID -import pytest import arrow +from arrow import Arrow +import pytest import requests from swh.core.api.classes import PagedResult from swh.core.api.serializers import ( SWHJSONDecoder, SWHJSONEncoder, + decode_response, msgpack_dumps, msgpack_loads, - decode_response, ) class ExtraType: def __init__(self, arg1, arg2): self.arg1 = arg1 self.arg2 = arg2 def __repr__(self): return f"ExtraType({self.arg1}, {self.arg2})" def __eq__(self, other): return isinstance(other, ExtraType) and (self.arg1, self.arg2) == ( other.arg1, other.arg2, ) extra_encoders: List[Tuple[type, str, Callable[..., Any]]] = [ (ExtraType, "extratype", lambda o: (o.arg1, o.arg2)) ] extra_decoders = { "extratype": lambda o: ExtraType(*o), } TZ = datetime.timezone(datetime.timedelta(minutes=118)) DATA_BYTES = b"123456789\x99\xaf\xff\x00\x12" ENCODED_DATA_BYTES = {"swhtype": "bytes", "d": "F)}kWH8wXmIhn8j01^"} DATA_DATETIME = datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, tzinfo=TZ,) ENCODED_DATA_DATETIME = { "swhtype": "datetime", "d": "2015-03-04T18:25:13.001234+01:58", } DATA_TIMEDELTA = datetime.timedelta(64) ENCODED_DATA_TIMEDELTA = { "swhtype": "timedelta", "d": {"days": 64, "seconds": 0, "microseconds": 0}, } DATA_ARROW = arrow.get("2018-04-25T16:17:53.533672+00:00") ENCODED_DATA_ARROW = {"swhtype": "arrow", "d": "2018-04-25T16:17:53.533672+00:00"} DATA_UUID = UUID("cdd8f804-9db6-40c3-93ab-5955d3836234") ENCODED_DATA_UUID = {"swhtype": "uuid", "d": "cdd8f804-9db6-40c3-93ab-5955d3836234"} # For test demonstration purposes TestPagedResultStr = PagedResult[ Union[UUID, datetime.datetime, datetime.timedelta], str ] DATA_PAGED_RESULT = TestPagedResultStr( results=[DATA_UUID, DATA_DATETIME, DATA_TIMEDELTA], next_page_token="10", ) ENCODED_DATA_PAGED_RESULT = { "d": { "results": [ENCODED_DATA_UUID, ENCODED_DATA_DATETIME, ENCODED_DATA_TIMEDELTA,], "next_page_token": "10", }, "swhtype": "paged_result", } TestPagedResultTuple = PagedResult[Union[str, bytes, Arrow], List[Union[str, UUID]]] DATA_PAGED_RESULT2 = TestPagedResultTuple( results=["data0", DATA_BYTES, DATA_ARROW], next_page_token=["10", DATA_UUID], ) ENCODED_DATA_PAGED_RESULT2 = { "d": { "results": ["data0", ENCODED_DATA_BYTES, ENCODED_DATA_ARROW,], "next_page_token": ["10", ENCODED_DATA_UUID], }, "swhtype": "paged_result", } DATA = { "bytes": DATA_BYTES, "datetime_tz": DATA_DATETIME, "datetime_utc": datetime.datetime( 2015, 3, 4, 18, 25, 13, 1234, tzinfo=datetime.timezone.utc ), "datetime_delta": DATA_TIMEDELTA, "arrow_date": DATA_ARROW, "swhtype": "fake", "swh_dict": {"swhtype": 42, "d": "test"}, "random_dict": {"swhtype": 43}, "uuid": DATA_UUID, "paged-result": DATA_PAGED_RESULT, "paged-result2": DATA_PAGED_RESULT2, } ENCODED_DATA = { "bytes": ENCODED_DATA_BYTES, "datetime_tz": ENCODED_DATA_DATETIME, "datetime_utc": {"swhtype": "datetime", "d": "2015-03-04T18:25:13.001234+00:00",}, "datetime_delta": ENCODED_DATA_TIMEDELTA, "arrow_date": ENCODED_DATA_ARROW, "swhtype": "fake", "swh_dict": {"swhtype": 42, "d": "test"}, "random_dict": {"swhtype": 43}, "uuid": ENCODED_DATA_UUID, "paged-result": ENCODED_DATA_PAGED_RESULT, "paged-result2": ENCODED_DATA_PAGED_RESULT2, } def test_serializers_round_trip_json(): json_data = json.dumps(DATA, cls=SWHJSONEncoder) actual_data = json.loads(json_data, cls=SWHJSONDecoder) assert actual_data == DATA def test_serializers_round_trip_json_extra_types(): expected_original_data = [ExtraType("baz", DATA), "qux"] data = json.dumps( expected_original_data, cls=SWHJSONEncoder, extra_encoders=extra_encoders ) actual_data = json.loads(data, cls=SWHJSONDecoder, extra_decoders=extra_decoders) assert actual_data == expected_original_data def test_serializers_encode_swh_json(): json_str = json.dumps(DATA, cls=SWHJSONEncoder) actual_data = json.loads(json_str) assert actual_data == ENCODED_DATA def test_serializers_round_trip_msgpack(): expected_original_data = { **DATA, "none_dict_key": {None: 42}, "long_int_is_loooong": 10000000000000000000000000000000, } data = msgpack_dumps(expected_original_data) actual_data = msgpack_loads(data) assert actual_data == expected_original_data def test_serializers_round_trip_msgpack_extra_types(): original_data = [ExtraType("baz", DATA), "qux"] data = msgpack_dumps(original_data, extra_encoders=extra_encoders) actual_data = msgpack_loads(data, extra_decoders=extra_decoders) assert actual_data == original_data def test_serializers_generator_json(): data = json.dumps((i for i in range(5)), cls=SWHJSONEncoder) assert json.loads(data, cls=SWHJSONDecoder) == [i for i in range(5)] def test_serializers_generator_msgpack(): data = msgpack_dumps((i for i in range(5))) assert msgpack_loads(data) == [i for i in range(5)] def test_serializers_decode_response_json(requests_mock): requests_mock.get( "https://example.org/test/data", json=ENCODED_DATA, headers={"content-type": "application/json"}, ) response = requests.get("https://example.org/test/data") assert decode_response(response) == DATA def test_serializers_decode_legacy_msgpack(): legacy_msgpack = { "bytes": b"\xc4\x0e123456789\x99\xaf\xff\x00\x12", "datetime_tz": ( b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xd9 " b"2015-03-04T18:25:13.001234+01:58" ), "datetime_utc": ( b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xd9 " b"2015-03-04T18:25:13.001234+00:00" ), "datetime_delta": ( b"\x82\xc4\r__timedelta__\xc3\xc4\x01s\x83\xa4" b"days@\xa7seconds\x00\xacmicroseconds\x00" ), "arrow_date": ( b"\x82\xc4\t__arrow__\xc3\xc4\x01s\xd9 2018-04-25T16:17:53.533672+00:00" ), "swhtype": b"\xa4fake", "swh_dict": b"\x82\xa7swhtype*\xa1d\xa4test", "random_dict": b"\x81\xa7swhtype+", "uuid": ( b"\x82\xc4\x08__uuid__\xc3\xc4\x01s\xd9$" b"cdd8f804-9db6-40c3-93ab-5955d3836234" ), } for k, v in legacy_msgpack.items(): assert msgpack_loads(v) == DATA[k] def test_serializers_encode_native_datetime(): dt = datetime.datetime(2015, 1, 1, 12, 4, 42, 231455) with pytest.raises(ValueError, match="naive datetime"): msgpack_dumps(dt) def test_serializers_decode_naive_datetime(): expected_dt = datetime.datetime(2015, 1, 1, 12, 4, 42, 231455) # Current encoding assert ( msgpack_loads( b"\x82\xc4\x07swhtype\xa8datetime\xc4\x01d\xba" b"2015-01-01T12:04:42.231455" ) == expected_dt ) # Legacy encoding assert ( msgpack_loads( b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xba2015-01-01T12:04:42.231455" ) == expected_dt ) diff --git a/swh/core/cli/__init__.py b/swh/core/cli/__init__.py index a840d55..1488bd3 100644 --- a/swh/core/cli/__init__.py +++ b/swh/core/cli/__init__.py @@ -1,126 +1,128 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import logging.config import click import pkg_resources LOG_LEVEL_NAMES = ["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"]) logger = logging.getLogger(__name__) class AliasedGroup(click.Group): """A simple Group that supports command aliases, as well as notes related to options""" def __init__(self, name=None, commands=None, **attrs): self.option_notes = attrs.pop("option_notes", None) self.aliases = {} super().__init__(name, commands, **attrs) def get_command(self, ctx, cmd_name): return super().get_command(ctx, self.aliases.get(cmd_name, cmd_name)) def add_alias(self, name, alias): if not isinstance(name, str): name = name.name self.aliases[alias] = name def format_options(self, ctx, formatter): click.Command.format_options(self, ctx, formatter) if self.option_notes: with formatter.section("Notes"): formatter.write_text(self.option_notes) self.format_commands(ctx, formatter) def clean_exit_on_signal(signal, frame): """Raise a SystemExit exception to let command-line clients wind themselves down on exit""" raise SystemExit(0) @click.group( context_settings=CONTEXT_SETTINGS, cls=AliasedGroup, option_notes="""\ If both options are present, --log-level will override the root logger configuration set in --log-config. The --log-config YAML must conform to the logging.config.dictConfig schema documented at https://docs.python.org/3/library/logging.config.html. """, ) @click.option( "--log-level", "-l", default=None, type=click.Choice(LOG_LEVEL_NAMES), help="Log level (defaults to INFO).", ) @click.option( "--log-config", default=None, type=click.File("r"), help="Python yaml logging configuration file.", ) @click.option( "--sentry-dsn", default=None, help="DSN of the Sentry instance to report to" ) @click.option( "--sentry-debug/--no-sentry-debug", default=False, hidden=True, help="Enable debugging of sentry", ) @click.pass_context def swh(ctx, log_level, log_config, sentry_dsn, sentry_debug): """Command line interface for Software Heritage. """ import signal + import yaml + from ..sentry import init_sentry signal.signal(signal.SIGTERM, clean_exit_on_signal) signal.signal(signal.SIGINT, clean_exit_on_signal) init_sentry(sentry_dsn, debug=sentry_debug) if log_level is None and log_config is None: log_level = "INFO" if log_config: logging.config.dictConfig(yaml.safe_load(log_config.read())) if log_level: log_level = logging.getLevelName(log_level) logging.root.setLevel(log_level) ctx.ensure_object(dict) ctx.obj["log_level"] = log_level def main(): # Even though swh() sets up logging, we need an earlier basic logging setup # for the next few logging statements logging.basicConfig() # load plugins that define cli sub commands for entry_point in pkg_resources.iter_entry_points("swh.cli.subcommands"): try: cmd = entry_point.load() swh.add_command(cmd, name=entry_point.name) except Exception as e: logger.warning("Could not load subcommand %s: %s", entry_point.name, str(e)) return swh(auto_envvar_prefix="SWH") if __name__ == "__main__": main() diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py index 8090df6..86f81be 100755 --- a/swh/core/cli/db.py +++ b/swh/core/cli/db.py @@ -1,208 +1,208 @@ #!/usr/bin/env python3 # Copyright (C) 2018-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging -from os import path, environ +from os import environ, path from typing import Tuple - import warnings import click from swh.core.cli import CONTEXT_SETTINGS - warnings.filterwarnings("ignore") # noqa prevent psycopg from telling us sh*t logger = logging.getLogger(__name__) @click.group(name="db", context_settings=CONTEXT_SETTINGS) @click.option( "--config-file", "-C", default=None, type=click.Path(exists=True, dir_okay=False), help="Configuration file.", ) @click.pass_context def db(ctx, config_file): """Software Heritage database generic tools. """ from swh.core.config import read as config_read ctx.ensure_object(dict) if config_file is None: config_file = environ.get("SWH_CONFIG_FILENAME") cfg = config_read(config_file) ctx.obj["config"] = cfg @db.command(name="init", context_settings=CONTEXT_SETTINGS) @click.pass_context def init(ctx): """Initialize the database for every Software Heritage module found in the configuration file. For every configuration section in the config file that: 1. has the name of an existing swh package, 2. has credentials for a local db access, it will run the initialization scripts from the swh package against the given database. Example for the config file:: \b storage: cls: local args: db: postgresql:///?service=swh-storage objstorage: cls: remote args: url: http://swh-objstorage:5003/ the command: swh db -C /path/to/config.yml init will initialize the database for the `storage` section using initialization scripts from the `swh.storage` package. """ for modname, cfg in ctx.obj["config"].items(): if cfg.get("cls") == "local" and cfg.get("args", {}).get("db"): try: initialized, dbversion = populate_database_for_package( modname, cfg["args"]["db"] ) except click.BadParameter: logger.info( "Failed to load/find sql initialization files for %s", modname ) click.secho( "DONE database for {} {} at version {}".format( modname, "initialized" if initialized else "exists", dbversion ), fg="green", bold=True, ) @click.command(context_settings=CONTEXT_SETTINGS) @click.argument("module", required=True) @click.option( "--db-name", "-d", help="Database name.", default="softwareheritage-dev", show_default=True, ) @click.option( "--create-db/--no-create-db", "-C", help="Attempt to create the database.", default=False, ) def db_init(module, db_name, create_db): """Initialize a database for the Software Heritage . By default, does not attempt to create the database. Example: swh db-init -d swh-test storage If you want to specify non-default postgresql connection parameters, please provide them using standard environment variables. See psql(1) man page (section ENVIRONMENTS) for details. Example: PGPORT=5434 swh db-init indexer """ logger.debug("db_init %s dn_name=%s", module, db_name) if create_db: from swh.core.db.tests.db_testing import pg_createdb # Create the db (or fail silently if already existing) pg_createdb(db_name, check=False) initialized, dbversion = populate_database_for_package(module, db_name) # TODO: Ideally migrate the version from db_version to the latest # db version click.secho( "DONE database for {} {} at version {}".format( module, "initialized" if initialized else "exists", dbversion ), fg="green", bold=True, ) def get_sql_for_package(modname): import glob from importlib import import_module + from swh.core.utils import numfile_sortkey as sortkey if not modname.startswith("swh."): modname = "swh.{}".format(modname) try: m = import_module(modname) except ImportError: raise click.BadParameter("Unable to load module {}".format(modname)) sqldir = path.join(path.dirname(m.__file__), "sql") if not path.isdir(sqldir): raise click.BadParameter( "Module {} does not provide a db schema " "(no sql/ dir)".format(modname) ) return list(sorted(glob.glob(path.join(sqldir, "*.sql")), key=sortkey)) def populate_database_for_package(modname: str, conninfo: str) -> Tuple[bool, int]: """Populate the database, pointed at with `conninfo`, using the SQL files found in the package `modname`. Args: modname: Name of the module of which we're loading the files conninfo: connection info string for the SQL database Returns: Tuple with two elements: whether the database has been initialized; the current version of the database. """ import subprocess + from swh.core.db.tests.db_testing import swh_db_version current_version = swh_db_version(conninfo) if current_version is not None: return False, current_version sqlfiles = get_sql_for_package(modname) for sqlfile in sqlfiles: subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", conninfo, "-f", sqlfile, ] ) current_version = swh_db_version(conninfo) return True, current_version diff --git a/swh/core/config.py b/swh/core/config.py index f05983a..91a8a4a 100644 --- a/swh/core/config.py +++ b/swh/core/config.py @@ -1,355 +1,354 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from copy import deepcopy +from itertools import chain import logging import os -import yaml -from itertools import chain -from copy import deepcopy - from typing import Any, Callable, Dict, List, Optional, Tuple +import yaml logger = logging.getLogger(__name__) SWH_CONFIG_DIRECTORIES = [ "~/.config/swh", "~/.swh", "/etc/softwareheritage", ] SWH_GLOBAL_CONFIG = "global.yml" SWH_DEFAULT_GLOBAL_CONFIG = { "max_content_size": ("int", 100 * 1024 * 1024), } SWH_CONFIG_EXTENSIONS = [ ".yml", ] # conversion per type _map_convert_fn: Dict[str, Callable] = { "int": int, "bool": lambda x: x.lower() == "true", "list[str]": lambda x: [value.strip() for value in x.split(",")], "list[int]": lambda x: [int(value.strip()) for value in x.split(",")], } _map_check_fn: Dict[str, Callable] = { "int": lambda x: isinstance(x, int), "bool": lambda x: isinstance(x, bool), "list[str]": lambda x: (isinstance(x, list) and all(isinstance(y, str) for y in x)), "list[int]": lambda x: (isinstance(x, list) and all(isinstance(y, int) for y in x)), } def exists_accessible(filepath: str) -> bool: """Check whether a file exists, and is accessible. Returns: True if the file exists and is accessible False if the file does not exist Raises: PermissionError if the file cannot be read. """ try: os.stat(filepath) except PermissionError: raise except FileNotFoundError: return False else: if os.access(filepath, os.R_OK): return True else: raise PermissionError("Permission denied: {filepath!r}") def config_basepath(config_path: str) -> str: """Return the base path of a configuration file""" if config_path.endswith(".yml"): return config_path[:-4] return config_path def read_raw_config(base_config_path: str) -> Dict[str, Any]: """Read the raw config corresponding to base_config_path. Can read yml files. """ yml_file = f"{base_config_path}.yml" if exists_accessible(yml_file): logger.info("Loading config file %s", yml_file) with open(yml_file) as f: return yaml.safe_load(f) return {} def config_exists(config_path): """Check whether the given config exists""" basepath = config_basepath(config_path) return any( exists_accessible(basepath + extension) for extension in SWH_CONFIG_EXTENSIONS ) def read( conf_file: Optional[str] = None, default_conf: Optional[Dict[str, Tuple[str, Any]]] = None, ) -> Dict[str, Any]: """Read the user's configuration file. Fill in the gap using `default_conf`. `default_conf` is similar to this:: DEFAULT_CONF = { 'a': ('str', '/tmp/swh-loader-git/log'), 'b': ('str', 'dbname=swhloadergit') 'c': ('bool', true) 'e': ('bool', None) 'd': ('int', 10) } If conf_file is None, return the default config. """ conf: Dict[str, Any] = {} if conf_file: base_config_path = config_basepath(os.path.expanduser(conf_file)) conf = read_raw_config(base_config_path) or {} if not default_conf: return conf # remaining missing default configuration key are set # also type conversion is enforced for underneath layer for key, (nature_type, default_value) in default_conf.items(): val = conf.get(key, None) if val is None: # fallback to default value conf[key] = default_value elif not _map_check_fn.get(nature_type, lambda x: True)(val): # value present but not in the proper format, force type conversion conf[key] = _map_convert_fn.get(nature_type, lambda x: x)(val) return conf def priority_read( conf_filenames: List[str], default_conf: Optional[Dict[str, Tuple[str, Any]]] = None ): """Try reading the configuration files from conf_filenames, in order, and return the configuration from the first one that exists. default_conf has the same specification as it does in read. """ # Try all the files in order for filename in conf_filenames: full_filename = os.path.expanduser(filename) if config_exists(full_filename): return read(full_filename, default_conf) # Else, return the default configuration return read(None, default_conf) def merge_default_configs(base_config, *other_configs): """Merge several default config dictionaries, from left to right""" full_config = base_config.copy() for config in other_configs: full_config.update(config) return full_config def merge_configs(base: Optional[Dict[str, Any]], other: Optional[Dict[str, Any]]): """Merge two config dictionaries This does merge config dicts recursively, with the rules, for every value of the dicts (with 'val' not being a dict): - None + type -> type - type + None -> None - dict + dict -> dict (merged) - val + dict -> TypeError - dict + val -> TypeError - val + val -> val (other) for instance: >>> d1 = { ... 'key1': { ... 'skey1': 'value1', ... 'skey2': {'sskey1': 'value2'}, ... }, ... 'key2': 'value3', ... } with >>> d2 = { ... 'key1': { ... 'skey1': 'value4', ... 'skey2': {'sskey2': 'value5'}, ... }, ... 'key3': 'value6', ... } will give: >>> d3 = { ... 'key1': { ... 'skey1': 'value4', # <-- note this ... 'skey2': { ... 'sskey1': 'value2', ... 'sskey2': 'value5', ... }, ... }, ... 'key2': 'value3', ... 'key3': 'value6', ... } >>> assert merge_configs(d1, d2) == d3 Note that no type checking is done for anything but dicts. """ if not isinstance(base, dict) or not isinstance(other, dict): raise TypeError("Cannot merge a %s with a %s" % (type(base), type(other))) output = {} allkeys = set(chain(base.keys(), other.keys())) for k in allkeys: vb = base.get(k) vo = other.get(k) if isinstance(vo, dict): output[k] = merge_configs(vb is not None and vb or {}, vo) elif isinstance(vb, dict) and k in other and other[k] is not None: output[k] = merge_configs(vb, vo is not None and vo or {}) elif k in other: output[k] = deepcopy(vo) else: output[k] = deepcopy(vb) return output def swh_config_paths(base_filename: str) -> List[str]: """Return the Software Heritage specific configuration paths for the given filename.""" return [os.path.join(dirname, base_filename) for dirname in SWH_CONFIG_DIRECTORIES] def prepare_folders(conf, *keys): """Prepare the folder mentioned in config under keys. """ def makedir(folder): if not os.path.exists(folder): os.makedirs(folder) for key in keys: makedir(conf[key]) def load_global_config(): """Load the global Software Heritage config""" return priority_read( swh_config_paths(SWH_GLOBAL_CONFIG), SWH_DEFAULT_GLOBAL_CONFIG, ) def load_named_config(name, default_conf=None, global_conf=True): """Load the config named `name` from the Software Heritage configuration paths. If global_conf is True (default), read the global configuration too. """ conf = {} if global_conf: conf.update(load_global_config()) conf.update(priority_read(swh_config_paths(name), default_conf)) return conf class SWHConfig: """Mixin to add configuration parsing abilities to classes The class should override the class attributes: - DEFAULT_CONFIG (default configuration to be parsed) - CONFIG_BASE_FILENAME (the filename of the configuration to be used) This class defines one classmethod, parse_config_file, which parses a configuration file using the default config as set in the class attribute. """ DEFAULT_CONFIG = {} # type: Dict[str, Tuple[str, Any]] CONFIG_BASE_FILENAME = "" # type: Optional[str] @classmethod def parse_config_file( cls, base_filename=None, config_filename=None, additional_configs=None, global_config=True, ): """Parse the configuration file associated to the current class. By default, parse_config_file will load the configuration cls.CONFIG_BASE_FILENAME from one of the Software Heritage configuration directories, in order, unless it is overridden by base_filename or config_filename (which shortcuts the file lookup completely). Args: - base_filename (str): overrides the default cls.CONFIG_BASE_FILENAME - config_filename (str): sets the file to parse instead of the defaults set from cls.CONFIG_BASE_FILENAME - additional_configs: (list of default configuration dicts) allows to override or extend the configuration set in cls.DEFAULT_CONFIG. - global_config (bool): Load the global configuration (default: True) """ if config_filename: config_filenames = [config_filename] elif "SWH_CONFIG_FILENAME" in os.environ: config_filenames = [os.environ["SWH_CONFIG_FILENAME"]] else: if not base_filename: base_filename = cls.CONFIG_BASE_FILENAME config_filenames = swh_config_paths(base_filename) if not additional_configs: additional_configs = [] full_default_config = merge_default_configs( cls.DEFAULT_CONFIG, *additional_configs ) config = {} if global_config: config = load_global_config() config.update(priority_read(config_filenames, full_default_config)) return config diff --git a/swh/core/db/__init__.py b/swh/core/db/__init__.py index a98d5e9..eb1341f 100644 --- a/swh/core/db/__init__.py +++ b/swh/core/db/__init__.py @@ -1,299 +1,297 @@ # Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from contextlib import contextmanager import datetime import enum import json import logging import os import sys import threading from typing import Any, Callable, Iterable, Iterator, Mapping, Optional, Type, TypeVar -from contextlib import contextmanager - import psycopg2 import psycopg2.extras import psycopg2.pool - logger = logging.getLogger(__name__) psycopg2.extras.register_uuid() def render_array(data) -> str: """Render the data as a postgresql array""" # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-IO # "The external text representation of an array value consists of items that are # interpreted according to the I/O conversion rules for the array's element type, # plus decoration that indicates the array structure. The decoration consists of # curly braces ({ and }) around the array value plus delimiter characters between # adjacent items. The delimiter character is usually a comma (,)" return "{%s}" % ",".join(render_array_element(e) for e in data) def render_array_element(element) -> str: """Render an element from an array.""" if element is None: # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-IO # "If the value written for an element is NULL (in any case variant), the # element is taken to be NULL." return "NULL" elif isinstance(element, (list, tuple)): # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-INPUT # "Each val is either a constant of the array element type, or a subarray." return render_array(element) else: # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-IO # "When writing an array value you can use double quotes around any individual # array element. [...] Empty strings and strings matching the word NULL must be # quoted, too. To put a double quote or backslash in a quoted array element # value, precede it with a backslash." ret = value_as_pg_text(element) return '"%s"' % ret.replace("\\", "\\\\").replace('"', '\\"') def value_as_pg_text(data: Any) -> str: """Render the given data in the postgresql text format. NULL values are handled **outside** of this function (either by :func:`render_array_element`, or by :meth:`BaseDb.copy_to`.) """ if data is None: raise ValueError("value_as_pg_text doesn't handle NULLs") if isinstance(data, bytes): return "\\x%s" % data.hex() elif isinstance(data, datetime.datetime): return data.isoformat() elif isinstance(data, dict): return json.dumps(data) elif isinstance(data, (list, tuple)): return render_array(data) elif isinstance(data, psycopg2.extras.Range): return "%s%s,%s%s" % ( "[" if data.lower_inc else "(", "-infinity" if data.lower_inf else value_as_pg_text(data.lower), "infinity" if data.upper_inf else value_as_pg_text(data.upper), "]" if data.upper_inc else ")", ) elif isinstance(data, enum.IntEnum): return str(int(data)) else: return str(data) def escape_copy_column(column: str) -> str: """Escape the text representation of a column for use by COPY.""" # From https://www.postgresql.org/docs/11/sql-copy.html # File Formats > Text Format # "Backslash characters (\) can be used in the COPY data to quote data characters # that might otherwise be taken as row or column delimiters. In particular, the # following characters must be preceded by a backslash if they appear as part of a # column value: backslash itself, newline, carriage return, and the current # delimiter character." ret = ( column.replace("\\", "\\\\") .replace("\n", "\\n") .replace("\r", "\\r") .replace("\t", "\\t") ) return ret def typecast_bytea(value, cur): if value is not None: data = psycopg2.BINARY(value, cur) return data.tobytes() BaseDbType = TypeVar("BaseDbType", bound="BaseDb") class BaseDb: """Base class for swh.*.*Db. cf. swh.storage.db.Db, swh.archiver.db.ArchiverDb """ @staticmethod def adapt_conn(conn: psycopg2.extensions.connection): """Makes psycopg2 use 'bytes' to decode bytea instead of 'memoryview', for this connection.""" t_bytes = psycopg2.extensions.new_type((17,), "bytea", typecast_bytea) psycopg2.extensions.register_type(t_bytes, conn) t_bytes_array = psycopg2.extensions.new_array_type((1001,), "bytea[]", t_bytes) psycopg2.extensions.register_type(t_bytes_array, conn) @classmethod def connect(cls: Type[BaseDbType], *args, **kwargs) -> BaseDbType: """factory method to create a DB proxy Accepts all arguments of psycopg2.connect; only some specific possibilities are reported below. Args: connstring: libpq2 connection string """ conn = psycopg2.connect(*args, **kwargs) return cls(conn) @classmethod def from_pool( cls: Type[BaseDbType], pool: psycopg2.pool.AbstractConnectionPool ) -> BaseDbType: conn = pool.getconn() return cls(conn, pool=pool) def __init__( self, conn: psycopg2.extensions.connection, pool: Optional[psycopg2.pool.AbstractConnectionPool] = None, ): """create a DB proxy Args: conn: psycopg2 connection to the SWH DB pool: psycopg2 pool of connections """ self.adapt_conn(conn) self.conn = conn self.pool = pool def put_conn(self) -> None: if self.pool: self.pool.putconn(self.conn) def cursor( self, cur_arg: Optional[psycopg2.extensions.cursor] = None ) -> psycopg2.extensions.cursor: """get a cursor: from cur_arg if given, or a fresh one otherwise meant to avoid boilerplate if/then/else in methods that proxy stored procedures """ if cur_arg is not None: return cur_arg else: return self.conn.cursor() _cursor = cursor # for bw compat @contextmanager def transaction(self) -> Iterator[psycopg2.extensions.cursor]: """context manager to execute within a DB transaction Yields: a psycopg2 cursor """ with self.conn.cursor() as cur: try: yield cur self.conn.commit() except Exception: if not self.conn.closed: self.conn.rollback() raise def copy_to( self, items: Iterable[Mapping[str, Any]], tblname: str, columns: Iterable[str], cur: Optional[psycopg2.extensions.cursor] = None, item_cb: Optional[Callable[[Any], Any]] = None, default_values: Optional[Mapping[str, Any]] = None, ) -> None: """Run the COPY command to insert the `columns` of each element of `items` into `tblname`. Args: items: dictionaries of data to copy into `tblname`. tblname: name of the destination table. columns: columns of the destination table. Elements of `items` must have these set as keys. default_values: dictionary of default values to use when inserting entries in `tblname`. cur: a db cursor; if not given, a new cursor will be created. item_cb: optional callback, run on each element of `items`, when it is copied. """ if default_values is None: default_values = {} read_file, write_file = os.pipe() exc_info = None def writer(): nonlocal exc_info cursor = self.cursor(cur) with open(read_file, "r") as f: try: cursor.copy_expert( "COPY %s (%s) FROM STDIN" % (tblname, ", ".join(columns)), f ) except Exception: # Tell the main thread about the exception exc_info = sys.exc_info() write_thread = threading.Thread(target=writer) write_thread.start() try: with open(write_file, "w") as f: # From https://www.postgresql.org/docs/11/sql-copy.html # File Formats > Text Format # "When the text format is used, the data read or written is a text file # with one line per table row. Columns in a row are separated by the # delimiter character." # NULL # "The default is \N (backslash-N) in text format." # DELIMITER # "The default is a tab character in text format." for d in items: if item_cb is not None: item_cb(d) line = [] for k in columns: value = d.get(k, default_values.get(k)) try: if value is None: line.append("\\N") else: line.append(escape_copy_column(value_as_pg_text(value))) except Exception as e: logger.error( "Could not escape value `%r` for column `%s`:" "Received exception: `%s`", value, k, e, ) raise e from None f.write("\t".join(line)) f.write("\n") finally: # No problem bubbling up exceptions, but we still need to make sure # we finish copying, even though we're probably going to cancel the # transaction. write_thread.join() if exc_info: # postgresql returned an error, let's raise it. raise exc_info[1].with_traceback(exc_info[2]) def mktemp(self, tblname: str, cur: Optional[psycopg2.extensions.cursor] = None): self.cursor(cur).execute("SELECT swh_mktemp(%s)", (tblname,)) diff --git a/swh/core/db/common.py b/swh/core/db/common.py index 17c46be..65e6d75 100644 --- a/swh/core/db/common.py +++ b/swh/core/db/common.py @@ -1,103 +1,103 @@ # Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import inspect import functools +import inspect def remove_kwargs(names): def decorator(f): sig = inspect.signature(f) params = sig.parameters params = [param for param in params.values() if param.name not in names] sig = sig.replace(parameters=params) f.__signature__ = sig return f return decorator def apply_options(cursor, options): """Applies the given postgresql client options to the given cursor. Returns a dictionary with the old values if they changed.""" old_options = {} for option, value in options.items(): cursor.execute("SHOW %s" % option) old_value = cursor.fetchall()[0][0] if old_value != value: cursor.execute("SET LOCAL %s TO %%s" % option, (value,)) old_options[option] = old_value return old_options def db_transaction(**client_options): """decorator to execute Backend methods within DB transactions The decorated method must accept a `cur` and `db` keyword argument Client options are passed as `set` options to the postgresql server """ def decorator(meth, __client_options=client_options): if inspect.isgeneratorfunction(meth): raise ValueError("Use db_transaction_generator for generator functions.") @remove_kwargs(["cur", "db"]) @functools.wraps(meth) def _meth(self, *args, **kwargs): if "cur" in kwargs and kwargs["cur"]: cur = kwargs["cur"] old_options = apply_options(cur, __client_options) ret = meth(self, *args, **kwargs) apply_options(cur, old_options) return ret else: db = self.get_db() try: with db.transaction() as cur: apply_options(cur, __client_options) return meth(self, *args, db=db, cur=cur, **kwargs) finally: self.put_db(db) return _meth return decorator def db_transaction_generator(**client_options): """decorator to execute Backend methods within DB transactions, while returning a generator The decorated method must accept a `cur` and `db` keyword argument Client options are passed as `set` options to the postgresql server """ def decorator(meth, __client_options=client_options): if not inspect.isgeneratorfunction(meth): raise ValueError("Use db_transaction for non-generator functions.") @remove_kwargs(["cur", "db"]) @functools.wraps(meth) def _meth(self, *args, **kwargs): if "cur" in kwargs and kwargs["cur"]: cur = kwargs["cur"] old_options = apply_options(cur, __client_options) yield from meth(self, *args, **kwargs) apply_options(cur, old_options) else: db = self.get_db() try: with db.transaction() as cur: apply_options(cur, __client_options) yield from meth(self, *args, db=db, cur=cur, **kwargs) finally: self.put_db(db) return _meth return decorator diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py index aa97a93..d656a5d 100644 --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -1,151 +1,151 @@ # Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # # This code has been imported from psycopg2, version 2.7.4, # https://github.com/psycopg/psycopg2/tree/5afb2ce803debea9533e293eef73c92ffce95bcd # and modified by Software Heritage. # # Original file: lib/extras.py # # psycopg2 is free software: you can redistribute it and/or modify it under the # terms of the GNU Lesser General Public License as published by the Free # Software Foundation, either version 3 of the License, or (at your option) any # later version. -import re import functools +import re import psycopg2.extensions def stored_procedure(stored_proc): """decorator to execute remote stored procedure, specified as argument Generally, the body of the decorated function should be empty. If it is not, the stored procedure will be executed first; the function body then. """ def wrap(meth): @functools.wraps(meth) def _meth(self, *args, **kwargs): cur = kwargs.get("cur", None) self._cursor(cur).execute("SELECT %s()" % stored_proc) meth(self, *args, **kwargs) return _meth return wrap def jsonize(value): """Convert a value to a psycopg2 JSON object if necessary""" if isinstance(value, dict): return psycopg2.extras.Json(value) return value def _paginate(seq, page_size): """Consume an iterable and return it in chunks. Every chunk is at most `page_size`. Never return an empty chunk. """ page = [] it = iter(seq) while 1: try: for i in range(page_size): page.append(next(it)) yield page page = [] except StopIteration: if page: yield page return def _split_sql(sql): """Split *sql* on a single ``%s`` placeholder. Split on the %s, perform %% replacement and return pre, post lists of snippets. """ curr = pre = [] post = [] tokens = re.split(br"(%.)", sql) for token in tokens: if len(token) != 2 or token[:1] != b"%": curr.append(token) continue if token[1:] == b"s": if curr is pre: curr = post else: raise ValueError("the query contains more than one '%s' placeholder") elif token[1:] == b"%": curr.append(b"%") else: raise ValueError( "unsupported format character: '%s'" % token[1:].decode("ascii", "replace") ) if curr is pre: raise ValueError("the query doesn't contain any '%s' placeholder") return pre, post def execute_values_generator(cur, sql, argslist, template=None, page_size=100): """Execute a statement using SQL ``VALUES`` with a sequence of parameters. Rows returned by the query are returned through a generator. You need to consume the generator for the queries to be executed! :param cur: the cursor to use to execute the query. :param sql: the query to execute. It must contain a single ``%s`` placeholder, which will be replaced by a `VALUES list`__. Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``. :param argslist: sequence of sequences or dictionaries with the arguments to send to the query. The type and content must be consistent with *template*. :param template: the snippet to merge to every item in *argslist* to compose the query. - If the *argslist* items are sequences it should contain positional placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there are constants value...). - If the *argslist* items are mappings it should contain named placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``). If not specified, assume the arguments are sequence and use a simple positional template (i.e. ``(%s, %s, ...)``), with the number of placeholders sniffed by the first element in *argslist*. :param page_size: maximum number of *argslist* items to include in every statement. If there are more items the function will execute more than one statement. :param yield_from_cur: Whether to yield results from the cursor in this function directly. .. __: https://www.postgresql.org/docs/current/static/queries-values.html After the execution of the function the `cursor.rowcount` property will **not** contain a total result. """ # we can't just use sql % vals because vals is bytes: if sql is bytes # there will be some decoding error because of stupid codec used, and Py3 # doesn't implement % on bytes. if not isinstance(sql, bytes): sql = sql.encode(psycopg2.extensions.encodings[cur.connection.encoding]) pre, post = _split_sql(sql) for page in _paginate(argslist, page_size=page_size): if template is None: template = b"(" + b",".join([b"%s"] * len(page[0])) + b")" parts = pre[:] for args in page: parts.append(cur.mogrify(template, args)) parts.append(b",") parts[-1:] = post cur.execute(b"".join(parts)) yield from cur diff --git a/swh/core/db/tests/db_testing.py b/swh/core/db/tests/db_testing.py index 89bd234..9f6c01b 100644 --- a/swh/core/db/tests/db_testing.py +++ b/swh/core/db/tests/db_testing.py @@ -1,344 +1,342 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import os import glob +import os import subprocess +from typing import Dict, Iterable, Optional, Tuple, Union import psycopg2 -from typing import Dict, Iterable, Optional, Tuple, Union - from swh.core.utils import numfile_sortkey as sortkey - DB_DUMP_TYPES = {".sql": "psql", ".dump": "pg_dump"} # type: Dict[str, str] def swh_db_version(dbname_or_service): """Retrieve the swh version if any. In case of the db not initialized, this returns None. Otherwise, this returns the db's version. Args: dbname_or_service (str): The db's name or service Returns: Optional[Int]: Either the db's version or None """ query = "select version from dbversion order by dbversion desc limit 1" cmd = [ "psql", "--tuples-only", "--no-psqlrc", "--quiet", "-v", "ON_ERROR_STOP=1", "--command=%s" % query, dbname_or_service, ] try: r = subprocess.run( cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, ) result = int(r.stdout.strip()) except Exception: # db not initialized result = None return result def pg_restore(dbname, dumpfile, dumptype="pg_dump"): """ Args: dbname: name of the DB to restore into dumpfile: path of the dump file dumptype: one of 'pg_dump' (for binary dumps), 'psql' (for SQL dumps) """ assert dumptype in ["pg_dump", "psql"] if dumptype == "pg_dump": subprocess.check_call( [ "pg_restore", "--no-owner", "--no-privileges", "--dbname", dbname, dumpfile, ] ) elif dumptype == "psql": subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-f", dumpfile, dbname, ] ) def pg_dump(dbname, dumpfile): subprocess.check_call( ["pg_dump", "--no-owner", "--no-privileges", "-Fc", "-f", dumpfile, dbname] ) def pg_dropdb(dbname): subprocess.check_call(["dropdb", dbname]) def pg_createdb(dbname, check=True): """Create a db. If check is True and the db already exists, this will raise an exception (original behavior). If check is False and the db already exists, this will fail silently. If the db does not exist, the db will be created. """ subprocess.run(["createdb", dbname], check=check) def db_create(dbname, dumps=None): """create the test DB and load the test data dumps into it dumps is an iterable of couples (dump_file, dump_type). context: setUpClass """ try: pg_createdb(dbname) except subprocess.CalledProcessError: # try recovering once, in case pg_dropdb(dbname) # the db already existed pg_createdb(dbname) for dump, dtype in dumps: pg_restore(dbname, dump, dtype) return dbname def db_destroy(dbname): """destroy the test DB context: tearDownClass """ pg_dropdb(dbname) def db_connect(dbname): """connect to the test DB and open a cursor context: setUp """ conn = psycopg2.connect("dbname=" + dbname) return {"conn": conn, "cursor": conn.cursor()} def db_close(conn): """rollback current transaction and disconnect from the test DB context: tearDown """ if not conn.closed: conn.rollback() conn.close() class DbTestConn: def __init__(self, dbname): self.dbname = dbname def __enter__(self): self.db_setup = db_connect(self.dbname) self.conn = self.db_setup["conn"] self.cursor = self.db_setup["cursor"] return self def __exit__(self, *_): db_close(self.conn) class DbTestContext: def __init__(self, name="softwareheritage-test", dumps=None): self.dbname = name self.dumps = dumps def __enter__(self): db_create(dbname=self.dbname, dumps=self.dumps) return self def __exit__(self, *_): db_destroy(self.dbname) class DbTestFixture: """Mix this in a test subject class to get DB testing support. Use the class method add_db() to add a new database to be tested. Using this will create a DbTestConn entry in the `test_db` dictionary for all the tests, indexed by the name of the database. Example: class TestDb(DbTestFixture, unittest.TestCase): @classmethod def setUpClass(cls): cls.add_db('db_name', DUMP) super().setUpClass() def setUp(self): db = self.test_db['db_name'] print('conn: {}, cursor: {}'.format(db.conn, db.cursor)) To ensure test isolation, each test method of the test case class will execute in its own connection, cursor, and transaction. Note that if you want to define setup/teardown methods, you need to explicitly call super() to ensure that the fixture setup/teardown methods are invoked. Here is an example where all setup/teardown methods are defined in a test case: class TestDb(DbTestFixture, unittest.TestCase): @classmethod def setUpClass(cls): # your add_db() calls here super().setUpClass() # your class setup code here def setUp(self): super().setUp() # your instance setup code here def tearDown(self): # your instance teardown code here super().tearDown() @classmethod def tearDownClass(cls): # your class teardown code here super().tearDownClass() """ _DB_DUMP_LIST = {} # type: Dict[str, Iterable[Tuple[str, str]]] _DB_LIST = {} # type: Dict[str, DbTestContext] DB_TEST_FIXTURE_IMPORTED = True @classmethod def add_db(cls, name="softwareheritage-test", dumps=None): cls._DB_DUMP_LIST[name] = dumps @classmethod def setUpClass(cls): for name, dumps in cls._DB_DUMP_LIST.items(): cls._DB_LIST[name] = DbTestContext(name, dumps) cls._DB_LIST[name].__enter__() super().setUpClass() @classmethod def tearDownClass(cls): super().tearDownClass() for name, context in cls._DB_LIST.items(): context.__exit__() def setUp(self, *args, **kwargs): self.test_db = {} for name in self._DB_LIST.keys(): self.test_db[name] = DbTestConn(name) self.test_db[name].__enter__() super().setUp(*args, **kwargs) def tearDown(self): super().tearDown() for name in self._DB_LIST.keys(): self.test_db[name].__exit__() def reset_db_tables(self, name, excluded=None): db = self.test_db[name] conn = db.conn cursor = db.cursor cursor.execute( """SELECT table_name FROM information_schema.tables WHERE table_schema = %s""", ("public",), ) tables = set(table for (table,) in cursor.fetchall()) if excluded is not None: tables -= set(excluded) for table in tables: cursor.execute("truncate table %s cascade" % table) conn.commit() class SingleDbTestFixture(DbTestFixture): """Simplified fixture like DbTest but that can only handle a single DB. Gives access to shortcuts like self.cursor and self.conn. DO NOT use this with other fixtures that need to access databases, like StorageTestFixture. The class can override the following class attributes: TEST_DB_NAME: name of the DB used for testing TEST_DB_DUMP: DB dump to be restored before running test methods; can be set to None if no restore from dump is required. If the dump file name endswith" - '.sql' it will be loaded via psql, - '.dump' it will be loaded via pg_restore. Other file extensions will be ignored. Can be a string or a list of strings; each path will be expanded using glob pattern matching. The test case class will then have the following attributes, accessible via self: dbname: name of the test database conn: psycopg2 connection object cursor: open psycopg2 cursor to the DB """ TEST_DB_NAME = "softwareheritage-test" TEST_DB_DUMP = None # type: Optional[Union[str, Iterable[str]]] @classmethod def setUpClass(cls): cls.dbname = cls.TEST_DB_NAME # XXX to kill? dump_files = cls.TEST_DB_DUMP if dump_files is None: dump_files = [] elif isinstance(dump_files, str): dump_files = [dump_files] all_dump_files = [] for files in dump_files: all_dump_files.extend(sorted(glob.glob(files), key=sortkey)) all_dump_files = [ (x, DB_DUMP_TYPES[os.path.splitext(x)[1]]) for x in all_dump_files ] cls.add_db(name=cls.TEST_DB_NAME, dumps=all_dump_files) super().setUpClass() def setUp(self, *args, **kwargs): super().setUp(*args, **kwargs) db = self.test_db[self.TEST_DB_NAME] self.conn = db.conn self.cursor = db.cursor diff --git a/swh/core/db/tests/test_cli.py b/swh/core/db/tests/test_cli.py index d815271..87b40eb 100644 --- a/swh/core/db/tests/test_cli.py +++ b/swh/core/db/tests/test_cli.py @@ -1,57 +1,56 @@ # from click.testing import CliRunner from swh.core.cli.db import db as swhdb - help_msg = """Usage: swh [OPTIONS] COMMAND [ARGS]... Command line interface for Software Heritage. Options: -l, --log-level [NOTSET|DEBUG|INFO|WARNING|ERROR|CRITICAL] Log level (defaults to INFO). --log-config FILENAME Python yaml logging configuration file. --sentry-dsn TEXT DSN of the Sentry instance to report to -h, --help Show this message and exit. Notes: If both options are present, --log-level will override the root logger configuration set in --log-config. The --log-config YAML must conform to the logging.config.dictConfig schema documented at https://docs.python.org/3/library/logging.config.html. Commands: db Software Heritage database generic tools. """ def test_swh_help(swhmain): swhmain.add_command(swhdb) runner = CliRunner() result = runner.invoke(swhmain, ["-h"]) assert result.exit_code == 0 assert result.output == help_msg help_db_msg = """Usage: swh db [OPTIONS] COMMAND [ARGS]... Software Heritage database generic tools. Options: -C, --config-file FILE Configuration file. -h, --help Show this message and exit. Commands: init Initialize the database for every Software Heritage module found in... """ def test_swh_db_help(swhmain): swhmain.add_command(swhdb) runner = CliRunner() result = runner.invoke(swhmain, ["db", "-h"]) assert result.exit_code == 0 assert result.output == help_db_msg diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py index 93385d4..bb7f80b 100644 --- a/swh/core/db/tests/test_db.py +++ b/swh/core/db/tests/test_db.py @@ -1,438 +1,434 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from dataclasses import dataclass import datetime from enum import IntEnum import inspect import os.path from string import printable import tempfile from typing import Any -from typing_extensions import Protocol import unittest -from unittest.mock import Mock, MagicMock +from unittest.mock import MagicMock, Mock import uuid -from hypothesis import strategies, given +from hypothesis import given, strategies from hypothesis.extra.pytz import timezones import psycopg2 import pytest +from typing_extensions import Protocol from swh.core.db import BaseDb from swh.core.db.common import db_transaction, db_transaction_generator -from .db_testing import ( - SingleDbTestFixture, - db_create, - db_destroy, - db_close, -) + +from .db_testing import SingleDbTestFixture, db_close, db_create, db_destroy # workaround mypy bug https://github.com/python/mypy/issues/5485 class Converter(Protocol): def __call__(self, x: Any) -> Any: ... @dataclass class Field: name: str """Column name""" pg_type: str """Type of the PostgreSQL column""" example: Any """Example value for the static tests""" strategy: strategies.SearchStrategy """Hypothesis strategy to generate these values""" in_wrapper: Converter = lambda x: x """Wrapper to convert this data type for the static tests""" out_converter: Converter = lambda x: x """Converter from the raw PostgreSQL column value to this data type""" # Limit PostgreSQL integer values pg_int = strategies.integers(-2147483648, +2147483647) pg_text = strategies.text( alphabet=strategies.characters( blacklist_categories=["Cs"], # surrogates blacklist_characters=[ "\x00", # pgsql does not support the null codepoint "\r", # pgsql normalizes those ], ), ) pg_bytea = strategies.binary() def pg_bytea_a(min_size: int, max_size: int) -> strategies.SearchStrategy: """Generate a PostgreSQL bytea[]""" return strategies.lists(pg_bytea, min_size=min_size, max_size=max_size) def pg_bytea_a_a(min_size: int, max_size: int) -> strategies.SearchStrategy: """Generate a PostgreSQL bytea[][]. The inner lists must all have the same size.""" return strategies.integers(min_value=max(1, min_size), max_value=max_size).flatmap( lambda n: strategies.lists( pg_bytea_a(min_size=n, max_size=n), min_size=min_size, max_size=max_size ) ) def pg_tstz() -> strategies.SearchStrategy: """Generate values that fit in a PostgreSQL timestamptz. Notes: We're forbidding old datetimes, because until 1956, many timezones had seconds in their "UTC offsets" (see ), which is not representable by PostgreSQL. """ min_value = datetime.datetime(1960, 1, 1, 0, 0, 0) return strategies.datetimes(min_value=min_value, timezones=timezones()) def pg_jsonb(min_size: int, max_size: int) -> strategies.SearchStrategy: """Generate values representable as a PostgreSQL jsonb object (dict).""" return strategies.dictionaries( strategies.text(printable), strategies.recursive( # should use floats() instead of integers(), but PostgreSQL # coerces large integers into floats, making the tests fail. We # only store ints in our generated data anyway. strategies.none() | strategies.booleans() | strategies.integers(-2147483648, +2147483647) | strategies.text(printable), lambda children: strategies.lists(children, max_size=max_size) | strategies.dictionaries( strategies.text(printable), children, max_size=max_size ), ), min_size=min_size, max_size=max_size, ) def tuple_2d_to_list_2d(v): """Convert a 2D tuple to a 2D list""" return [list(inner) for inner in v] def list_2d_to_tuple_2d(v): """Convert a 2D list to a 2D tuple""" return tuple(tuple(inner) for inner in v) class TestIntEnum(IntEnum): foo = 1 bar = 2 def now(): return datetime.datetime.now(tz=datetime.timezone.utc) FIELDS = ( Field("i", "int", 1, pg_int), Field("txt", "text", "foo", pg_text), Field("bytes", "bytea", b"bar", strategies.binary()), Field( "bytes_array", "bytea[]", [b"baz1", b"baz2"], pg_bytea_a(min_size=0, max_size=5), ), Field( "bytes_tuple", "bytea[]", (b"baz1", b"baz2"), pg_bytea_a(min_size=0, max_size=5).map(tuple), in_wrapper=list, out_converter=tuple, ), Field( "bytes_2d", "bytea[][]", [[b"quux1"], [b"quux2"]], pg_bytea_a_a(min_size=0, max_size=5), ), Field( "bytes_2d_tuple", "bytea[][]", ((b"quux1",), (b"quux2",)), pg_bytea_a_a(min_size=0, max_size=5).map(list_2d_to_tuple_2d), in_wrapper=tuple_2d_to_list_2d, out_converter=list_2d_to_tuple_2d, ), Field("ts", "timestamptz", now(), pg_tstz(),), Field( "dict", "jsonb", {"str": "bar", "int": 1, "list": ["a", "b"], "nested": {"a": "b"}}, pg_jsonb(min_size=0, max_size=5), in_wrapper=psycopg2.extras.Json, ), Field( "intenum", "int", TestIntEnum.foo, strategies.sampled_from(TestIntEnum), in_wrapper=int, out_converter=TestIntEnum, ), Field("uuid", "uuid", uuid.uuid4(), strategies.uuids()), Field( "text_list", "text[]", # All the funky corner cases ["null", "NULL", None, "\\", "\t", "\n", "\r", " ", "'", ",", '"', "{", "}"], strategies.lists(pg_text, min_size=0, max_size=5), ), Field( "tstz_list", "timestamptz[]", [now(), now() + datetime.timedelta(days=1)], strategies.lists(pg_tstz(), min_size=0, max_size=5), ), Field( "tstz_range", "tstzrange", psycopg2.extras.DateTimeTZRange( lower=now(), upper=now() + datetime.timedelta(days=1), bounds="[)", ), strategies.tuples( # generate two sorted timestamptzs for use as bounds strategies.tuples(pg_tstz(), pg_tstz()).map(sorted), # and a set of bounds strategies.sampled_from(["[]", "()", "[)", "(]"]), ).map( # and build the actual DateTimeTZRange object from these args lambda args: psycopg2.extras.DateTimeTZRange( lower=args[0][0], upper=args[0][1], bounds=args[1], ) ), ), ) INIT_SQL = "create table test_table (%s)" % ", ".join( f"{field.name} {field.pg_type}" for field in FIELDS ) COLUMNS = tuple(field.name for field in FIELDS) INSERT_SQL = "insert into test_table (%s) values (%s)" % ( ", ".join(COLUMNS), ", ".join("%s" for i in range(len(COLUMNS))), ) STATIC_ROW_IN = tuple(field.in_wrapper(field.example) for field in FIELDS) EXPECTED_ROW_OUT = tuple(field.example for field in FIELDS) db_rows = strategies.lists(strategies.tuples(*(field.strategy for field in FIELDS))) def convert_lines(cur): return [ tuple(field.out_converter(x) for x, field in zip(line, FIELDS)) for line in cur ] @pytest.mark.db def test_connect(): db_name = db_create("test-db2", dumps=[]) try: db = BaseDb.connect("dbname=%s" % db_name) with db.cursor() as cur: psycopg2.extras.register_default_jsonb(cur) cur.execute(INIT_SQL) cur.execute(INSERT_SQL, STATIC_ROW_IN) cur.execute("select * from test_table;") output = convert_lines(cur) assert len(output) == 1 assert EXPECTED_ROW_OUT == output[0] finally: db_close(db.conn) db_destroy(db_name) @pytest.mark.db class TestDb(SingleDbTestFixture, unittest.TestCase): TEST_DB_NAME = "test-db" @classmethod def setUpClass(cls): with tempfile.TemporaryDirectory() as td: with open(os.path.join(td, "init.sql"), "a") as fd: fd.write(INIT_SQL) cls.TEST_DB_DUMP = os.path.join(td, "*.sql") super().setUpClass() def setUp(self): super().setUp() self.db = BaseDb(self.conn) def test_initialized(self): cur = self.db.cursor() psycopg2.extras.register_default_jsonb(cur) cur.execute(INSERT_SQL, STATIC_ROW_IN) cur.execute("select * from test_table;") output = convert_lines(cur) assert len(output) == 1 assert EXPECTED_ROW_OUT == output[0] def test_reset_tables(self): cur = self.db.cursor() cur.execute(INSERT_SQL, STATIC_ROW_IN) self.reset_db_tables("test-db") cur.execute("select * from test_table;") assert convert_lines(cur) == [] def test_copy_to_static(self): items = [{field.name: field.example for field in FIELDS}] self.db.copy_to(items, "test_table", COLUMNS) cur = self.db.cursor() cur.execute("select * from test_table;") output = convert_lines(cur) assert len(output) == 1 assert EXPECTED_ROW_OUT == output[0] @given(db_rows) def test_copy_to(self, data): try: # the table is not reset between runs by hypothesis self.reset_db_tables("test-db") items = [dict(zip(COLUMNS, item)) for item in data] self.db.copy_to(items, "test_table", COLUMNS) cur = self.db.cursor() cur.execute("select * from test_table;") assert convert_lines(cur) == data finally: self.db.conn.rollback() def test_copy_to_thread_exception(self): data = [(2 ** 65, "foo", b"bar")] items = [dict(zip(COLUMNS, item)) for item in data] with self.assertRaises(psycopg2.errors.NumericValueOutOfRange): self.db.copy_to(items, "test_table", COLUMNS) def test_db_transaction(mocker): expected_cur = object() called = False class Storage: @db_transaction() def endpoint(self, cur=None, db=None): nonlocal called called = True assert cur is expected_cur storage = Storage() # 'with storage.get_db().transaction() as cur:' should cause # 'cur' to be 'expected_cur' db_mock = Mock() db_mock.transaction.return_value = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) put_db_mock = mocker.patch.object(storage, "put_db", create=True) storage.endpoint() assert called put_db_mock.assert_called_once_with(db_mock) def test_db_transaction__with_generator(): with pytest.raises(ValueError, match="generator"): class Storage: @db_transaction() def endpoint(self, cur=None, db=None): yield None def test_db_transaction_signature(): """Checks db_transaction removes the 'cur' and 'db' arguments.""" def f(self, foo, *, bar=None): pass expected_sig = inspect.signature(f) @db_transaction() def g(self, foo, *, bar=None, db=None, cur=None): pass actual_sig = inspect.signature(g) assert actual_sig == expected_sig def test_db_transaction_generator(mocker): expected_cur = object() called = False class Storage: @db_transaction_generator() def endpoint(self, cur=None, db=None): nonlocal called called = True assert cur is expected_cur yield None storage = Storage() # 'with storage.get_db().transaction() as cur:' should cause # 'cur' to be 'expected_cur' db_mock = Mock() db_mock.transaction.return_value = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) put_db_mock = mocker.patch.object(storage, "put_db", create=True) list(storage.endpoint()) assert called put_db_mock.assert_called_once_with(db_mock) def test_db_transaction_generator__with_nongenerator(): with pytest.raises(ValueError, match="generator"): class Storage: @db_transaction_generator() def endpoint(self, cur=None, db=None): pass def test_db_transaction_generator_signature(): """Checks db_transaction removes the 'cur' and 'db' arguments.""" def f(self, foo, *, bar=None): pass expected_sig = inspect.signature(f) @db_transaction_generator() def g(self, foo, *, bar=None, db=None, cur=None): yield None actual_sig = inspect.signature(g) assert actual_sig == expected_sig diff --git a/swh/core/logger.py b/swh/core/logger.py index f0163a6..1c88883 100644 --- a/swh/core/logger.py +++ b/swh/core/logger.py @@ -1,117 +1,118 @@ # Copyright (C) 2015 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import logging from typing import Any, Generator, List, Tuple -from systemd.journal import JournalHandler as _JournalHandler, send +from systemd.journal import JournalHandler as _JournalHandler +from systemd.journal import send try: from celery import current_task except ImportError: current_task = None EXTRA_LOGDATA_PREFIX = "swh_" def db_level_of_py_level(lvl): """convert a log level of the logging module to a log level suitable for the logging Postgres DB """ return logging.getLevelName(lvl).lower() def get_extra_data(record, task_args=True): """Get the extra data to insert to the database from the logging record""" log_data = record.__dict__ extra_data = { k[len(EXTRA_LOGDATA_PREFIX) :]: v for k, v in log_data.items() if k.startswith(EXTRA_LOGDATA_PREFIX) } args = log_data.get("args") if args: extra_data["logging_args"] = args # Retrieve Celery task info if current_task and current_task.request: extra_data["task"] = { "id": current_task.request.id, "name": current_task.name, } if task_args: extra_data["task"].update( { "kwargs": current_task.request.kwargs, "args": current_task.request.args, } ) return extra_data def flatten(data: Any, separator: str = "_") -> Generator[Tuple[str, Any], None, None]: """Flatten the data dictionary into a flat structure""" def inner_flatten( data: Any, prefix: List[str] ) -> Generator[Tuple[List[str], Any], None, None]: if isinstance(data, dict): if all(isinstance(key, str) for key in data): for key, value in data.items(): yield from inner_flatten(value, prefix + [key]) else: yield prefix, str(data) elif isinstance(data, (list, tuple)): for key, value in enumerate(data): yield from inner_flatten(value, prefix + [str(key)]) else: yield prefix, data for path, value in inner_flatten(data, []): yield separator.join(path), value def stringify(value): """Convert value to string""" if isinstance(value, datetime.datetime): return value.isoformat() return str(value) class JournalHandler(_JournalHandler): def emit(self, record): """Write `record` as a journal event. MESSAGE is taken from the message provided by the user, and PRIORITY, LOGGER, THREAD_NAME, CODE_{FILE,LINE,FUNC} fields are appended automatically. In addition, record.MESSAGE_ID will be used if present. """ try: extra_data = flatten(get_extra_data(record, task_args=False)) extra_data = { (EXTRA_LOGDATA_PREFIX + key).upper(): stringify(value) for key, value in extra_data } msg = self.format(record) pri = self.mapPriority(record.levelno) send( msg, PRIORITY=format(pri), LOGGER=record.name, THREAD_NAME=record.threadName, CODE_FILE=record.pathname, CODE_LINE=record.lineno, CODE_FUNC=record.funcName, **extra_data, ) except Exception: self.handleError(record) diff --git a/swh/core/pytest_plugin.py b/swh/core/pytest_plugin.py index 233d554..7a45019 100644 --- a/swh/core/pytest_plugin.py +++ b/swh/core/pytest_plugin.py @@ -1,319 +1,317 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import logging -import re -import pytest -import requests - from functools import partial +import logging from os import path +import re from typing import Dict, List, Optional -from urllib.parse import urlparse, unquote +from urllib.parse import unquote, urlparse +import pytest +import requests from requests.adapters import BaseAdapter from requests.structures import CaseInsensitiveDict from requests.utils import get_encoding_from_headers - logger = logging.getLogger(__name__) # Check get_local_factory function # Maximum number of iteration checks to generate requests responses MAX_VISIT_FILES = 10 def get_response_cb( request: requests.Request, context, datadir, ignore_urls: List[str] = [], visits: Optional[Dict] = None, ): """Mount point callback to fetch on disk the request's content. The request urls provided are url decoded first to resolve the associated file on disk. This is meant to be used as 'body' argument of the requests_mock.get() method. It will look for files on the local filesystem based on the requested URL, using the following rules: - files are searched in the datadir/ directory - the local file name is the path part of the URL with path hierarchy markers (aka '/') replaced by '_' Eg. if you use the requests_mock fixture in your test file as: requests_mock.get('https?://nowhere.com', body=get_response_cb) # or even requests_mock.get(re.compile('https?://'), body=get_response_cb) then a call requests.get like: requests.get('https://nowhere.com/path/to/resource?a=b&c=d') will look the content of the response in: datadir/https_nowhere.com/path_to_resource,a=b,c=d or a call requests.get like: requests.get('http://nowhere.com/path/to/resource?a=b&c=d') will look the content of the response in: datadir/http_nowhere.com/path_to_resource,a=b,c=d Args: request: Object requests context (requests.Context): Object holding response metadata information (status_code, headers, etc...) datadir: Data files path ignore_urls: urls whose status response should be 404 even if the local file exists visits: Dict of url, number of visits. If None, disable multi visit support (default) Returns: Optional[FileDescriptor] on disk file to read from the test context """ logger.debug("get_response_cb(%s, %s)", request, context) logger.debug("url: %s", request.url) logger.debug("ignore_urls: %s", ignore_urls) unquoted_url = unquote(request.url) if unquoted_url in ignore_urls: context.status_code = 404 return None url = urlparse(unquoted_url) # http://pypi.org ~> http_pypi.org # https://files.pythonhosted.org ~> https_files.pythonhosted.org dirname = "%s_%s" % (url.scheme, url.hostname) # url.path: pypi//json -> local file: pypi__json filename = url.path[1:] if filename.endswith("/"): filename = filename[:-1] filename = filename.replace("/", "_") if url.query: filename += "," + url.query.replace("&", ",") filepath = path.join(datadir, dirname, filename) if visits is not None: visit = visits.get(url, 0) visits[url] = visit + 1 if visit: filepath = filepath + "_visit%s" % visit if not path.isfile(filepath): logger.debug("not found filepath: %s", filepath) context.status_code = 404 return None fd = open(filepath, "rb") context.headers["content-length"] = str(path.getsize(filepath)) return fd @pytest.fixture def datadir(request): """By default, returns the test directory's data directory. This can be overridden on a per file tree basis. Add an override definition in the local conftest, for example:: import pytest from os import path @pytest.fixture def datadir(): return path.join(path.abspath(path.dirname(__file__)), 'resources') """ return path.join(path.dirname(str(request.fspath)), "data") def requests_mock_datadir_factory( ignore_urls: List[str] = [], has_multi_visit: bool = False ): """This factory generates fixture which allow to look for files on the local filesystem based on the requested URL, using the following rules: - files are searched in the datadir/ directory - the local file name is the path part of the URL with path hierarchy markers (aka '/') replaced by '_' Multiple implementations are possible, for example: - requests_mock_datadir_factory([]): This computes the file name from the query and always returns the same result. - requests_mock_datadir_factory(has_multi_visit=True): This computes the file name from the query and returns the content of the filename the first time, the next call returning the content of files suffixed with _visit1 and so on and so forth. If the file is not found, returns a 404. - requests_mock_datadir_factory(ignore_urls=['url1', 'url2']): This will ignore any files corresponding to url1 and url2, always returning 404. Args: ignore_urls: List of urls to always returns 404 (whether file exists or not) has_multi_visit: Activate or not the multiple visits behavior """ @pytest.fixture def requests_mock_datadir(requests_mock, datadir): if not has_multi_visit: cb = partial(get_response_cb, ignore_urls=ignore_urls, datadir=datadir) requests_mock.get(re.compile("https?://"), body=cb) else: visits = {} requests_mock.get( re.compile("https?://"), body=partial( get_response_cb, ignore_urls=ignore_urls, visits=visits, datadir=datadir, ), ) return requests_mock return requests_mock_datadir # Default `requests_mock_datadir` implementation requests_mock_datadir = requests_mock_datadir_factory([]) # Implementation for multiple visits behavior: # - first time, it checks for a file named `filename` # - second time, it checks for a file named `filename`_visit1 # etc... requests_mock_datadir_visits = requests_mock_datadir_factory(has_multi_visit=True) @pytest.fixture def swh_rpc_client(swh_rpc_client_class, swh_rpc_adapter): """This fixture generates an RPCClient instance that uses the class generated by the rpc_client_class fixture as backend. Since it uses the swh_rpc_adapter, HTTP queries will be intercepted and routed directly to the current Flask app (as provided by the `app` fixture). So this stack of fixtures allows to test the RPCClient -> RPCServerApp communication path using a real RPCClient instance and a real Flask (RPCServerApp) app instance. To use this fixture: - ensure an `app` fixture exists and generate a Flask application, - implement an `swh_rpc_client_class` fixtures that returns the RPCClient-based class to use as client side for the tests, - implement your tests using this `swh_rpc_client` fixture. See swh/core/api/tests/test_rpc_client_server.py for an example of usage. """ url = "mock://example.com" cli = swh_rpc_client_class(url=url) # we need to clear the list of existing adapters here so we ensure we # have one and only one adapter which is then used for all the requests. cli.session.adapters.clear() cli.session.mount("mock://", swh_rpc_adapter) return cli @pytest.fixture def swh_rpc_adapter(app): """Fixture that generates a requests.Adapter instance that can be used to test client/servers code based on swh.core.api classes. See swh/core/api/tests/test_rpc_client_server.py for an example of usage. """ with app.test_client() as client: yield RPCTestAdapter(client) class RPCTestAdapter(BaseAdapter): def __init__(self, client): self._client = client def build_response(self, req, resp): response = requests.Response() # Fallback to None if there's no status_code, for whatever reason. response.status_code = resp.status_code # Make headers case-insensitive. response.headers = CaseInsensitiveDict(getattr(resp, "headers", {})) # Set encoding. response.encoding = get_encoding_from_headers(response.headers) response.raw = resp response.reason = response.raw.status if isinstance(req.url, bytes): response.url = req.url.decode("utf-8") else: response.url = req.url # Give the Response some context. response.request = req response.connection = self response._content = resp.data return response def send(self, request, **kw): """ Overrides ``requests.adapters.BaseAdapter.send`` """ resp = self._client.open( request.url, method=request.method, headers=request.headers.items(), data=request.body, ) return self.build_response(request, resp) @pytest.fixture def flask_app_client(app): with app.test_client() as client: yield client # stolen from pytest-flask, required to have url_for() working within tests # using flask_app_client fixture. @pytest.fixture(autouse=True) def _push_request_context(request): """During tests execution request context has been pushed, e.g. `url_for`, `session`, etc. can be used in tests as is:: def test_app(app, client): assert client.get(url_for('myview')).status_code == 200 """ if "app" not in request.fixturenames: return app = request.getfixturevalue("app") ctx = app.test_request_context() ctx.push() def teardown(): ctx.pop() request.addfinalizer(teardown) diff --git a/swh/core/sentry.py b/swh/core/sentry.py index 2af66a6..75e3e39 100644 --- a/swh/core/sentry.py +++ b/swh/core/sentry.py @@ -1,35 +1,36 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import pkg_resources import os +import pkg_resources + def get_sentry_release(): main_package = os.environ.get("SWH_MAIN_PACKAGE") if main_package: version = pkg_resources.get_distribution(main_package).version return f"{main_package}@{version}" else: return None def init_sentry(sentry_dsn, *, debug=None, integrations=[], extra_kwargs={}): if debug is None: debug = bool(os.environ.get("SWH_SENTRY_DEBUG")) sentry_dsn = sentry_dsn or os.environ.get("SWH_SENTRY_DSN") environment = os.environ.get("SWH_SENTRY_ENVIRONMENT") if sentry_dsn: import sentry_sdk sentry_sdk.init( release=get_sentry_release(), environment=environment, dsn=sentry_dsn, integrations=integrations, debug=debug, **extra_kwargs, ) diff --git a/swh/core/statsd.py b/swh/core/statsd.py index c212e71..1747328 100644 --- a/swh/core/statsd.py +++ b/swh/core/statsd.py @@ -1,441 +1,440 @@ # Copyright (C) 2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information # Initially imported from https://github.com/DataDog/datadogpy/ # at revision 62b3a3e89988dc18d78c282fe3ff5d1813917436 # # Copyright (c) 2015, Datadog # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of Datadog nor the names of its contributors may be # used to endorse or promote products derived from this software without # specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # # # Vastly adapted for integration in swh.core: # # - Removed python < 3.5 compat code # - trimmed the imports down to be a single module # - adjust some options: # - drop unix socket connection option # - add environment variable support for setting the statsd host and # port (pulled the idea from the main python statsd module) # - only send timer metrics in milliseconds (that's what # prometheus-statsd-exporter expects) # - drop DataDog-specific metric types (that are unsupported in # prometheus-statsd-exporter) # - made the tags a dict instead of a list (prometheus-statsd-exporter only # supports tags with a value, mirroring prometheus) # - switch from time.time to time.monotonic # - improve unit test coverage # - documentation cleanup from asyncio import iscoroutinefunction from functools import wraps -from random import random -from time import monotonic import itertools import logging import os +from random import random import socket import threading +from time import monotonic import warnings - log = logging.getLogger("swh.core.statsd") class TimedContextManagerDecorator(object): """ A context manager and a decorator which will report the elapsed time in the context OR in a function call. Attributes: elapsed (float): the elapsed time at the point of completion """ def __init__( self, statsd, metric=None, error_metric=None, tags=None, sample_rate=1 ): self.statsd = statsd self.metric = metric self.error_metric = error_metric self.tags = tags self.sample_rate = sample_rate self.elapsed = None # this is for testing purpose def __call__(self, func): """ Decorator which returns the elapsed time of the function call. Default to the function name if metric was not provided. """ if not self.metric: self.metric = "%s.%s" % (func.__module__, func.__name__) # Coroutines if iscoroutinefunction(func): @wraps(func) async def wrapped_co(*args, **kwargs): start = monotonic() try: result = await func(*args, **kwargs) except: # noqa self._send_error() raise self._send(start) return result return wrapped_co # Others @wraps(func) def wrapped(*args, **kwargs): start = monotonic() try: result = func(*args, **kwargs) except: # noqa self._send_error() raise self._send(start) return result return wrapped def __enter__(self): if not self.metric: raise TypeError("Cannot used timed without a metric!") self._start = monotonic() return self def __exit__(self, type, value, traceback): # Report the elapsed time of the context manager if no error. if type is None: self._send(self._start) else: self._send_error() def _send(self, start): elapsed = (monotonic() - start) * 1000 self.statsd.timing( self.metric, elapsed, tags=self.tags, sample_rate=self.sample_rate ) self.elapsed = elapsed def _send_error(self): if self.error_metric is None: self.error_metric = self.metric + "_error_count" self.statsd.increment(self.error_metric, tags=self.tags) def start(self): """Start the timer""" self.__enter__() def stop(self): """Stop the timer, send the metric value""" self.__exit__(None, None, None) class Statsd(object): """Initialize a client to send metrics to a StatsD server. Arguments: host (str): the host of the StatsD server. Defaults to localhost. port (int): the port of the StatsD server. Defaults to 8125. max_buffer_size (int): Maximum number of metrics to buffer before sending to the server if sending metrics in batch namespace (str): Namespace to prefix all metric names constant_tags (Dict[str, str]): Tags to attach to all metrics Note: This class also supports the following environment variables: STATSD_HOST Override the default host of the statsd server STATSD_PORT Override the default port of the statsd server STATSD_TAGS Tags to attach to every metric reported. Example value: "label:value,other_label:other_value" """ def __init__( self, host=None, port=None, max_buffer_size=50, namespace=None, constant_tags=None, ): # Connection if host is None: host = os.environ.get("STATSD_HOST") or "localhost" self.host = host if port is None: port = os.environ.get("STATSD_PORT") or 8125 self.port = int(port) # Socket self._socket = None self.lock = threading.Lock() self.max_buffer_size = max_buffer_size self._send = self._send_to_server self.encoding = "utf-8" # Tags self.constant_tags = {} tags_envvar = os.environ.get("STATSD_TAGS", "") for tag in tags_envvar.split(","): if not tag: continue if ":" not in tag: warnings.warn( "STATSD_TAGS needs to be in key:value format, " "%s invalid" % tag, UserWarning, ) continue k, v = tag.split(":", 1) self.constant_tags[k] = v if constant_tags: self.constant_tags.update( {str(k): str(v) for k, v in constant_tags.items()} ) # Namespace if namespace is not None: namespace = str(namespace) self.namespace = namespace def __enter__(self): self.open_buffer(self.max_buffer_size) return self def __exit__(self, type, value, traceback): self.close_buffer() def gauge(self, metric, value, tags=None, sample_rate=1): """ Record the value of a gauge, optionally setting a list of tags and a sample rate. >>> statsd.gauge('users.online', 123) >>> statsd.gauge('active.connections', 1001, tags={"protocol": "http"}) """ return self._report(metric, "g", value, tags, sample_rate) def increment(self, metric, value=1, tags=None, sample_rate=1): """ Increment a counter, optionally setting a value, tags and a sample rate. >>> statsd.increment('page.views') >>> statsd.increment('files.transferred', 124) """ self._report(metric, "c", value, tags, sample_rate) def decrement(self, metric, value=1, tags=None, sample_rate=1): """ Decrement a counter, optionally setting a value, tags and a sample rate. >>> statsd.decrement('files.remaining') >>> statsd.decrement('active.connections', 2) """ metric_value = -value if value else value self._report(metric, "c", metric_value, tags, sample_rate) def histogram(self, metric, value, tags=None, sample_rate=1): """ Sample a histogram value, optionally setting tags and a sample rate. >>> statsd.histogram('uploaded.file.size', 1445) >>> statsd.histogram('file.count', 26, tags={"filetype": "python"}) """ self._report(metric, "h", value, tags, sample_rate) def timing(self, metric, value, tags=None, sample_rate=1): """ Record a timing, optionally setting tags and a sample rate. >>> statsd.timing("query.response.time", 1234) """ self._report(metric, "ms", value, tags, sample_rate) def timed(self, metric=None, error_metric=None, tags=None, sample_rate=1): """ A decorator or context manager that will measure the distribution of a function's/context's run time. Optionally specify a list of tags or a sample rate. If the metric is not defined as a decorator, the module name and function name will be used. The metric is required as a context manager. :: @statsd.timed('user.query.time', sample_rate=0.5) def get_user(user_id): # Do what you need to ... pass # Is equivalent to ... with statsd.timed('user.query.time', sample_rate=0.5): # Do what you need to ... pass # Is equivalent to ... start = time.monotonic() try: get_user(user_id) finally: statsd.timing('user.query.time', time.monotonic() - start) """ return TimedContextManagerDecorator( statsd=self, metric=metric, error_metric=error_metric, tags=tags, sample_rate=sample_rate, ) def set(self, metric, value, tags=None, sample_rate=1): """ Sample a set value. >>> statsd.set('visitors.uniques', 999) """ self._report(metric, "s", value, tags, sample_rate) @property def socket(self): """ Return a connected socket. Note: connect the socket before assigning it to the class instance to avoid bad thread race conditions. """ with self.lock: if not self._socket: sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.connect((self.host, self.port)) self._socket = sock return self._socket def open_buffer(self, max_buffer_size=50): """ Open a buffer to send a batch of metrics in one packet. You can also use this as a context manager. >>> with Statsd() as batch: ... batch.gauge('users.online', 123) ... batch.gauge('active.connections', 1001) """ self.max_buffer_size = max_buffer_size self.buffer = [] self._send = self._send_to_buffer def close_buffer(self): """ Flush the buffer and switch back to single metric packets. """ self._send = self._send_to_server if self.buffer: # Only send packets if there are packets to send self._flush_buffer() def close_socket(self): """ Closes connected socket if connected. """ with self.lock: if self._socket: self._socket.close() self._socket = None def _report(self, metric, metric_type, value, tags, sample_rate): """ Create a metric packet and send it. """ if value is None: return if sample_rate != 1 and random() > sample_rate: return # Resolve the full tag list tags = self._add_constant_tags(tags) # Create/format the metric packet payload = "%s%s:%s|%s%s%s" % ( (self.namespace + ".") if self.namespace else "", metric, value, metric_type, ("|@" + str(sample_rate)) if sample_rate != 1 else "", ("|#" + ",".join("%s:%s" % (k, v) for (k, v) in sorted(tags.items()))) if tags else "", ) # Send it self._send(payload) def _send_to_server(self, packet): try: # If set, use socket directly self.socket.send(packet.encode("utf-8")) except socket.timeout: return except socket.error: log.debug( "Error submitting statsd packet." " Dropping the packet and closing the socket." ) self.close_socket() def _send_to_buffer(self, packet): self.buffer.append(packet) if len(self.buffer) >= self.max_buffer_size: self._flush_buffer() def _flush_buffer(self): self._send_to_server("\n".join(self.buffer)) self.buffer = [] def _add_constant_tags(self, tags): return { str(k): str(v) for k, v in itertools.chain( self.constant_tags.items(), (tags if tags else {}).items(), ) } statsd = Statsd() diff --git a/swh/core/tarball.py b/swh/core/tarball.py index d557efd..3f3c29a 100644 --- a/swh/core/tarball.py +++ b/swh/core/tarball.py @@ -1,148 +1,147 @@ # Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import shutil import stat +from subprocess import run import tarfile import zipfile -from subprocess import run - from . import utils def _unpack_tar(tarpath: str, extract_dir: str) -> str: """Unpack tarballs unsupported by the standard python library. Examples include tar.Z, tar.lz, tar.x, etc.... As this implementation relies on the `tar` command, this function supports the same compression the tar command supports. This expects the `extract_dir` to exist. Raises shutil.ReadError in case of issue uncompressing the archive (tarpath does not exist, extract_dir does not exist, etc...) Returns full path to the uncompressed directory. """ try: run(["tar", "xf", tarpath, "-C", extract_dir], check=True) return extract_dir except Exception as e: raise shutil.ReadError( f"Unable to uncompress {tarpath} to {extract_dir}. Reason: {e}" ) def register_new_archive_formats(): """Register new archive formats to uncompress """ registered_formats = [f[0] for f in shutil.get_unpack_formats()] for name, extensions, function in ADDITIONAL_ARCHIVE_FORMATS: if name in registered_formats: continue shutil.register_unpack_format(name, extensions, function) def uncompress(tarpath: str, dest: str): """Uncompress tarpath to dest folder if tarball is supported. Note that this fixes permissions after successfully uncompressing the archive. Args: tarpath: path to tarball to uncompress dest: the destination folder where to uncompress the tarball Returns: The nature of the tarball, zip or tar. Raises: ValueError when a problem occurs during unpacking """ try: shutil.unpack_archive(tarpath, extract_dir=dest) except shutil.ReadError as e: raise ValueError(f"Problem during unpacking {tarpath}. Reason: {e}") # Fix permissions for dirpath, _, fnames in os.walk(dest): os.chmod(dirpath, 0o755) for fname in fnames: fpath = os.path.join(dirpath, fname) if not os.path.islink(fpath): fpath_exec = os.stat(fpath).st_mode & stat.S_IXUSR if not fpath_exec: os.chmod(fpath, 0o644) def _ls(rootdir): """Generator of filepath, filename from rootdir. """ for dirpath, dirnames, fnames in os.walk(rootdir): for fname in dirnames + fnames: fpath = os.path.join(dirpath, fname) fname = utils.commonname(rootdir, fpath) yield fpath, fname def _compress_zip(tarpath, files): """Compress dirpath's content as tarpath. """ with zipfile.ZipFile(tarpath, "w") as z: for fpath, fname in files: z.write(fpath, arcname=fname) def _compress_tar(tarpath, files): """Compress dirpath's content as tarpath. """ with tarfile.open(tarpath, "w:bz2") as t: for fpath, fname in files: t.add(fpath, arcname=fname, recursive=False) def compress(tarpath, nature, dirpath_or_files): """Create a tarball tarpath with nature nature. The content of the tarball is either dirpath's content (if representing a directory path) or dirpath's iterable contents. Compress the directory dirpath's content to a tarball. The tarball being dumped at tarpath. The nature of the tarball is determined by the nature argument. """ if isinstance(dirpath_or_files, str): files = _ls(dirpath_or_files) else: # iterable of 'filepath, filename' files = dirpath_or_files if nature == "zip": _compress_zip(tarpath, files) else: _compress_tar(tarpath, files) return tarpath # Additional uncompression archive format support ADDITIONAL_ARCHIVE_FORMATS = [ # name , extensions, function ("tar.Z|x", [".tar.Z", ".tar.x"], _unpack_tar), # FIXME: make this optional depending on the runtime lzip package install ("tar.lz", [".tar.lz"], _unpack_tar), ] register_new_archive_formats() diff --git a/swh/core/tests/__init__.py b/swh/core/tests/__init__.py index e70ce2a..688a7de 100644 --- a/swh/core/tests/__init__.py +++ b/swh/core/tests/__init__.py @@ -1,5 +1,5 @@ from os import path -import swh.core +import swh.core SQL_DIR = path.join(path.dirname(swh.core.__file__), "sql") diff --git a/swh/core/tests/fixture/conftest.py b/swh/core/tests/fixture/conftest.py index 412d102..0c3cb69 100644 --- a/swh/core/tests/fixture/conftest.py +++ b/swh/core/tests/fixture/conftest.py @@ -1,16 +1,15 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import pytest - from os import path +import pytest DATADIR = path.join(path.abspath(path.dirname(__file__)), "data") @pytest.fixture def datadir(): return DATADIR diff --git a/swh/core/tests/test_cli.py b/swh/core/tests/test_cli.py index 089eb93..6e71060 100644 --- a/swh/core/tests/test_cli.py +++ b/swh/core/tests/test_cli.py @@ -1,283 +1,282 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging -import pkg_resources import textwrap from unittest.mock import patch import click from click.testing import CliRunner +import pkg_resources import pytest - help_msg = """Usage: swh [OPTIONS] COMMAND [ARGS]... Command line interface for Software Heritage. Options: -l, --log-level [NOTSET|DEBUG|INFO|WARNING|ERROR|CRITICAL] Log level (defaults to INFO). --log-config FILENAME Python yaml logging configuration file. --sentry-dsn TEXT DSN of the Sentry instance to report to -h, --help Show this message and exit. Notes: If both options are present, --log-level will override the root logger configuration set in --log-config. The --log-config YAML must conform to the logging.config.dictConfig schema documented at https://docs.python.org/3/library/logging.config.html. """ def test_swh_help(swhmain): runner = CliRunner() result = runner.invoke(swhmain, ["-h"]) assert result.exit_code == 0 assert result.output.startswith(help_msg) result = runner.invoke(swhmain, ["--help"]) assert result.exit_code == 0 assert result.output.startswith(help_msg) def test_command(swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): click.echo("Hello SWH!") runner = CliRunner() with patch("sentry_sdk.init") as sentry_sdk_init: result = runner.invoke(swhmain, ["test"]) sentry_sdk_init.assert_not_called() assert result.exit_code == 0 assert result.output.strip() == "Hello SWH!" def test_loglevel_default(caplog, swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): assert logging.root.level == 20 click.echo("Hello SWH!") runner = CliRunner() result = runner.invoke(swhmain, ["test"]) assert result.exit_code == 0 print(result.output) assert result.output.strip() == """Hello SWH!""" def test_loglevel_error(caplog, swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): assert logging.root.level == 40 click.echo("Hello SWH!") runner = CliRunner() result = runner.invoke(swhmain, ["-l", "ERROR", "test"]) assert result.exit_code == 0 assert result.output.strip() == """Hello SWH!""" def test_loglevel_debug(caplog, swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): assert logging.root.level == 10 click.echo("Hello SWH!") runner = CliRunner() result = runner.invoke(swhmain, ["-l", "DEBUG", "test"]) assert result.exit_code == 0 assert result.output.strip() == """Hello SWH!""" def test_sentry(swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): click.echo("Hello SWH!") runner = CliRunner() with patch("sentry_sdk.init") as sentry_sdk_init: result = runner.invoke(swhmain, ["--sentry-dsn", "test_dsn", "test"]) assert result.exit_code == 0 assert result.output.strip() == """Hello SWH!""" sentry_sdk_init.assert_called_once_with( dsn="test_dsn", debug=False, integrations=[], release=None, environment=None, ) def test_sentry_debug(swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): click.echo("Hello SWH!") runner = CliRunner() with patch("sentry_sdk.init") as sentry_sdk_init: result = runner.invoke( swhmain, ["--sentry-dsn", "test_dsn", "--sentry-debug", "test"] ) assert result.exit_code == 0 assert result.output.strip() == """Hello SWH!""" sentry_sdk_init.assert_called_once_with( dsn="test_dsn", debug=True, integrations=[], release=None, environment=None, ) def test_sentry_env(swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): click.echo("Hello SWH!") runner = CliRunner() with patch("sentry_sdk.init") as sentry_sdk_init: env = { "SWH_SENTRY_DSN": "test_dsn", "SWH_SENTRY_DEBUG": "1", } result = runner.invoke(swhmain, ["test"], env=env, auto_envvar_prefix="SWH") assert result.exit_code == 0 assert result.output.strip() == """Hello SWH!""" sentry_sdk_init.assert_called_once_with( dsn="test_dsn", debug=True, integrations=[], release=None, environment=None, ) def test_sentry_env_main_package(swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): click.echo("Hello SWH!") runner = CliRunner() with patch("sentry_sdk.init") as sentry_sdk_init: env = { "SWH_SENTRY_DSN": "test_dsn", "SWH_MAIN_PACKAGE": "swh.core", "SWH_SENTRY_ENVIRONMENT": "tests", } result = runner.invoke(swhmain, ["test"], env=env, auto_envvar_prefix="SWH") assert result.exit_code == 0 version = pkg_resources.get_distribution("swh.core").version assert result.output.strip() == """Hello SWH!""" sentry_sdk_init.assert_called_once_with( dsn="test_dsn", debug=False, integrations=[], release="swh.core@" + version, environment="tests", ) @pytest.fixture def log_config_path(tmp_path): log_config = textwrap.dedent( """\ --- version: 1 formatters: formatter: format: 'custom format:%(name)s:%(levelname)s:%(message)s' handlers: console: class: logging.StreamHandler stream: ext://sys.stdout formatter: formatter level: DEBUG root: level: DEBUG handlers: - console loggers: dontshowdebug: level: INFO """ ) (tmp_path / "log_config.yml").write_text(log_config) yield str(tmp_path / "log_config.yml") def test_log_config(caplog, log_config_path, swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): logging.debug("Root log debug") logging.info("Root log info") logging.getLogger("dontshowdebug").debug("Not shown") logging.getLogger("dontshowdebug").info("Shown") runner = CliRunner() result = runner.invoke(swhmain, ["--log-config", log_config_path, "test",],) assert result.exit_code == 0 assert result.output.strip() == "\n".join( [ "custom format:root:DEBUG:Root log debug", "custom format:root:INFO:Root log info", "custom format:dontshowdebug:INFO:Shown", ] ) def test_log_config_log_level_interaction(caplog, log_config_path, swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): logging.debug("Root log debug") logging.info("Root log info") logging.getLogger("dontshowdebug").debug("Not shown") logging.getLogger("dontshowdebug").info("Shown") runner = CliRunner() result = runner.invoke( swhmain, ["--log-config", log_config_path, "--log-level", "INFO", "test",], ) assert result.exit_code == 0 assert result.output.strip() == "\n".join( [ "custom format:root:INFO:Root log info", "custom format:dontshowdebug:INFO:Shown", ] ) def test_aliased_command(swhmain): @swhmain.command(name="canonical-test") @click.pass_context def swhtest(ctx): "A test command." click.echo("Hello SWH!") swhmain.add_alias(swhtest, "othername") runner = CliRunner() # check we have only 'canonical-test' listed in the usage help msg result = runner.invoke(swhmain, ["-h"]) assert result.exit_code == 0 assert "canonical-test A test command." in result.output assert "othername" not in result.output # check we can execute the cmd with 'canonical-test' result = runner.invoke(swhmain, ["canonical-test"]) assert result.exit_code == 0 assert result.output.strip() == """Hello SWH!""" # check we can also execute the cmd with the alias 'othername' result = runner.invoke(swhmain, ["othername"]) assert result.exit_code == 0 assert result.output.strip() == """Hello SWH!""" diff --git a/swh/core/tests/test_collections.py b/swh/core/tests/test_collections.py index c40a121..22efbc0 100644 --- a/swh/core/tests/test_collections.py +++ b/swh/core/tests/test_collections.py @@ -1,73 +1,71 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest - from swh.core.collections import SortedList - parametrize = pytest.mark.parametrize( "items", [ [1, 2, 3, 4, 5, 6, 10, 100], [10, 100, 6, 5, 4, 3, 2, 1], [10, 4, 5, 6, 1, 2, 3, 100], ], ) @parametrize def test_sorted_list_iter(items): list1 = SortedList() for item in items: list1.add(item) assert list(list1) == sorted(items) list2 = SortedList(items) assert list(list2) == sorted(items) @parametrize def test_sorted_list_iter__key(items): list1 = SortedList(key=lambda item: -item) for item in items: list1.add(item) assert list(list1) == list(reversed(sorted(items))) list2 = SortedList(items, key=lambda item: -item) assert list(list2) == list(reversed(sorted(items))) @parametrize def test_sorted_list_iter_from(items): list_ = SortedList(items) for split in items: expected = sorted(item for item in items if item >= split) assert list(list_.iter_from(split)) == expected, f"split: {split}" @parametrize def test_sorted_list_iter_from__key(items): list_ = SortedList(items, key=lambda item: -item) for split in items: expected = reversed(sorted(item for item in items if item <= split)) assert list(list_.iter_from(-split)) == list(expected), f"split: {split}" @parametrize def test_sorted_list_iter_after(items): list_ = SortedList(items) for split in items: expected = sorted(item for item in items if item > split) assert list(list_.iter_after(split)) == expected, f"split: {split}" @parametrize def test_sorted_list_iter_after__key(items): list_ = SortedList(items, key=lambda item: -item) for split in items: expected = reversed(sorted(item for item in items if item < split)) assert list(list_.iter_after(-split)) == list(expected), f"split: {split}" diff --git a/swh/core/tests/test_config.py b/swh/core/tests/test_config.py index 6ee50d5..3227ebc 100644 --- a/swh/core/tests/test_config.py +++ b/swh/core/tests/test_config.py @@ -1,328 +1,328 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os -import pytest import shutil -import pkg_resources.extern.packaging.version +import pkg_resources.extern.packaging.version +import pytest import yaml from swh.core import config pytest_v = pkg_resources.get_distribution("pytest").parsed_version if pytest_v < pkg_resources.extern.packaging.version.parse("3.9"): @pytest.fixture def tmp_path(request): - import tempfile import pathlib + import tempfile with tempfile.TemporaryDirectory() as tmpdir: yield pathlib.Path(tmpdir) default_conf = { "a": ("int", 2), "b": ("string", "default-string"), "c": ("bool", True), "d": ("int", 10), "e": ("int", None), "f": ("bool", None), "g": ("string", None), "h": ("bool", True), "i": ("bool", True), "ls": ("list[str]", ["a", "b", "c"]), "li": ("list[int]", [42, 43]), } other_default_conf = { "a": ("int", 3), } full_default_conf = default_conf.copy() full_default_conf["a"] = other_default_conf["a"] parsed_default_conf = {key: value for key, (type, value) in default_conf.items()} parsed_conffile = { "a": 1, "b": "this is a string", "c": True, "d": 10, "e": None, "f": None, "g": None, "h": False, "i": True, "ls": ["list", "of", "strings"], "li": [1, 2, 3, 4], } @pytest.fixture def swh_config(tmp_path): # create a temporary folder conffile = tmp_path / "config.yml" conf_contents = """ a: 1 b: this is a string c: true h: false ls: list, of, strings li: 1, 2, 3, 4 """ conffile.open("w").write(conf_contents) return conffile @pytest.fixture def swh_config_unreadable(swh_config): # Create an unreadable, proper configuration file os.chmod(str(swh_config), 0o000) yield swh_config # Make the broken perms file readable again to be able to remove them os.chmod(str(swh_config), 0o644) @pytest.fixture def swh_config_unreadable_dir(swh_config): # Create a proper configuration file in an unreadable directory perms_broken_dir = swh_config.parent / "unreadabledir" perms_broken_dir.mkdir() shutil.move(str(swh_config), str(perms_broken_dir)) os.chmod(str(perms_broken_dir), 0o000) yield perms_broken_dir / swh_config.name # Make the broken perms items readable again to be able to remove them os.chmod(str(perms_broken_dir), 0o755) @pytest.fixture def swh_config_empty(tmp_path): # create a temporary folder conffile = tmp_path / "config.yml" conffile.touch() return conffile def test_read(swh_config): # when res = config.read(str(swh_config), default_conf) # then assert res == parsed_conffile def test_read_no_default_conf(swh_config): """If no default config if provided to read, this should directly parse the config file yaml """ config_path = str(swh_config) actual_config = config.read(config_path) with open(config_path) as f: expected_config = yaml.safe_load(f) assert actual_config == expected_config def test_read_empty_file(): # when res = config.read(None, default_conf) # then assert res == parsed_default_conf def test_support_non_existing_conffile(tmp_path): # when res = config.read(str(tmp_path / "void.yml"), default_conf) # then assert res == parsed_default_conf def test_support_empty_conffile(swh_config_empty): # when res = config.read(str(swh_config_empty), default_conf) # then assert res == parsed_default_conf def test_raise_on_broken_directory_perms(swh_config_unreadable_dir): with pytest.raises(PermissionError): config.read(str(swh_config_unreadable_dir), default_conf) def test_raise_on_broken_file_perms(swh_config_unreadable): with pytest.raises(PermissionError): config.read(str(swh_config_unreadable), default_conf) def test_merge_default_configs(): # when res = config.merge_default_configs(default_conf, other_default_conf) # then assert res == full_default_conf def test_priority_read_nonexist_conf(swh_config): noexist = str(swh_config.parent / "void.yml") # when res = config.priority_read([noexist, str(swh_config)], default_conf) # then assert res == parsed_conffile def test_priority_read_conf_nonexist_empty(swh_config): noexist = swh_config.parent / "void.yml" empty = swh_config.parent / "empty.yml" empty.touch() # when res = config.priority_read( [str(p) for p in (swh_config, noexist, empty)], default_conf ) # then assert res == parsed_conffile def test_priority_read_empty_conf_nonexist(swh_config): noexist = swh_config.parent / "void.yml" empty = swh_config.parent / "empty.yml" empty.touch() # when res = config.priority_read( [str(p) for p in (empty, swh_config, noexist)], default_conf ) # then assert res == parsed_default_conf def test_swh_config_paths(): res = config.swh_config_paths("foo/bar.yml") assert res == [ "~/.config/swh/foo/bar.yml", "~/.swh/foo/bar.yml", "/etc/softwareheritage/foo/bar.yml", ] def test_prepare_folder(tmp_path): # given conf = { "path1": str(tmp_path / "path1"), "path2": str(tmp_path / "path2" / "depth1"), } # the folders does not exists assert not os.path.exists(conf["path1"]), "path1 should not exist." assert not os.path.exists(conf["path2"]), "path2 should not exist." # when config.prepare_folders(conf, "path1") # path1 exists but not path2 assert os.path.exists(conf["path1"]), "path1 should now exist!" assert not os.path.exists(conf["path2"]), "path2 should not exist." # path1 already exists, skips it but creates path2 config.prepare_folders(conf, "path1", "path2") assert os.path.exists(conf["path1"]), "path1 should still exist!" assert os.path.exists(conf["path2"]), "path2 should now exist." def test_merge_config(): cfg_a = { "a": 42, "b": [1, 2, 3], "c": None, "d": {"gheez": 27}, "e": { "ea": "Mr. Bungle", "eb": None, "ec": [11, 12, 13], "ed": {"eda": "Secret Chief 3", "edb": "Faith No More"}, "ee": 451, }, "f": "Janis", } cfg_b = { "a": 43, "b": [41, 42, 43], "c": "Tom Waits", "d": None, "e": { "ea": "Igorrr", "ec": [51, 52], "ed": {"edb": "Sleepytime Gorilla Museum", "edc": "Nils Peter Molvaer"}, }, "g": "Hüsker Dü", } # merge A, B cfg_m = config.merge_configs(cfg_a, cfg_b) assert cfg_m == { "a": 43, # b takes precedence "b": [41, 42, 43], # b takes precedence "c": "Tom Waits", # b takes precedence "d": None, # b['d'] takes precedence (explicit None) "e": { "ea": "Igorrr", # a takes precedence "eb": None, # only in a "ec": [51, 52], # b takes precedence "ed": { "eda": "Secret Chief 3", # only in a "edb": "Sleepytime Gorilla Museum", # b takes precedence "edc": "Nils Peter Molvaer", }, # only defined in b "ee": 451, }, "f": "Janis", # only defined in a "g": "Hüsker Dü", # only defined in b } # merge B, A cfg_m = config.merge_configs(cfg_b, cfg_a) assert cfg_m == { "a": 42, # a takes precedence "b": [1, 2, 3], # a takes precedence "c": None, # a takes precedence "d": {"gheez": 27}, # a takes precedence "e": { "ea": "Mr. Bungle", # a takes precedence "eb": None, # only defined in a "ec": [11, 12, 13], # a takes precedence "ed": { "eda": "Secret Chief 3", # only in a "edb": "Faith No More", # a takes precedence "edc": "Nils Peter Molvaer", }, # only in b "ee": 451, }, "f": "Janis", # only in a "g": "Hüsker Dü", # only in b } def test_merge_config_type_error(): for v in (1, "str", None): with pytest.raises(TypeError): config.merge_configs(v, {}) with pytest.raises(TypeError): config.merge_configs({}, v) for v in (1, "str"): with pytest.raises(TypeError): config.merge_configs({"a": v}, {"a": {}}) with pytest.raises(TypeError): config.merge_configs({"a": {}}, {"a": v}) diff --git a/swh/core/tests/test_logger.py b/swh/core/tests/test_logger.py index 6980e6e..677082e 100644 --- a/swh/core/tests/test_logger.py +++ b/swh/core/tests/test_logger.py @@ -1,130 +1,130 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime -import logging -import pytz import inspect - +import logging from unittest.mock import patch +import pytz + from swh.core import logger def lineno(): """Returns the current line number in our program.""" return inspect.currentframe().f_back.f_lineno def test_db_level(): assert logger.db_level_of_py_level(10) == "debug" assert logger.db_level_of_py_level(20) == "info" assert logger.db_level_of_py_level(30) == "warning" assert logger.db_level_of_py_level(40) == "error" assert logger.db_level_of_py_level(50) == "critical" def test_flatten_scalar(): assert list(logger.flatten("")) == [("", "")] assert list(logger.flatten("toto")) == [("", "toto")] assert list(logger.flatten(10)) == [("", 10)] assert list(logger.flatten(10.5)) == [("", 10.5)] def test_flatten_list(): assert list(logger.flatten([])) == [] assert list(logger.flatten([1])) == [("0", 1)] assert list(logger.flatten([1, 2, ["a", "b"]])) == [ ("0", 1), ("1", 2), ("2_0", "a"), ("2_1", "b"), ] assert list(logger.flatten([1, 2, ["a", ("x", 1)]])) == [ ("0", 1), ("1", 2), ("2_0", "a"), ("2_1_0", "x"), ("2_1_1", 1), ] def test_flatten_dict(): assert list(logger.flatten({})) == [] assert list(logger.flatten({"a": 1})) == [("a", 1)] assert sorted(logger.flatten({"a": 1, "b": (2, 3,), "c": {"d": 4, "e": "f"}})) == [ ("a", 1), ("b_0", 2), ("b_1", 3), ("c_d", 4), ("c_e", "f"), ] def test_flatten_dict_binary_keys(): d = {b"a": "a"} str_d = str(d) assert list(logger.flatten(d)) == [("", str_d)] assert list(logger.flatten({"a": d})) == [("a", str_d)] assert list(logger.flatten({"a": [d, d]})) == [("a_0", str_d), ("a_1", str_d)] def test_stringify(): assert logger.stringify(None) == "None" assert logger.stringify(123) == "123" assert logger.stringify("abc") == "abc" date = datetime(2019, 9, 1, 16, 32) assert logger.stringify(date) == "2019-09-01T16:32:00" tzdate = datetime(2019, 9, 1, 16, 32, tzinfo=pytz.utc) assert logger.stringify(tzdate) == "2019-09-01T16:32:00+00:00" @patch("swh.core.logger.send") def test_journal_handler(send): log = logging.getLogger("test_logger") log.addHandler(logger.JournalHandler()) log.setLevel(logging.DEBUG) _, ln = log.info("hello world"), lineno() send.assert_called_with( "hello world", CODE_FILE=__file__, CODE_FUNC="test_journal_handler", CODE_LINE=ln, LOGGER="test_logger", PRIORITY="6", THREAD_NAME="MainThread", ) @patch("swh.core.logger.send") def test_journal_handler_w_data(send): log = logging.getLogger("test_logger") log.addHandler(logger.JournalHandler()) log.setLevel(logging.DEBUG) _, ln = ( log.debug("something cool %s", ["with", {"extra": "data"}]), lineno() - 1, ) send.assert_called_with( "something cool ['with', {'extra': 'data'}]", CODE_FILE=__file__, CODE_FUNC="test_journal_handler_w_data", CODE_LINE=ln, LOGGER="test_logger", PRIORITY="7", THREAD_NAME="MainThread", SWH_LOGGING_ARGS_0_0="with", SWH_LOGGING_ARGS_0_1_EXTRA="data", ) diff --git a/swh/core/tests/test_pytest_plugin.py b/swh/core/tests/test_pytest_plugin.py index 7c72df3..f8d23ef 100644 --- a/swh/core/tests/test_pytest_plugin.py +++ b/swh/core/tests/test_pytest_plugin.py @@ -1,117 +1,117 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import requests - from os import path from urllib.parse import unquote +import requests + from swh.core.pytest_plugin import requests_mock_datadir_factory def test_get_response_cb_with_encoded_url(requests_mock_datadir): # The following urls (quoted, unquoted) will be resolved as the same file for encoded_url, expected_response in [ ("https://forge.s.o/api/diffusion?attachments%5Buris%5D=1", "something"), ( "https://www.reference.com/web?q=What+Is+an+Example+of+a+URL?&qo=contentPageRelatedSearch&o=600605&l=dir&sga=1", # noqa "something else", ), ]: for url in [encoded_url, unquote(encoded_url)]: response = requests.get(url) assert response.ok assert response.json() == expected_response def test_get_response_cb_with_visits_nominal(requests_mock_datadir_visits): response = requests.get("https://example.com/file.json") assert response.ok assert response.json() == {"hello": "you"} response = requests.get("http://example.com/something.json") assert response.ok assert response.json() == "something" response = requests.get("https://example.com/file.json") assert response.ok assert response.json() == {"hello": "world"} response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 def test_get_response_cb_with_visits(requests_mock_datadir_visits): response = requests.get("https://example.com/file.json") assert response.ok assert response.json() == {"hello": "you"} response = requests.get("https://example.com/other.json") assert response.ok assert response.json() == "foobar" response = requests.get("https://example.com/file.json") assert response.ok assert response.json() == {"hello": "world"} response = requests.get("https://example.com/other.json") assert not response.ok assert response.status_code == 404 response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 def test_get_response_cb_no_visit(requests_mock_datadir): response = requests.get("https://example.com/file.json") assert response.ok assert response.json() == {"hello": "you"} response = requests.get("https://example.com/file.json") assert response.ok assert response.json() == {"hello": "you"} def test_get_response_cb_query_params(requests_mock_datadir): response = requests.get("https://example.com/file.json?toto=42") assert not response.ok assert response.status_code == 404 response = requests.get("https://example.com/file.json?name=doe&firstname=jane") assert response.ok assert response.json() == {"hello": "jane doe"} requests_mock_datadir_ignore = requests_mock_datadir_factory( ignore_urls=["https://example.com/file.json"], has_multi_visit=False, ) def test_get_response_cb_ignore_url(requests_mock_datadir_ignore): response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 requests_mock_datadir_ignore_and_visit = requests_mock_datadir_factory( ignore_urls=["https://example.com/file.json"], has_multi_visit=True, ) def test_get_response_cb_ignore_url_with_visit(requests_mock_datadir_ignore_and_visit): response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 def test_data_dir(datadir): expected_datadir = path.join(path.abspath(path.dirname(__file__)), "data") assert datadir == expected_datadir diff --git a/swh/core/tests/test_tarball.py b/swh/core/tests/test_tarball.py index 3c0449f..add1bec 100644 --- a/swh/core/tests/test_tarball.py +++ b/swh/core/tests/test_tarball.py @@ -1,223 +1,224 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os -import pytest import shutil +import pytest + from swh.core import tarball @pytest.fixture def prepare_shutil_state(): """Reset any shutil modification in its current state """ import shutil registered_formats = [f[0] for f in shutil.get_unpack_formats()] for format_id in tarball.ADDITIONAL_ARCHIVE_FORMATS: name = format_id[0] if name in registered_formats: shutil.unregister_unpack_format(name) return shutil def test_compress_uncompress_zip(tmp_path): tocompress = tmp_path / "compressme" tocompress.mkdir() for i in range(10): fpath = tocompress / ("file%s.txt" % i) fpath.write_text("content of file %s" % i) zipfile = tmp_path / "archive.zip" tarball.compress(str(zipfile), "zip", str(tocompress)) destdir = tmp_path / "destdir" tarball.uncompress(str(zipfile), str(destdir)) lsdir = sorted(x.name for x in destdir.iterdir()) assert ["file%s.txt" % i for i in range(10)] == lsdir @pytest.mark.xfail( reason=( "Python's zipfile library doesn't support Info-ZIP's " "extension for file permissions." ) ) def test_compress_uncompress_zip_modes(tmp_path): tocompress = tmp_path / "compressme" tocompress.mkdir() fpath = tocompress / "text.txt" fpath.write_text("echo foo") fpath.chmod(0o644) fpath = tocompress / "executable.sh" fpath.write_text("echo foo") fpath.chmod(0o755) zipfile = tmp_path / "archive.zip" tarball.compress(str(zipfile), "zip", str(tocompress)) destdir = tmp_path / "destdir" tarball.uncompress(str(zipfile), str(destdir)) (executable_path, text_path) = sorted(destdir.iterdir()) assert text_path.stat().st_mode == 0o100644 # succeeds, it's the default assert executable_path.stat().st_mode == 0o100755 # fails def test_compress_uncompress_tar(tmp_path): tocompress = tmp_path / "compressme" tocompress.mkdir() for i in range(10): fpath = tocompress / ("file%s.txt" % i) fpath.write_text("content of file %s" % i) tarfile = tmp_path / "archive.tar" tarball.compress(str(tarfile), "tar", str(tocompress)) destdir = tmp_path / "destdir" tarball.uncompress(str(tarfile), str(destdir)) lsdir = sorted(x.name for x in destdir.iterdir()) assert ["file%s.txt" % i for i in range(10)] == lsdir def test_compress_uncompress_tar_modes(tmp_path): tocompress = tmp_path / "compressme" tocompress.mkdir() fpath = tocompress / "text.txt" fpath.write_text("echo foo") fpath.chmod(0o644) fpath = tocompress / "executable.sh" fpath.write_text("echo foo") fpath.chmod(0o755) tarfile = tmp_path / "archive.tar" tarball.compress(str(tarfile), "tar", str(tocompress)) destdir = tmp_path / "destdir" tarball.uncompress(str(tarfile), str(destdir)) (executable_path, text_path) = sorted(destdir.iterdir()) assert text_path.stat().st_mode == 0o100644 assert executable_path.stat().st_mode == 0o100755 def test__unpack_tar_failure(tmp_path, datadir): """Unpack inexistent tarball should fail """ tarpath = os.path.join(datadir, "archives", "inexistent-archive.tar.Z") assert not os.path.exists(tarpath) with pytest.raises( shutil.ReadError, match=f"Unable to uncompress {tarpath} to {tmp_path}" ): tarball._unpack_tar(tarpath, tmp_path) def test__unpack_tar_failure2(tmp_path, datadir): """Unpack Existent tarball into an inexistent folder should fail """ filename = "groff-1.02.tar.Z" tarpath = os.path.join(datadir, "archives", filename) assert os.path.exists(tarpath) extract_dir = os.path.join(tmp_path, "dir", "inexistent") with pytest.raises( shutil.ReadError, match=f"Unable to uncompress {tarpath} to {tmp_path}" ): tarball._unpack_tar(tarpath, extract_dir) def test__unpack_tar_failure3(tmp_path, datadir): """Unpack unsupported tarball should fail """ filename = "hello.zip" tarpath = os.path.join(datadir, "archives", filename) assert os.path.exists(tarpath) with pytest.raises( shutil.ReadError, match=f"Unable to uncompress {tarpath} to {tmp_path}" ): tarball._unpack_tar(tarpath, tmp_path) def test__unpack_tar(tmp_path, datadir): """Unpack supported tarball into an existent folder should be ok """ filename = "groff-1.02.tar.Z" tarpath = os.path.join(datadir, "archives", filename) assert os.path.exists(tarpath) extract_dir = os.path.join(tmp_path, filename) os.makedirs(extract_dir, exist_ok=True) output_directory = tarball._unpack_tar(tarpath, extract_dir) assert extract_dir == output_directory assert len(os.listdir(extract_dir)) > 0 def test_register_new_archive_formats(prepare_shutil_state): """Registering new archive formats should be fine """ unpack_formats_v1 = [f[0] for f in shutil.get_unpack_formats()] for format_id in tarball.ADDITIONAL_ARCHIVE_FORMATS: assert format_id[0] not in unpack_formats_v1 # when tarball.register_new_archive_formats() # then unpack_formats_v2 = [f[0] for f in shutil.get_unpack_formats()] for format_id in tarball.ADDITIONAL_ARCHIVE_FORMATS: assert format_id[0] in unpack_formats_v2 def test_uncompress_tarpaths(tmp_path, datadir, prepare_shutil_state): """High level call uncompression on un/supported tarballs """ archive_dir = os.path.join(datadir, "archives") tarfiles = os.listdir(archive_dir) tarpaths = [os.path.join(archive_dir, tarfile) for tarfile in tarfiles] unsupported_tarpaths = [] for t in tarpaths: if t.endswith(".Z") or t.endswith(".x") or t.endswith(".lz"): unsupported_tarpaths.append(t) # not supported yet for tarpath in unsupported_tarpaths: with pytest.raises(ValueError, match=f"Problem during unpacking {tarpath}."): tarball.uncompress(tarpath, dest=tmp_path) # register those unsupported formats tarball.register_new_archive_formats() # unsupported formats are now supported for n, tarpath in enumerate(tarpaths, start=1): tarball.uncompress(tarpath, dest=tmp_path) assert n == len(tarpaths) diff --git a/swh/core/utils.py b/swh/core/utils.py index 0b6fcbd..a14daa5 100644 --- a/swh/core/utils.py +++ b/swh/core/utils.py @@ -1,123 +1,122 @@ # Copyright (C) 2016-2017 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import os -import itertools import codecs -import re - from contextlib import contextmanager +import itertools +import os +import re @contextmanager def cwd(path): """Contextually change the working directory to do thy bidding. Then gets back to the original location. """ prev_cwd = os.getcwd() os.chdir(path) try: yield finally: os.chdir(prev_cwd) def grouper(iterable, n): """Collect data into fixed-length size iterables. The last block might contain less elements as it will hold only the remaining number of elements. The invariant here is that the number of elements in the input iterable and the sum of the number of elements of all iterables generated from this function should be equal. Args: iterable (Iterable): an iterable n (int): size of block to slice the iterable into Yields: fixed-length blocks as iterables. As mentioned, the last iterable might be less populated. """ args = [iter(iterable)] * n stop_value = object() for _data in itertools.zip_longest(*args, fillvalue=stop_value): yield (d for d in _data if d is not stop_value) def backslashescape_errors(exception): if isinstance(exception, UnicodeDecodeError): bad_data = exception.object[exception.start : exception.end] escaped = "".join(r"\x%02x" % x for x in bad_data) return escaped, exception.end return codecs.backslashreplace_errors(exception) codecs.register_error("backslashescape", backslashescape_errors) def encode_with_unescape(value): """Encode an unicode string containing \\x backslash escapes""" slices = [] start = 0 odd_backslashes = False i = 0 while i < len(value): if value[i] == "\\": odd_backslashes = not odd_backslashes else: if odd_backslashes: if value[i] != "x": raise ValueError( "invalid escape for %r at position %d" % (value, i - 1) ) slices.append( value[start : i - 1].replace("\\\\", "\\").encode("utf-8") ) slices.append(bytes.fromhex(value[i + 1 : i + 3])) odd_backslashes = False start = i = i + 3 continue i += 1 slices.append(value[start:i].replace("\\\\", "\\").encode("utf-8")) return b"".join(slices) def decode_with_escape(value): """Decode a bytestring as utf-8, escaping the bytes of invalid utf-8 sequences as \\x. We also escape NUL bytes as they are invalid in JSON strings. """ # escape backslashes value = value.replace(b"\\", b"\\\\") value = value.replace(b"\x00", b"\\x00") return value.decode("utf-8", "backslashescape") def commonname(path0, path1, as_str=False): """Compute the commonname between the path0 and path1. """ return path1.split(path0)[1] def numfile_sortkey(fname): """Simple function to sort filenames of the form: nnxxx.ext where nn is a number according to the numbers. Typically used to sort sql/nn-swh-xxx.sql files. """ num, rem = re.match(r"(\d*)(.*)", fname).groups() return (num and int(num) or 99, rem)