diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py index 5e416cc..a9d1c06 100644 --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -1,500 +1,501 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import abc import functools import inspect import logging import pickle from typing import ( Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union, ) from deprecated import deprecated from flask import Flask, Request, Response, abort, request import requests from werkzeug.exceptions import HTTPException from .negotiation import Formatter as FormatterBase from .negotiation import Negotiator as NegotiatorBase from .negotiation import negotiate as _negotiate from .serializers import ( exception_to_dict, json_dumps, json_loads, msgpack_dumps, msgpack_loads, ) from .serializers import decode_response from .serializers import encode_data_client as encode_data logger = logging.getLogger(__name__) # support for content negotiation class Negotiator(NegotiatorBase): def best_mimetype(self): return request.accept_mimetypes.best_match( self.accept_mimetypes, "application/json" ) def _abort(self, status_code, err=None): return abort(status_code, err) def negotiate(formatter_cls, *args, **kwargs): return _negotiate(Negotiator, formatter_cls, *args, **kwargs) class Formatter(FormatterBase): def _make_response(self, body, content_type): return Response(body, content_type=content_type) def configure(self, extra_encoders=None): self.extra_encoders = extra_encoders class JSONFormatter(Formatter): format = "json" mimetypes = ["application/json"] def render(self, obj): return json_dumps(obj, extra_encoders=self.extra_encoders) class MsgpackFormatter(Formatter): format = "msgpack" mimetypes = ["application/x-msgpack"] def render(self, obj): return msgpack_dumps(obj, extra_encoders=self.extra_encoders) # base API classes class RemoteException(Exception): """raised when remote returned an out-of-band failure notification, e.g., as a HTTP status code or serialized exception Attributes: response: HTTP response corresponding to the failure """ def __init__( self, payload: Optional[Any] = None, response: Optional[requests.Response] = None, ): if payload is not None: super().__init__(payload) else: super().__init__() self.response = response def __str__(self): if ( self.args and isinstance(self.args[0], dict) and "type" in self.args[0] and "args" in self.args[0] ): return ( f"' ) else: return super().__str__() F = TypeVar("F", bound=Callable) def remote_api_endpoint(path: str, method: str = "POST") -> Callable[[F], F]: def dec(f: F) -> F: f._endpoint_path = path # type: ignore f._method = method # type: ignore return f return dec class APIError(Exception): """API Error""" def __str__(self): return "An unexpected error occurred in the backend: {}".format(self.args) class MetaRPCClient(type): """Metaclass for RPCClient, which adds a method for each endpoint of the database it is designed to access. See for example :class:`swh.indexer.storage.api.client.RemoteStorage`""" def __new__(cls, name, bases, attributes): # For each method wrapped with @remote_api_endpoint in an API backend # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new # method in RemoteStorage, with the same documentation. # # Note that, despite the usage of decorator magic (eg. functools.wrap), # this never actually calls an IndexerStorage method. backend_class = attributes.get("backend_class", None) for base in bases: if backend_class is not None: break backend_class = getattr(base, "backend_class", None) if backend_class: for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, "_endpoint_path"): cls.__add_endpoint(meth_name, meth, attributes) return super().__new__(cls, name, bases, attributes) @staticmethod def __add_endpoint(meth_name, meth, attributes): wrapped_meth = inspect.unwrap(meth) @functools.wraps(meth) # Copy signature and doc def meth_(*args, **kwargs): # Match arguments and parameters post_data = inspect.getcallargs(wrapped_meth, *args, **kwargs) # Remove arguments that should not be passed self = post_data.pop("self") post_data.pop("cur", None) post_data.pop("db", None) # Send the request. return self._post(meth._endpoint_path, post_data) if meth_name not in attributes: attributes[meth_name] = meth_ class RPCClient(metaclass=MetaRPCClient): - """Proxy to an internal SWH RPC - - """ + """Proxy to an internal SWH RPC""" backend_class = None # type: ClassVar[Optional[type]] """For each method of `backend_class` decorated with :func:`remote_api_endpoint`, a method with the same prototype and docstring will be added to this class. Calls to this new method will be translated into HTTP requests to a remote server. This backend class will never be instantiated, it only serves as a template.""" api_exception = APIError # type: ClassVar[Type[Exception]] """The exception class to raise in case of communication error with the server.""" reraise_exceptions: ClassVar[List[Type[Exception]]] = [] """On server errors, if any of the exception classes in this list has the same name as the error name, then the exception will be instantiated and raised instead of a generic RemoteException.""" extra_type_encoders: List[Tuple[type, str, Callable]] = [] """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` to be able to serialize more object types.""" extra_type_decoders: Dict[str, Callable] = {} """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` to be able to deserialize more object types.""" def __init__( self, url, api_exception=None, timeout=None, chunk_size=4096, reraise_exceptions=None, **kwargs, ): if api_exception: self.api_exception = api_exception if reraise_exceptions: self.reraise_exceptions = reraise_exceptions base_url = url if url.endswith("/") else url + "/" self.url = base_url self.session = requests.Session() adapter = requests.adapters.HTTPAdapter( max_retries=kwargs.get("max_retries", 3), pool_connections=kwargs.get("pool_connections", 20), pool_maxsize=kwargs.get("pool_maxsize", 100), ) self.session.mount(self.url, adapter) self.timeout = timeout self.chunk_size = chunk_size def _url(self, endpoint): return "%s%s" % (self.url, endpoint) def raw_verb(self, verb, endpoint, **opts): if "chunk_size" in opts: # if the chunk_size argument has been passed, consider the user # also wants stream=True, otherwise, what's the point. opts["stream"] = True if self.timeout and "timeout" not in opts: opts["timeout"] = self.timeout try: return getattr(self.session, verb)(self._url(endpoint), **opts) except requests.exceptions.ConnectionError as e: raise self.api_exception(e) def _post(self, endpoint, data, **opts): if isinstance(data, (abc.Iterator, abc.Generator)): data = (self._encode_data(x) for x in data) else: data = self._encode_data(data) chunk_size = opts.pop("chunk_size", self.chunk_size) response = self.raw_verb( "post", endpoint, data=data, headers={ "content-type": "application/x-msgpack", "accept": "application/x-msgpack", }, **opts, ) if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": self.raise_for_status(response) return response.iter_content(chunk_size) else: return self._decode_response(response) def _encode_data(self, data): return encode_data(data, extra_encoders=self.extra_type_encoders) _post_stream = _post @deprecated(version="2.1.0", reason="Use _post instead") def post(self, *args, **kwargs): return self._post(*args, **kwargs) @deprecated(version="2.1.0", reason="Use _post_stream instead") def post_stream(self, *args, **kwargs): return self._post_stream(*args, **kwargs) def _get(self, endpoint, **opts): chunk_size = opts.pop("chunk_size", self.chunk_size) response = self.raw_verb( "get", endpoint, headers={"accept": "application/x-msgpack"}, **opts ) if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": self.raise_for_status(response) return response.iter_content(chunk_size) else: return self._decode_response(response) def _get_stream(self, endpoint, **opts): return self._get(endpoint, stream=True, **opts) @deprecated(version="2.1.0", reason="Use _get instead") def get(self, *args, **kwargs): return self._get(*args, **kwargs) @deprecated(version="2.1.0", reason="Use _get_stream instead") def get_stream(self, *args, **kwargs): return self._get_stream(*args, **kwargs) def raise_for_status(self, response) -> None: """check response HTTP status code and raise an exception if it denotes an error; do nothing otherwise """ status_code = response.status_code status_class = response.status_code // 100 if status_code == 404: raise RemoteException(payload="404 not found", response=response) exception = None # TODO: only old servers send pickled error; stop trying to unpickle # after they are all upgraded try: if status_class == 4: data = self._decode_response(response, check_status=False) if isinstance(data, dict): # TODO: remove "exception" key check once all servers # are using new schema exc_data = data["exception"] if "exception" in data else data for exc_type in self.reraise_exceptions: if exc_type.__name__ == exc_data["type"]: exception = exc_type(*exc_data["args"]) break else: exception = RemoteException(payload=exc_data, response=response) else: exception = pickle.loads(data) elif status_class == 5: data = self._decode_response(response, check_status=False) if "exception_pickled" in data: exception = pickle.loads(data["exception_pickled"]) else: # TODO: remove "exception" key check once all servers # are using new schema exc_data = data["exception"] if "exception" in data else data exception = RemoteException(payload=exc_data, response=response) except (TypeError, pickle.UnpicklingError): raise RemoteException(payload=data, response=response) if exception: raise exception from None if status_class != 2: raise RemoteException( payload=f"API HTTP error: {status_code} {response.content}", response=response, ) def _decode_response(self, response, check_status=True): if check_status: self.raise_for_status(response) return decode_response(response, extra_decoders=self.extra_type_decoders) def __repr__(self): return "<{} url={}>".format(self.__class__.__name__, self.url) class BytesRequest(Request): """Request with proper escaping of arbitrary byte sequences.""" encoding = "utf-8" encoding_errors = "surrogateescape" ENCODERS: Dict[str, Callable[[Any], Union[bytes, str]]] = { "application/x-msgpack": msgpack_dumps, "application/json": json_dumps, } def encode_data_server( data, content_type="application/x-msgpack", extra_type_encoders=None ): encoded_data = ENCODERS[content_type](data, extra_encoders=extra_type_encoders) - return Response(encoded_data, mimetype=content_type,) + return Response( + encoded_data, + mimetype=content_type, + ) def decode_request(request, extra_decoders=None): content_type = request.mimetype data = request.get_data() if not data: return {} if content_type == "application/x-msgpack": r = msgpack_loads(data, extra_decoders=extra_decoders) elif content_type == "application/json": # XXX this .decode() is needed for py35. # Should not be needed any more with py37 r = json_loads(data.decode("utf-8"), extra_decoders=extra_decoders) else: raise ValueError("Wrong content type `%s` for API request" % content_type) return r def error_handler(exception, encoder, status_code=500): logging.exception(exception) response = encoder(exception_to_dict(exception)) if isinstance(exception, HTTPException): response.status_code = exception.code else: # TODO: differentiate between server errors and client errors response.status_code = status_code return response class RPCServerApp(Flask): """For each endpoint of the given `backend_class`, tells app.route to call a function that decodes the request and sends it to the backend object provided by the factory. :param Any backend_class: The class of the backend, which will be analyzed to look for API endpoints. :param Optional[Callable[[], backend_class]] backend_factory: A function with no argument that returns an instance of `backend_class`. If unset, defaults to calling `backend_class` constructor directly. For each method 'do_x()' of the ``backend_factory``, subclasses may implement two methods: ``pre_do_x(self, kw)`` and ``post_do_x(self, ret, kw)`` that will be called respectively before and after ``do_x(**kw)``. ``kw`` is the dict of request parameters, and ``ret`` is the return value of ``do_x(**kw)``. """ request_class = BytesRequest extra_type_encoders: List[Tuple[type, str, Callable]] = [] """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` to be able to serialize more object types.""" extra_type_decoders: Dict[str, Callable] = {} """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` to be able to deserialize more object types.""" method_decorators: List[Callable[[Callable], Callable]] = [] """List of decorators to all methods generated from the ``backend_class``.""" def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs): super().__init__(*args, **kwargs) self.add_backend_class(backend_class, backend_factory) def add_backend_class(self, backend_class=None, backend_factory=None): if backend_class is None and backend_factory is not None: raise ValueError( "backend_factory should only be provided if backend_class is" ) if backend_class is not None: backend_factory = backend_factory or backend_class for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, "_endpoint_path"): self.__add_endpoint(meth_name, meth, backend_factory) def __add_endpoint(self, meth_name, meth, backend_factory): from flask import request @negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders) @negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders) @functools.wraps(meth) # Copy signature and doc def f(): # Call the actual code pre_hook = getattr(self, f"pre_{meth_name}", None) post_hook = getattr(self, f"post_{meth_name}", None) obj_meth = getattr(backend_factory(), meth_name) kw = decode_request(request, extra_decoders=self.extra_type_decoders) if pre_hook is not None: pre_hook(kw) ret = obj_meth(**kw) if post_hook is not None: post_hook(ret, kw) return ret for decorator in self.method_decorators: f = decorator(f) self.route("/" + meth._endpoint_path, methods=["POST"])(f) diff --git a/swh/core/api/asynchronous.py b/swh/core/api/asynchronous.py index 65f3916..4e2188d 100644 --- a/swh/core/api/asynchronous.py +++ b/swh/core/api/asynchronous.py @@ -1,186 +1,183 @@ # Copyright (C) 2017-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import OrderedDict import functools import logging from typing import Callable, Dict, List, Optional, Tuple, Type, Union import aiohttp.web from aiohttp_utils import Response, negotiation from deprecated import deprecated import multidict from .serializers import ( exception_to_dict, json_dumps, json_loads, msgpack_dumps, msgpack_loads, ) def encode_msgpack(data, **kwargs): return aiohttp.web.Response( body=msgpack_dumps(data), headers=multidict.MultiDict({"Content-Type": "application/x-msgpack"}), **kwargs, ) encode_data_server = Response def render_msgpack(request, data, extra_encoders=None): return msgpack_dumps(data, extra_encoders=extra_encoders) def render_json(request, data, extra_encoders=None): return json_dumps(data, extra_encoders=extra_encoders) def decode_data(data, content_type, extra_decoders=None): - """Decode data according to content type, eventually using some extra decoders. - - """ + """Decode data according to content type, eventually using some extra decoders.""" if not data: return {} if content_type == "application/x-msgpack": r = msgpack_loads(data, extra_decoders=extra_decoders) elif content_type == "application/json": r = json_loads(data, extra_decoders=extra_decoders) else: raise ValueError(f"Wrong content type `{content_type}` for API request") return r async def decode_request(request, extra_decoders=None): - """Decode asynchronously the request - - """ + """Decode asynchronously the request""" data = await request.read() return decode_data(data, request.content_type, extra_decoders=extra_decoders) async def error_middleware(app, handler): async def middleware_handler(request): try: return await handler(request) except Exception as e: if isinstance(e, aiohttp.web.HTTPException): raise logging.exception(e) res = exception_to_dict(e) if isinstance(e, app.client_exception_classes): status = 400 else: status = 500 return encode_data_server(res, status=status) return middleware_handler class RPCServerApp(aiohttp.web.Application): """For each endpoint of the given `backend_class`, tells app.route to call a function that decodes the request and sends it to the backend object provided by the factory. :param Any backend_class: The class of the backend, which will be analyzed to look for API endpoints. :param Optional[Callable[[], backend_class]] backend_factory: A function with no argument that returns an instance of `backend_class`. If unset, defaults to calling `backend_class` constructor directly. """ client_exception_classes: Tuple[Type[Exception], ...] = () """Exceptions that should be handled as a client error (eg. object not found, invalid argument)""" extra_type_encoders: List[Tuple[type, str, Callable]] = [] """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` to be able to serialize more object types.""" extra_type_decoders: Dict[str, Callable] = {} """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` to be able to deserialize more object types.""" def __init__( self, app_name: Optional[str] = None, backend_class: Optional[Callable] = None, backend_factory: Optional[Union[Callable, str]] = None, middlewares=(), **kwargs, ): nego_middleware = negotiation.negotiation_middleware( renderers=self._renderers(), force_rendering=True ) - middlewares = (nego_middleware, error_middleware,) + middlewares + middlewares = ( + nego_middleware, + error_middleware, + ) + middlewares super().__init__(middlewares=middlewares, **kwargs) # swh decorations starts here self.app_name = app_name if backend_class is None and backend_factory is not None: raise ValueError( "backend_factory should only be provided if backend_class is" ) self.backend_class = backend_class if backend_class is not None: backend_factory = backend_factory or backend_class for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, "_endpoint_path"): path = meth._endpoint_path http_method = meth._method path = path if path.startswith("/") else f"/{path}" self.router.add_route( http_method, path, self._endpoint(meth_name, meth, backend_factory), ) def _renderers(self): """Return an ordered list of renderers in order of increasing desirability (!) See mimetype.best_match() docstring """ return OrderedDict( [ ( "application/json", lambda request, data: render_json( request, data, extra_encoders=self.extra_type_encoders ), ), ( "application/x-msgpack", lambda request, data: render_msgpack( request, data, extra_encoders=self.extra_type_encoders ), ), ] ) def _endpoint(self, meth_name, meth, backend_factory): - """Create endpoint out of the method `meth`. - - """ + """Create endpoint out of the method `meth`.""" @functools.wraps(meth) # Copy signature and doc async def decorated_meth(request, *args, **kwargs): obj_meth = getattr(backend_factory(), meth_name) data = await request.read() kw = decode_data( data, request.content_type, extra_decoders=self.extra_type_decoders ) result = obj_meth(**kw) return encode_data_server(result) return decorated_meth @deprecated(version="0.0.64", reason="Use the RPCServerApp instead") class SWHRemoteAPI(RPCServerApp): pass diff --git a/swh/core/api/classes.py b/swh/core/api/classes.py index 17ed731..0d6a716 100644 --- a/swh/core/api/classes.py +++ b/swh/core/api/classes.py @@ -1,61 +1,57 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from dataclasses import dataclass, field import itertools from typing import Callable, Generic, Iterable, List, Optional, TypeVar TResult = TypeVar("TResult") TToken = TypeVar("TToken") @dataclass(eq=True) class PagedResult(Generic[TResult, TToken]): """Represents a page of results; with a token to get the next page""" results: List[TResult] = field(default_factory=list) next_page_token: Optional[TToken] = field(default=None) def _stream_results(f, *args, page_token, **kwargs): """Helper for stream_results() and stream_results_optional()""" while True: page_result = f(*args, page_token=page_token, **kwargs) yield from page_result.results page_token = page_result.next_page_token if page_token is None: break def stream_results( f: Callable[..., PagedResult[TResult, TToken]], *args, **kwargs ) -> Iterable[TResult]: - """Consume the paginated result and stream the page results - - """ + """Consume the paginated result and stream the page results""" if "page_token" in kwargs: raise TypeError('stream_results has no argument "page_token".') yield from _stream_results(f, *args, page_token=None, **kwargs) def stream_results_optional( f: Callable[..., Optional[PagedResult[TResult, TToken]]], *args, **kwargs ) -> Optional[Iterable[TResult]]: - """Like stream_results(), but for functions ``f`` that return an Optional. - - """ + """Like stream_results(), but for functions ``f`` that return an Optional.""" if "page_token" in kwargs: raise TypeError('stream_results_optional has no argument "page_token".') res = f(*args, page_token=None, **kwargs) if res is None: return None else: if res.next_page_token is None: return iter(res.results) else: return itertools.chain( res.results, _stream_results(f, *args, page_token=res.next_page_token, **kwargs), ) diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py index 2c60885..37dad4b 100644 --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -1,322 +1,325 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 import datetime from enum import Enum import json import traceback import types from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from uuid import UUID import iso8601 import msgpack from requests import Response from swh.core.api.classes import PagedResult def encode_datetime(dt: datetime.datetime) -> str: """Wrapper of datetime.datetime.isoformat() that forbids naive datetimes.""" if dt.tzinfo is None: raise TypeError("can not serialize naive 'datetime.datetime' object") return dt.isoformat() def _encode_paged_result(obj: PagedResult) -> Dict[str, Any]: """Serialize PagedResult to a Dict.""" return { "results": obj.results, "next_page_token": obj.next_page_token, } def _decode_paged_result(obj: Dict[str, Any]) -> PagedResult: """Deserialize Dict into PagedResult""" - return PagedResult(results=obj["results"], next_page_token=obj["next_page_token"],) + return PagedResult( + results=obj["results"], + next_page_token=obj["next_page_token"], + ) def exception_to_dict(exception: Exception) -> Dict[str, Any]: tb = traceback.format_exception(None, exception, exception.__traceback__) exc_type = type(exception) return { "type": exc_type.__name__, "module": exc_type.__module__, "args": exception.args, "message": str(exception), "traceback": tb, } def dict_to_exception(exc_dict: Dict[str, Any]) -> Exception: temp = __import__(exc_dict["module"], fromlist=[exc_dict["type"]]) try: return getattr(temp, exc_dict["type"])(*exc_dict["args"]) except Exception: # custom Exception type cannot be rebuilt, fallback to base Exception type return Exception(exc_dict["message"]) def encode_timedelta(td: datetime.timedelta) -> Dict[str, int]: return { "days": td.days, "seconds": td.seconds, "microseconds": td.microseconds, } ENCODERS: List[Tuple[type, str, Callable]] = [ (UUID, "uuid", str), (datetime.timedelta, "timedelta", encode_timedelta), (PagedResult, "paged_result", _encode_paged_result), (Exception, "exception", exception_to_dict), ] JSON_ENCODERS: List[Tuple[type, str, Callable]] = [ (datetime.datetime, "datetime", encode_datetime), (bytes, "bytes", lambda o: base64.b85encode(o).decode("ascii")), ] DECODERS: Dict[str, Callable] = { "timedelta": lambda d: datetime.timedelta(**d), "uuid": UUID, "paged_result": _decode_paged_result, "exception": dict_to_exception, # for BW compat, to be moved in JSON_DECODERS ASAP "datetime": lambda d: iso8601.parse_date(d, default_timezone=None), } JSON_DECODERS: Dict[str, Callable] = { "bytes": base64.b85decode, } def get_encoders( extra_encoders: Optional[List[Tuple[Type, str, Callable]]], with_json: bool = False ) -> List[Tuple[Type, str, Callable]]: encoders = ENCODERS if with_json: encoders = [*encoders, *JSON_ENCODERS] if extra_encoders: encoders = [*encoders, *extra_encoders] return encoders def get_decoders( extra_decoders: Optional[Dict[str, Callable]], with_json: bool = False ) -> Dict[str, Callable]: decoders = DECODERS if with_json: decoders = {**decoders, **JSON_DECODERS} if extra_decoders is not None: decoders = {**decoders, **extra_decoders} return decoders class MsgpackExtTypeCodes(Enum): LONG_INT = 1 LONG_NEG_INT = 2 def encode_data_client(data: Any, extra_encoders=None) -> bytes: try: return msgpack_dumps(data, extra_encoders=extra_encoders) except OverflowError as e: raise ValueError("Limits were reached. Please, check your input.\n" + str(e)) def decode_response(response: Response, extra_decoders=None) -> Any: content_type = response.headers["content-type"] if content_type.startswith("application/x-msgpack"): r = msgpack_loads(response.content, extra_decoders=extra_decoders) elif content_type.startswith("application/json"): r = json_loads(response.text, extra_decoders=extra_decoders) elif content_type.startswith("text/"): r = response.text else: raise ValueError("Wrong content type `%s` for API response" % content_type) return r class SWHJSONEncoder(json.JSONEncoder): """JSON encoder for data structures generated by Software Heritage. This JSON encoder extends the default Python JSON encoder and adds awareness for the following specific types: - bytes (get encoded as a Base85 string); - datetime.datetime (get encoded as an ISO8601 string). Non-standard types get encoded as a a dictionary with two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. SWHJSONEncoder also encodes arbitrary iterables as a list (allowing serialization of generators). Caveats: Limitations in the JSONEncoder extension mechanism prevent us from "escaping" dictionaries that only contain the swhtype and d keys, and therefore arbitrary data structures can't be round-tripped through SWHJSONEncoder and SWHJSONDecoder. """ def __init__(self, extra_encoders=None, **kwargs): super().__init__(**kwargs) self.encoders = get_encoders(extra_encoders, with_json=True) def default(self, o: Any) -> Union[Dict[str, Union[Dict[str, int], str]], list]: for (type_, type_name, encoder) in self.encoders: if isinstance(o, type_): return { "swhtype": type_name, "d": encoder(o), } try: return super().default(o) except TypeError as e: try: iterable = iter(o) except TypeError: raise e from None else: return list(iterable) class SWHJSONDecoder(json.JSONDecoder): """JSON decoder for data structures encoded with SWHJSONEncoder. This JSON decoder extends the default Python JSON decoder, allowing the decoding of: - bytes (encoded as a Base85 string); - datetime.datetime (encoded as an ISO8601 string). Non-standard types must be encoded as a a dictionary with exactly two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. To limit the impact our encoding, if the swhtype key doesn't contain a known value, the dictionary is decoded as-is. """ def __init__(self, extra_decoders=None, **kwargs): super().__init__(**kwargs) self.decoders = get_decoders(extra_decoders, with_json=True) def decode_data(self, o: Any) -> Any: if isinstance(o, dict): if set(o.keys()) == {"d", "swhtype"}: if o["swhtype"] == "bytes": return base64.b85decode(o["d"]) decoder = self.decoders.get(o["swhtype"]) if decoder: return decoder(self.decode_data(o["d"])) return {key: self.decode_data(value) for key, value in o.items()} if isinstance(o, list): return [self.decode_data(value) for value in o] else: return o def raw_decode(self, s: str, idx: int = 0) -> Tuple[Any, int]: data, index = super().raw_decode(s, idx) return self.decode_data(data), index def json_dumps(data: Any, extra_encoders=None) -> str: return json.dumps(data, cls=SWHJSONEncoder, extra_encoders=extra_encoders) def json_loads(data: str, extra_decoders=None) -> Any: return json.loads(data, cls=SWHJSONDecoder, extra_decoders=extra_decoders) def msgpack_dumps(data: Any, extra_encoders=None) -> bytes: """Write data as a msgpack stream""" encoders = get_encoders(extra_encoders) def encode_types(obj): if isinstance(obj, int): # integer overflowed while packing. Handle it as an extended type if obj > 0: code = MsgpackExtTypeCodes.LONG_INT.value else: code = MsgpackExtTypeCodes.LONG_NEG_INT.value obj = -obj length, rem = divmod(obj.bit_length(), 8) if rem: length += 1 return msgpack.ExtType(code, int.to_bytes(obj, length, "big")) if isinstance(obj, types.GeneratorType): return list(obj) for (type_, type_name, encoder) in encoders: if isinstance(obj, type_): return { b"swhtype": type_name, b"d": encoder(obj), } return obj return msgpack.packb( data, use_bin_type=True, datetime=True, # encode datetime as msgpack.Timestamp default=encode_types, ) def msgpack_loads(data: bytes, extra_decoders=None) -> Any: """Read data as a msgpack stream. .. Caution:: This function is used by swh.journal to decode the contents of the journal. This function **must** be kept backwards-compatible. """ decoders = get_decoders(extra_decoders) def ext_hook(code, data): if code == MsgpackExtTypeCodes.LONG_INT.value: return int.from_bytes(data, "big") elif code == MsgpackExtTypeCodes.LONG_NEG_INT.value: return -int.from_bytes(data, "big") raise ValueError("Unknown msgpack extended code %s" % code) def decode_types(obj): # Support for current encodings if set(obj.keys()) == {b"d", b"swhtype"}: decoder = decoders.get(obj[b"swhtype"]) if decoder: return decoder(obj[b"d"]) # Fallthrough return obj try: try: return msgpack.unpackb( data, raw=False, object_hook=decode_types, ext_hook=ext_hook, strict_map_key=False, timestamp=3, # convert Timestamp in datetime objects (tz UTC) ) except TypeError: # msgpack < 0.6.0 return msgpack.unpackb( data, raw=False, object_hook=decode_types, ext_hook=ext_hook ) except TypeError: # msgpack < 0.5.2 return msgpack.unpackb( data, encoding="utf-8", object_hook=decode_types, ext_hook=ext_hook ) diff --git a/swh/core/api/tests/server_testing.py b/swh/core/api/tests/server_testing.py index 0088a36..4b9aafe 100644 --- a/swh/core/api/tests/server_testing.py +++ b/swh/core/api/tests/server_testing.py @@ -1,146 +1,142 @@ # Copyright (C) 2015-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import abc import multiprocessing import time import unittest.mock from urllib.request import urlopen import aiohttp import aiohttp.test_utils class ServerTestFixtureBaseClass(metaclass=abc.ABCMeta): """Base class for http client/server testing implementations. Override this class to implement the following methods: - process_config: to do something needed for the server configuration (e.g propagate the configuration to other part) - define_worker_function: define the function that will actually run the server. To ensure test isolation, each test will run in a different server and a different folder. In order to correctly work, the subclass must call the parents class's setUp() and tearDown() methods. """ def setUp(self): super().setUp() self.start_server() def tearDown(self): self.stop_server() super().tearDown() def url(self): return "http://127.0.0.1:%d/" % self.port def process_config(self): """Process the server's configuration. Do something useful for example, pass along the self.config dictionary inside the self.app. By default, do nothing. """ pass @abc.abstractmethod def define_worker_function(self, app, port): - """Define how the actual implementation server will run. - - """ + """Define how the actual implementation server will run.""" pass def start_server(self): - """ Spawn the API server using multiprocessing. - """ + """Spawn the API server using multiprocessing.""" self.process = None self.process_config() self.port = aiohttp.test_utils.unused_port() worker_fn = self.define_worker_function() self.process = multiprocessing.Process( target=worker_fn, args=(self.app, self.port) ) self.process.start() # Wait max 5 seconds for server to spawn i = 0 while i < 500: try: urlopen(self.url()) except Exception: i += 1 time.sleep(0.01) else: return def stop_server(self): - """ Terminate the API server's process. - """ + """Terminate the API server's process.""" if self.process: self.process.terminate() class ServerTestFixture(ServerTestFixtureBaseClass): """Base class for http client/server testing (e.g flask). Mix this in a test class in order to have access to an http server running in background. Note that the subclass should define a dictionary in self.config that contains the server config. And an application in self.app that corresponds to the type of server the tested client needs. To ensure test isolation, each test will run in a different server and a different folder. In order to correctly work, the subclass must call the parents class's setUp() and tearDown() methods. """ def process_config(self): # WSGI app configuration for key, value in self.config.items(): self.app.config[key] = value def define_worker_function(self): def worker(app, port): # Make Flask 1.0 stop printing its server banner with unittest.mock.patch("flask.cli.show_server_banner"): return app.run(port=port, use_reloader=False) return worker class ServerTestFixtureAsync(ServerTestFixtureBaseClass): """Base class for http client/server async testing (e.g aiohttp). Mix this in a test class in order to have access to an http server running in background. Note that the subclass should define an application in self.app that corresponds to the type of server the tested client needs. To ensure test isolation, each test will run in a different server and a different folder. In order to correctly work, the subclass must call the parents class's setUp() and tearDown() methods. """ def define_worker_function(self): def worker(app, port): return aiohttp.web.run_app(app, port=int(port), print=lambda *_: None) return worker diff --git a/swh/core/api/tests/test_gunicorn.py b/swh/core/api/tests/test_gunicorn.py index 92a3284..683a6e8 100644 --- a/swh/core/api/tests/test_gunicorn.py +++ b/swh/core/api/tests/test_gunicorn.py @@ -1,117 +1,121 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os from unittest.mock import patch import pkg_resources import swh.core.api.gunicorn_config as gunicorn_config def test_post_fork_default(): with patch("sentry_sdk.init") as sentry_sdk_init: gunicorn_config.post_fork(None, None) sentry_sdk_init.assert_not_called() def test_post_fork_with_dsn_env(): flask_integration = object() # unique object to check for equality with patch( "sentry_sdk.integrations.flask.FlaskIntegration", new=lambda: flask_integration ): with patch("sentry_sdk.init") as sentry_sdk_init: with patch.dict(os.environ, {"SWH_SENTRY_DSN": "test_dsn"}): gunicorn_config.post_fork(None, None) sentry_sdk_init.assert_called_once_with( dsn="test_dsn", integrations=[flask_integration], debug=False, release=None, environment=None, ) def test_post_fork_with_package_env(): flask_integration = object() # unique object to check for equality with patch( "sentry_sdk.integrations.flask.FlaskIntegration", new=lambda: flask_integration ): with patch("sentry_sdk.init") as sentry_sdk_init: with patch.dict( os.environ, { "SWH_SENTRY_DSN": "test_dsn", "SWH_SENTRY_ENVIRONMENT": "tests", "SWH_MAIN_PACKAGE": "swh.core", }, ): gunicorn_config.post_fork(None, None) version = pkg_resources.get_distribution("swh.core").version sentry_sdk_init.assert_called_once_with( dsn="test_dsn", integrations=[flask_integration], debug=False, release="swh.core@" + version, environment="tests", ) def test_post_fork_debug(): flask_integration = object() # unique object to check for equality with patch( "sentry_sdk.integrations.flask.FlaskIntegration", new=lambda: flask_integration ): with patch("sentry_sdk.init") as sentry_sdk_init: with patch.dict( os.environ, {"SWH_SENTRY_DSN": "test_dsn", "SWH_SENTRY_DEBUG": "1"} ): gunicorn_config.post_fork(None, None) sentry_sdk_init.assert_called_once_with( dsn="test_dsn", integrations=[flask_integration], debug=True, release=None, environment=None, ) def test_post_fork_no_flask(): with patch("sentry_sdk.init") as sentry_sdk_init: with patch.dict(os.environ, {"SWH_SENTRY_DSN": "test_dsn"}): gunicorn_config.post_fork(None, None, flask=False) sentry_sdk_init.assert_called_once_with( - dsn="test_dsn", integrations=[], debug=False, release=None, environment=None, + dsn="test_dsn", + integrations=[], + debug=False, + release=None, + environment=None, ) def test_post_fork_extras(): flask_integration = object() # unique object to check for equality with patch( "sentry_sdk.integrations.flask.FlaskIntegration", new=lambda: flask_integration ): with patch("sentry_sdk.init") as sentry_sdk_init: with patch.dict(os.environ, {"SWH_SENTRY_DSN": "test_dsn"}): gunicorn_config.post_fork( None, None, sentry_integrations=["foo"], extra_sentry_kwargs={"bar": "baz"}, ) sentry_sdk_init.assert_called_once_with( dsn="test_dsn", integrations=["foo", flask_integration], debug=False, bar="baz", release=None, environment=None, ) diff --git a/swh/core/api/tests/test_rpc_client.py b/swh/core/api/tests/test_rpc_client.py index 747ee67..6376f87 100644 --- a/swh/core/api/tests/test_rpc_client.py +++ b/swh/core/api/tests/test_rpc_client.py @@ -1,177 +1,182 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import re import pytest from requests.exceptions import ConnectionError from swh.core.api import APIError, RemoteException, RPCClient, remote_api_endpoint from swh.core.api.serializers import exception_to_dict, msgpack_dumps from .test_serializers import ExtraType, extra_decoders, extra_encoders class ReraiseException(Exception): pass @pytest.fixture def rpc_client(requests_mock): class TestStorage: @remote_api_endpoint("test_endpoint_url") def test_endpoint(self, test_data, db=None, cur=None): ... @remote_api_endpoint("path/to/endpoint") def something(self, data, db=None, cur=None): ... @remote_api_endpoint("serializer_test") def serializer_test(self, data, db=None, cur=None): ... @remote_api_endpoint("overridden/endpoint") def overridden_method(self, data): return "foo" class Testclient(RPCClient): backend_class = TestStorage extra_type_encoders = extra_encoders extra_type_decoders = extra_decoders reraise_exceptions = [ReraiseException] def overridden_method(self, data): return "bar" def callback(request, context): assert request.headers["Content-Type"] == "application/x-msgpack" context.headers["Content-Type"] = "application/x-msgpack" if request.path == "/test_endpoint_url": context.content = b"\xa3egg" elif request.path == "/path/to/endpoint": context.content = b"\xa4spam" elif request.path == "/serializer_test": context.content = ( b"\x82\xc4\x07swhtype\xa9extratype" b"\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux" ) else: assert False return context.content requests_mock.post(re.compile("mock://example.com/"), content=callback) return Testclient(url="mock://example.com") def test_client(rpc_client): assert hasattr(rpc_client, "test_endpoint") assert hasattr(rpc_client, "something") res = rpc_client.test_endpoint("spam") assert res == "egg" res = rpc_client.test_endpoint(test_data="spam") assert res == "egg" res = rpc_client.something("whatever") assert res == "spam" res = rpc_client.something(data="whatever") assert res == "spam" def test_client_extra_serializers(rpc_client): res = rpc_client.serializer_test(["foo", ExtraType("bar", b"baz")]) assert res == ExtraType({"spam": "egg"}, "qux") def test_client_overridden_method(rpc_client): res = rpc_client.overridden_method("foo") assert res == "bar" def test_client_connexion_error(rpc_client, requests_mock): """ ConnectionError should be wrapped and raised as an APIError. """ error_message = "unreachable host" requests_mock.post( re.compile("mock://example.com/connection_error"), exc=ConnectionError(error_message), ) with pytest.raises(APIError) as exc_info: rpc_client._post("connection_error", data={}) assert type(exc_info.value.args[0]) == ConnectionError assert str(exc_info.value.args[0]) == error_message def _exception_response(exception, status_code, old_exception_schema=False): def callback(request, context): assert request.headers["Content-Type"] == "application/x-msgpack" context.headers["Content-Type"] = "application/x-msgpack" exc_dict = exception_to_dict(exception) if old_exception_schema: exc_dict = {"exception": exc_dict} context.content = msgpack_dumps(exc_dict) context.status_code = status_code return context.content return callback @pytest.mark.parametrize("old_exception_schema", [False, True]) def test_client_reraise_exception(rpc_client, requests_mock, old_exception_schema): """ Exception caught server-side and whitelisted will be raised again client-side. """ error_message = "something went wrong" endpoint = "reraise_exception" requests_mock.post( re.compile(f"mock://example.com/{endpoint}"), content=_exception_response( exception=ReraiseException(error_message), status_code=400, old_exception_schema=old_exception_schema, ), ) with pytest.raises(ReraiseException) as exc_info: rpc_client._post(endpoint, data={}) assert str(exc_info.value) == error_message @pytest.mark.parametrize( "status_code, old_exception_schema", - [(400, False), (500, False), (400, True), (500, True),], + [ + (400, False), + (500, False), + (400, True), + (500, True), + ], ) def test_client_raise_remote_exception( rpc_client, requests_mock, status_code, old_exception_schema ): """ Exception caught server-side and not whitelisted will be wrapped and raised as a RemoteException client-side. """ error_message = "something went wrong" endpoint = "raise_remote_exception" requests_mock.post( re.compile(f"mock://example.com/{endpoint}"), content=_exception_response( exception=Exception(error_message), status_code=status_code, old_exception_schema=old_exception_schema, ), ) with pytest.raises(RemoteException) as exc_info: rpc_client._post(endpoint, data={}) assert str(exc_info.value.args[0]["type"]) == "Exception" assert str(exc_info.value.args[0]["message"]) == error_message diff --git a/swh/core/api/tests/test_rpc_server.py b/swh/core/api/tests/test_rpc_server.py index 81bb573..dfe02af 100644 --- a/swh/core/api/tests/test_rpc_server.py +++ b/swh/core/api/tests/test_rpc_server.py @@ -1,147 +1,150 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import json from flask import url_for import msgpack import pytest from swh.core.api import ( JSONFormatter, MsgpackFormatter, RPCServerApp, negotiate, remote_api_endpoint, ) from .test_serializers import ExtraType, extra_decoders, extra_encoders class MyRPCServerApp(RPCServerApp): extra_type_encoders = extra_encoders extra_type_decoders = extra_decoders class TestStorage: @remote_api_endpoint("test_endpoint_url") def endpoint_test(self, test_data, db=None, cur=None): assert test_data == "spam" return "egg" @remote_api_endpoint("path/to/endpoint") def something(self, data, db=None, cur=None): return data @remote_api_endpoint("serializer_test") def serializer_test(self, data, db=None, cur=None): assert data == ["foo", ExtraType("bar", b"baz")] return ExtraType({"spam": "egg"}, "qux") @pytest.fixture def app(): return MyRPCServerApp("testapp", backend_class=TestStorage) def test_api_rpc_server_app_ok(app): assert isinstance(app, MyRPCServerApp) actual_rpc_server2 = MyRPCServerApp( "app2", backend_class=TestStorage, backend_factory=TestStorage ) assert isinstance(actual_rpc_server2, MyRPCServerApp) actual_rpc_server3 = MyRPCServerApp("app3") assert isinstance(actual_rpc_server3, MyRPCServerApp) def test_api_rpc_server_app_misconfigured(): expected_error = "backend_factory should only be provided if backend_class is" with pytest.raises(ValueError, match=expected_error): MyRPCServerApp("failed-app", backend_factory="something-to-make-it-raise") def test_api_endpoint(flask_app_client): res = flask_app_client.post( url_for("something"), headers=[("Content-Type", "application/json"), ("Accept", "application/json")], data=json.dumps({"data": "toto"}), ) assert res.status_code == 200 assert res.mimetype == "application/json" def test_api_nego_default(flask_app_client): res = flask_app_client.post( url_for("something"), headers=[("Content-Type", "application/json")], data=json.dumps({"data": "toto"}), ) assert res.status_code == 200 assert res.mimetype == "application/json" assert res.data == b'"toto"' def test_api_nego_accept(flask_app_client): res = flask_app_client.post( url_for("something"), headers=[ ("Accept", "application/x-msgpack"), ("Content-Type", "application/x-msgpack"), ], data=msgpack.dumps({"data": "toto"}), ) assert res.status_code == 200 assert res.mimetype == "application/x-msgpack" assert res.data == b"\xa4toto" def test_rpc_server(flask_app_client): res = flask_app_client.post( url_for("endpoint_test"), headers=[ ("Content-Type", "application/x-msgpack"), ("Accept", "application/x-msgpack"), ], data=b"\x81\xa9test_data\xa4spam", ) assert res.status_code == 200 assert res.mimetype == "application/x-msgpack" assert res.data == b"\xa3egg" def test_rpc_server_extra_serializers(flask_app_client): res = flask_app_client.post( url_for("serializer_test"), headers=[ ("Content-Type", "application/x-msgpack"), ("Accept", "application/x-msgpack"), ], data=b"\x81\xa4data\x92\xa3foo\x82\xc4\x07swhtype\xa9extratype" b"\xc4\x01d\x92\xa3bar\xc4\x03baz", ) assert res.status_code == 200 assert res.mimetype == "application/x-msgpack" assert res.data == ( b"\x82\xc4\x07swhtype\xa9extratype\xc4" b"\x01d\x92\x81\xa4spam\xa3egg\xa3qux" ) def test_api_negotiate_no_extra_encoders(app, flask_app_client): url = "/test/negotiate/no/extra/encoders" @app.route(url, methods=["POST"]) @negotiate(MsgpackFormatter) @negotiate(JSONFormatter) def endpoint(): return "test" - res = flask_app_client.post(url, headers=[("Content-Type", "application/json")],) + res = flask_app_client.post( + url, + headers=[("Content-Type", "application/json")], + ) assert res.status_code == 200 assert res.mimetype == "application/json" assert res.data == b'"test"' diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py index a7e8285..d340d11 100644 --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -1,284 +1,308 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import json from typing import Any, Callable, List, Tuple, Union from uuid import UUID import msgpack import pytest import requests from requests.exceptions import ConnectionError from swh.core.api.classes import PagedResult from swh.core.api.serializers import ( ENCODERS, decode_response, json_dumps, json_loads, msgpack_dumps, msgpack_loads, ) class ExtraType: def __init__(self, arg1, arg2): self.arg1 = arg1 self.arg2 = arg2 def __repr__(self): return f"ExtraType({self.arg1}, {self.arg2})" def __eq__(self, other): return isinstance(other, ExtraType) and (self.arg1, self.arg2) == ( other.arg1, other.arg2, ) extra_encoders: List[Tuple[type, str, Callable[..., Any]]] = [ (ExtraType, "extratype", lambda o: (o.arg1, o.arg2)) ] extra_decoders = { "extratype": lambda o: ExtraType(*o), } TZ = datetime.timezone(datetime.timedelta(minutes=118)) DATA_BYTES = b"123456789\x99\xaf\xff\x00\x12" ENCODED_DATA_BYTES = {"swhtype": "bytes", "d": "F)}kWH8wXmIhn8j01^"} -DATA_DATETIME = datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, tzinfo=TZ,) +DATA_DATETIME = datetime.datetime( + 2015, + 3, + 4, + 18, + 25, + 13, + 1234, + tzinfo=TZ, +) ENCODED_DATA_DATETIME = { "swhtype": "datetime", "d": "2015-03-04T18:25:13.001234+01:58", } DATA_TIMEDELTA = datetime.timedelta(64) ENCODED_DATA_TIMEDELTA = { "swhtype": "timedelta", "d": {"days": 64, "seconds": 0, "microseconds": 0}, } DATA_UUID = UUID("cdd8f804-9db6-40c3-93ab-5955d3836234") ENCODED_DATA_UUID = {"swhtype": "uuid", "d": "cdd8f804-9db6-40c3-93ab-5955d3836234"} # For test demonstration purposes TestPagedResultStr = PagedResult[ Union[UUID, datetime.datetime, datetime.timedelta], str ] DATA_PAGED_RESULT = TestPagedResultStr( - results=[DATA_UUID, DATA_DATETIME, DATA_TIMEDELTA], next_page_token="10", + results=[DATA_UUID, DATA_DATETIME, DATA_TIMEDELTA], + next_page_token="10", ) ENCODED_DATA_PAGED_RESULT = { "d": { - "results": [ENCODED_DATA_UUID, ENCODED_DATA_DATETIME, ENCODED_DATA_TIMEDELTA,], + "results": [ + ENCODED_DATA_UUID, + ENCODED_DATA_DATETIME, + ENCODED_DATA_TIMEDELTA, + ], "next_page_token": "10", }, "swhtype": "paged_result", } TestPagedResultTuple = PagedResult[ Union[str, bytes, datetime.datetime], List[Union[str, UUID]] ] DATA_PAGED_RESULT2 = TestPagedResultTuple( - results=["data0", DATA_BYTES, DATA_DATETIME], next_page_token=["10", DATA_UUID], + results=["data0", DATA_BYTES, DATA_DATETIME], + next_page_token=["10", DATA_UUID], ) ENCODED_DATA_PAGED_RESULT2 = { "d": { - "results": ["data0", ENCODED_DATA_BYTES, ENCODED_DATA_DATETIME,], + "results": [ + "data0", + ENCODED_DATA_BYTES, + ENCODED_DATA_DATETIME, + ], "next_page_token": ["10", ENCODED_DATA_UUID], }, "swhtype": "paged_result", } DATA = { "bytes": DATA_BYTES, "datetime_tz": DATA_DATETIME, "datetime_utc": datetime.datetime( 2015, 3, 4, 18, 25, 13, 1234, tzinfo=datetime.timezone.utc ), "datetime_delta": DATA_TIMEDELTA, "swhtype": "fake", "swh_dict": {"swhtype": 42, "d": "test"}, "random_dict": {"swhtype": 43}, "uuid": DATA_UUID, "paged-result": DATA_PAGED_RESULT, "paged-result2": DATA_PAGED_RESULT2, } ENCODED_DATA = { "bytes": ENCODED_DATA_BYTES, "datetime_tz": ENCODED_DATA_DATETIME, - "datetime_utc": {"swhtype": "datetime", "d": "2015-03-04T18:25:13.001234+00:00",}, + "datetime_utc": { + "swhtype": "datetime", + "d": "2015-03-04T18:25:13.001234+00:00", + }, "datetime_delta": ENCODED_DATA_TIMEDELTA, "swhtype": "fake", "swh_dict": {"swhtype": 42, "d": "test"}, "random_dict": {"swhtype": 43}, "uuid": ENCODED_DATA_UUID, "paged-result": ENCODED_DATA_PAGED_RESULT, "paged-result2": ENCODED_DATA_PAGED_RESULT2, } class ComplexExceptionType(Exception): def __init__(self, error_type, message): self.error_type = error_type super().__init__(f"{error_type}: {message}") def test_serializers_round_trip_json(): json_data = json_dumps(DATA) actual_data = json_loads(json_data) assert actual_data == DATA def test_serializers_round_trip_json_extra_types(): expected_original_data = [ExtraType("baz", DATA), "qux"] data = json_dumps(expected_original_data, extra_encoders=extra_encoders) actual_data = json_loads(data, extra_decoders=extra_decoders) assert actual_data == expected_original_data def test_exception_serializer_round_trip_json(): error_message = "unreachable host" - json_data = json_dumps({"exception": ConnectionError(error_message)},) + json_data = json_dumps( + {"exception": ConnectionError(error_message)}, + ) actual_data = json_loads(json_data) assert "exception" in actual_data assert type(actual_data["exception"]) == ConnectionError assert str(actual_data["exception"]) == error_message def test_complex_exception_serializer_round_trip_json(): exception = ComplexExceptionType("NotFound", "the object is missing") json_data = json_dumps({"exception": exception}) actual_data = json_loads(json_data) assert "exception" in actual_data assert type(actual_data["exception"]) == Exception assert str(actual_data["exception"]) == str(exception) def test_serializers_encode_swh_json(): json_str = json_dumps(DATA) actual_data = json.loads(json_str) assert actual_data == ENCODED_DATA def test_serializers_round_trip_msgpack(): expected_original_data = { **DATA, "none_dict_key": {None: 42}, "long_int_is_loooong": 10000000000000000000000000000000, "long_negative_int_is_loooong": -10000000000000000000000000000000, } data = msgpack_dumps(expected_original_data) actual_data = msgpack_loads(data) assert actual_data == expected_original_data def test_serializers_round_trip_msgpack_extra_types(): original_data = [ExtraType("baz", DATA), "qux"] data = msgpack_dumps(original_data, extra_encoders=extra_encoders) actual_data = msgpack_loads(data, extra_decoders=extra_decoders) assert actual_data == original_data def test_exception_serializer_round_trip_msgpack(): error_message = "unreachable host" data = msgpack_dumps({"exception": ConnectionError(error_message)}) actual_data = msgpack_loads(data) assert "exception" in actual_data assert type(actual_data["exception"]) == ConnectionError assert str(actual_data["exception"]) == error_message def test_complex_exception_serializer_round_trip_msgpack(): exception = ComplexExceptionType("NotFound", "the object is missing") data = msgpack_dumps({"exception": exception}) actual_data = msgpack_loads(data) assert "exception" in actual_data assert type(actual_data["exception"]) == Exception assert str(actual_data["exception"]) == str(exception) def test_serializers_generator_json(): data = json_dumps((i for i in range(5))) assert json_loads(data) == [i for i in range(5)] def test_serializers_generator_msgpack(): data = msgpack_dumps((i for i in range(5))) assert msgpack_loads(data) == [i for i in range(5)] def test_serializers_decode_response_json(requests_mock): requests_mock.get( "https://example.org/test/data", json=ENCODED_DATA, headers={"content-type": "application/json"}, ) response = requests.get("https://example.org/test/data") assert decode_response(response) == DATA def test_serializers_encode_datetime_msgpack(): dt = datetime.datetime.now(tz=datetime.timezone.utc) encmsg = msgpack_dumps(dt) decmsg = msgpack.loads(encmsg, timestamp=0) assert isinstance(decmsg, msgpack.Timestamp) assert decmsg.to_datetime() == dt def test_serializers_decode_datetime_compat_msgpack(): dt = datetime.datetime.now(tz=datetime.timezone.utc) encmsg = msgpack_dumps({b"swhtype": "datetime", b"d": dt.isoformat()}) decmsg = msgpack_loads(encmsg) assert decmsg == dt def test_serializers_encode_native_datetime_msgpack(): dt = datetime.datetime(2015, 1, 1, 12, 4, 42, 231455) with pytest.raises((TypeError, ValueError), match="datetime"): msgpack_dumps(dt) def test_serializers_encode_native_datetime_json(): dt = datetime.datetime(2015, 1, 1, 12, 4, 42, 231455) with pytest.raises(TypeError, match="datetime"): json_dumps(dt) def test_serializers_decode_naive_datetime(): expected_dt = datetime.datetime(2015, 1, 1, 12, 4, 42, 231455) # Current encoding assert ( msgpack_loads( b"\x82\xc4\x07swhtype\xa8datetime\xc4\x01d\xba" b"2015-01-01T12:04:42.231455" ) == expected_dt ) def test_msgpack_extra_encoders_mutation(): data = msgpack_dumps({}, extra_encoders=extra_encoders) assert data is not None assert ENCODERS[-1][0] != ExtraType def test_json_extra_encoders_mutation(): data = json_dumps({}, extra_encoders=extra_encoders) assert data is not None assert ENCODERS[-1][0] != ExtraType diff --git a/swh/core/cli/__init__.py b/swh/core/cli/__init__.py index 882bd10..c064f91 100644 --- a/swh/core/cli/__init__.py +++ b/swh/core/cli/__init__.py @@ -1,189 +1,188 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import logging.config from typing import Optional import warnings import click import pkg_resources LOG_LEVEL_NAMES = ["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"]) logger = logging.getLogger(__name__) class AliasedGroup(click.Group): """A simple Group that supports command aliases, as well as notes related to options""" def __init__(self, name=None, commands=None, **attrs): self.option_notes = attrs.pop("option_notes", None) self.aliases = {} super().__init__(name, commands, **attrs) def get_command(self, ctx, cmd_name): return super().get_command(ctx, self.aliases.get(cmd_name, cmd_name)) def add_alias(self, name, alias): if not isinstance(name, str): name = name.name self.aliases[alias] = name def format_options(self, ctx, formatter): click.Command.format_options(self, ctx, formatter) if self.option_notes: with formatter.section("Notes"): formatter.write_text(self.option_notes) self.format_commands(ctx, formatter) def clean_exit_on_signal(signal, frame): """Raise a SystemExit exception to let command-line clients wind themselves down on exit""" raise SystemExit(0) def validate_loglevel_params(ctx, param, value): """Validate the --log-level parameters, with multiple values""" if value is None: return None return [validate_loglevel(ctx, param, v) for v in value] def validate_loglevel(ctx, param, value): """Validate a single loglevel specification, of the form LOGLEVEL or module:LOGLEVEL.""" if ":" in value: try: module, log_level = value.split(":") except ValueError: raise click.BadParameter( "Invalid log level specification `%s`, " "needs to be in format `module:LOGLEVEL`" % value ) else: module = None log_level = value if log_level not in LOG_LEVEL_NAMES: raise click.BadParameter( "Log level %s unknown (in `%s`) needs to be one of %s" % (log_level, value, ", ".join(LOG_LEVEL_NAMES)) ) return (module, log_level) @click.group( context_settings=CONTEXT_SETTINGS, cls=AliasedGroup, option_notes="""\ If both options are present, --log-level values will override the configuration in --log-config. The --log-config YAML must conform to the logging.config.dictConfig schema documented at https://docs.python.org/3/library/logging.config.html. """, ) @click.option( "--log-level", "-l", "log_levels", default=None, callback=validate_loglevel_params, multiple=True, help=( "Log level (defaults to INFO). " "Can override the log level for a specific module, by using the " "``specific.module:LOGLEVEL`` syntax (e.g. ``--log-level swh.core:DEBUG`` " "will enable DEBUG logging for swh.core)." ), ) @click.option( "--log-config", default=None, type=click.File("r"), help="Python yaml logging configuration file.", ) @click.option( "--sentry-dsn", default=None, help="DSN of the Sentry instance to report to" ) @click.option( "--sentry-debug/--no-sentry-debug", default=False, hidden=True, help="Enable debugging of sentry", ) @click.pass_context def swh(ctx, log_levels, log_config, sentry_dsn, sentry_debug): - """Command line interface for Software Heritage. - """ + """Command line interface for Software Heritage.""" import signal import yaml from ..sentry import init_sentry signal.signal(signal.SIGTERM, clean_exit_on_signal) signal.signal(signal.SIGINT, clean_exit_on_signal) init_sentry(sentry_dsn, debug=sentry_debug) set_default_loglevel: Optional[str] = None if log_config: logging.config.dictConfig(yaml.safe_load(log_config.read())) set_default_loglevel = logging.root.getEffectiveLevel() if not log_levels: log_levels = [] for module, log_level in log_levels: logger = logging.getLogger(module) log_level = logging.getLevelName(log_level) logger.setLevel(log_level) if module is None: set_default_loglevel = log_level if not set_default_loglevel: logging.root.setLevel("INFO") set_default_loglevel = "INFO" ctx.ensure_object(dict) ctx.obj["log_level"] = set_default_loglevel def main(): # Even though swh() sets up logging, we need an earlier basic logging setup # for the next few logging statements logging.basicConfig() # load plugins that define cli sub commands for entry_point in pkg_resources.iter_entry_points("swh.cli.subcommands"): try: cmd = entry_point.load() if isinstance(cmd, click.BaseCommand): # for BW compat, auto add click commands warnings.warn( f"{entry_point.name}: automagic addition of click commands " f"to the main swh group is deprecated", DeprecationWarning, ) swh.add_command(cmd, name=entry_point.name) # otherwise it's expected to be a module which has been loaded # it's the responsibility of the click commands/groups in this # module to transitively have the main swh group as parent. except Exception as e: logger.warning("Could not load subcommand %s: %r", entry_point.name, e) return swh(auto_envvar_prefix="SWH") if __name__ == "__main__": main() diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py index f23580e..51aea86 100755 --- a/swh/core/cli/db.py +++ b/swh/core/cli/db.py @@ -1,420 +1,423 @@ #!/usr/bin/env python3 # Copyright (C) 2018-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging from os import environ import warnings import click from swh.core.cli import CONTEXT_SETTINGS from swh.core.cli import swh as swh_cli_group warnings.filterwarnings("ignore") # noqa prevent psycopg from telling us sh*t logger = logging.getLogger(__name__) @swh_cli_group.group(name="db", context_settings=CONTEXT_SETTINGS) @click.option( "--config-file", "-C", default=None, type=click.Path(exists=True, dir_okay=False), help="Configuration file.", ) @click.pass_context def db(ctx, config_file): """Software Heritage database generic tools.""" from swh.core.config import read as config_read ctx.ensure_object(dict) if config_file is None: config_file = environ.get("SWH_CONFIG_FILENAME") cfg = config_read(config_file) ctx.obj["config"] = cfg @db.command(name="create", context_settings=CONTEXT_SETTINGS) @click.argument("module", required=True) @click.option( "--dbname", "--db-name", "-d", help="Database name.", default="softwareheritage-dev", show_default=True, ) @click.option( "--template", "-T", help="Template database from which to build this database.", default="template1", show_default=True, ) def db_create(module, dbname, template): """Create a database for the Software Heritage . and potentially execute superuser-level initialization steps. Example:: swh db create -d swh-test storage If you want to specify non-default postgresql connection parameters, please provide them using standard environment variables or by the mean of a properly crafted libpq connection URI. See psql(1) man page (section ENVIRONMENTS) for details. Note: this command requires a postgresql connection with superuser permissions. Example:: PGPORT=5434 swh db create indexer swh db create -d postgresql://superuser:passwd@pghost:5433/swh-storage storage """ from swh.core.db.db_utils import create_database_for_package logger.debug("db_create %s dn_name=%s", module, dbname) create_database_for_package(module, dbname, template) @db.command(name="init-admin", context_settings=CONTEXT_SETTINGS) @click.argument("module", required=True) @click.option( "--dbname", "--db-name", "-d", help="Database name.", default="softwareheritage-dev", show_default=True, ) def db_init_admin(module: str, dbname: str) -> None: """Execute superuser-level initialization steps (e.g pg extensions, admin functions, ...) Example:: PGPASSWORD=... swh db init-admin -d swh-test scheduler If you want to specify non-default postgresql connection parameters, please provide them using standard environment variables or by the mean of a properly crafted libpq connection URI. See psql(1) man page (section ENVIRONMENTS) for details. Note: this command requires a postgresql connection with superuser permissions (e.g postgres, swh-admin, ...) Example:: PGPORT=5434 swh db init-admin scheduler swh db init-admin -d postgresql://superuser:passwd@pghost:5433/swh-scheduler \ scheduler """ from swh.core.db.db_utils import init_admin_extensions logger.debug("db_init_admin %s dbname=%s", module, dbname) init_admin_extensions(module, dbname) @db.command(name="init", context_settings=CONTEXT_SETTINGS) @click.argument("module", required=True) @click.option( "--dbname", "--db-name", "-d", help="Database name or connection URI.", default=None, show_default=False, ) @click.option( - "--flavor", help="Database flavor.", default=None, + "--flavor", + help="Database flavor.", + default=None, ) @click.option( "--initial-version", help="Database initial version.", default=1, show_default=True ) @click.pass_context def db_init(ctx, module, dbname, flavor, initial_version): """Initialize a database for the Software Heritage . The database connection string comes from the configuration file (see option ``--config-file`` in ``swh db --help``) in the section named after the MODULE argument. Example:: $ cat conf.yml storage: cls: postgresql db: postgresql://user:passwd@pghost:5433/swh-storage objstorage: cls: memory $ swh db -C conf.yml init storage # or $ SWH_CONFIG_FILENAME=conf.yml swh db init storage Note that the connection string can also be passed directly using the '--db-name' option, but this usage is about to be deprecated. """ from swh.core.db.db_utils import ( get_database_info, import_swhmodule, populate_database_for_package, swh_set_db_version, ) cfg = None if dbname is None: # use the db cnx from the config file; the expected config entry is the # given module name cfg = ctx.obj["config"].get(module, {}) dbname = get_dburl_from_config(cfg) if not dbname: raise click.BadParameter( "Missing the postgresql connection configuration. Either fix your " "configuration file or use the --dbname option." ) logger.debug("db_init %s flavor=%s dbname=%s", module, flavor, dbname) initialized, dbversion, dbflavor = populate_database_for_package( module, dbname, flavor ) if dbversion is None: if cfg is not None: # db version has not been populated by sql init scripts (new style), # let's do it; instantiate the data source to retrieve the current # (expected) db version datastore_factory = getattr(import_swhmodule(module), "get_datastore", None) if datastore_factory: datastore = datastore_factory(**cfg) try: get_current_version = datastore.get_current_version except AttributeError: logger.warning( "Datastore %s does not implement the " "'get_current_version()' method", datastore, ) else: code_version = get_current_version() logger.info( "Initializing database version to %s from the %s datastore", code_version, module, ) swh_set_db_version(dbname, code_version, desc="DB initialization") dbversion = get_database_info(dbname)[1] if dbversion is None: logger.info( "Initializing database version to %s " "from the command line option --initial-version", initial_version, ) swh_set_db_version(dbname, initial_version, desc="DB initialization") dbversion = get_database_info(dbname)[1] assert dbversion is not None # TODO: Ideally migrate the version from db_version to the latest # db version click.secho( "DONE database for {} {}{} at version {}".format( module, "initialized" if initialized else "exists", f" (flavor {dbflavor})" if dbflavor is not None else "", dbversion, ), fg="green", bold=True, ) if flavor is not None and dbflavor != flavor: click.secho( f"WARNING requested flavor '{flavor}' != recorded flavor '{dbflavor}'", fg="red", bold=True, ) @db.command(name="version", context_settings=CONTEXT_SETTINGS) @click.argument("module", required=True) @click.option( "--all/--no-all", "show_all", help="Show version history.", default=False, show_default=True, ) @click.pass_context def db_version(ctx, module, show_all): """Print the database version for the Software Heritage. Example:: swh db version -d swh-test """ from swh.core.db.db_utils import get_database_info, import_swhmodule # use the db cnx from the config file; the expected config entry is the # given module name cfg = ctx.obj["config"].get(module, {}) dbname = get_dburl_from_config(cfg) if not dbname: raise click.BadParameter( "Missing the postgresql connection configuration. Either fix your " "configuration file or use the --dbname option." ) logger.debug("db_version dbname=%s", dbname) db_module, db_version, db_flavor = get_database_info(dbname) if db_module is None: click.secho( "WARNING the database does not have a dbmodule table.", fg="red", bold=True ) db_module = module assert db_module == module, f"{db_module} (in the db) != {module} (given)" click.secho(f"module: {db_module}", fg="green", bold=True) if db_flavor is not None: click.secho(f"flavor: {db_flavor}", fg="green", bold=True) # instantiate the data source to retrieve the current (expected) db version datastore_factory = getattr(import_swhmodule(db_module), "get_datastore", None) if datastore_factory: datastore = datastore_factory(**cfg) code_version = datastore.get_current_version() click.secho( f"current code version: {code_version}", fg="green" if code_version == db_version else "red", bold=True, ) if not show_all: click.secho(f"version: {db_version}", fg="green", bold=True) else: from swh.core.db.db_utils import swh_db_versions versions = swh_db_versions(dbname) for version, tstamp, desc in versions: click.echo(f"{version} [{tstamp}] {desc}") @db.command(name="upgrade", context_settings=CONTEXT_SETTINGS) @click.argument("module", required=True) @click.option( "--to-version", type=int, help="Upgrade up to version VERSION", metavar="VERSION", default=None, ) @click.option( "--interactive/--non-interactive", help="Do not ask questions (use default answer to all questions)", default=True, ) @click.pass_context def db_upgrade(ctx, module, to_version, interactive): """Upgrade the database for given module (to a given version if specified). Examples:: swh db upgrade storage swg db upgrade scheduler --to-version=10 """ from swh.core.db.db_utils import ( get_database_info, import_swhmodule, swh_db_upgrade, swh_set_db_module, ) # use the db cnx from the config file; the expected config entry is the # given module name cfg = ctx.obj["config"].get(module, {}) dbname = get_dburl_from_config(cfg) if not dbname: raise click.BadParameter( "Missing the postgresql connection configuration. Either fix your " "configuration file or use the --dbname option." ) logger.debug("db_version dbname=%s", dbname) db_module, db_version, db_flavor = get_database_info(dbname) if db_module is None: click.secho( "Warning: the database does not have a dbmodule table.", fg="yellow", bold=True, ) if interactive and not click.confirm( f"Write the module information ({module}) in the database?", default=True ): raise click.BadParameter("Migration aborted.") swh_set_db_module(dbname, module) db_module = module if db_module != module: raise click.BadParameter( f"Error: the given module ({module}) does not match the value " f"stored in the database ({db_module})." ) # instantiate the data source to retrieve the current (expected) db version datastore_factory = getattr(import_swhmodule(db_module), "get_datastore", None) if not datastore_factory: raise click.UsageError( "You cannot use this command on old-style datastore backend {db_module}" ) datastore = datastore_factory(**cfg) ds_version = datastore.get_current_version() if to_version is None: to_version = ds_version if to_version > ds_version: raise click.UsageError( f"The target version {to_version} is larger than the current version " f"{ds_version} of the datastore backend {db_module}" ) if to_version == db_version: click.secho( - f"No migration needed: the current version is {db_version}", fg="yellow", + f"No migration needed: the current version is {db_version}", + fg="yellow", ) else: new_db_version = swh_db_upgrade(dbname, module, to_version) click.secho(f"Migration to version {new_db_version} done", fg="green") if new_db_version < ds_version: click.secho( "Warning: migration was not complete: " f"the current version is {ds_version}", fg="yellow", ) def get_dburl_from_config(cfg): if cfg.get("cls") != "postgresql": raise click.BadParameter( "Configuration cls must be set to 'postgresql' for this command." ) if "args" in cfg: # for bw compat cfg = cfg["args"] return cfg.get("db") diff --git a/swh/core/config.py b/swh/core/config.py index 119eadc..a537471 100644 --- a/swh/core/config.py +++ b/swh/core/config.py @@ -1,308 +1,308 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from copy import deepcopy from itertools import chain import logging import os from typing import Any, Callable, Dict, List, Optional, Tuple import yaml logger = logging.getLogger(__name__) SWH_CONFIG_DIRECTORIES = [ "~/.config/swh", "~/.swh", "/etc/softwareheritage", ] SWH_GLOBAL_CONFIG = "global.yml" SWH_DEFAULT_GLOBAL_CONFIG = { "max_content_size": ("int", 100 * 1024 * 1024), } SWH_CONFIG_EXTENSIONS = [ ".yml", ] # conversion per type _map_convert_fn: Dict[str, Callable] = { "int": int, "bool": lambda x: x.lower() == "true", "list[str]": lambda x: [value.strip() for value in x.split(",")], "list[int]": lambda x: [int(value.strip()) for value in x.split(",")], } _map_check_fn: Dict[str, Callable] = { "int": lambda x: isinstance(x, int), "bool": lambda x: isinstance(x, bool), "list[str]": lambda x: (isinstance(x, list) and all(isinstance(y, str) for y in x)), "list[int]": lambda x: (isinstance(x, list) and all(isinstance(y, int) for y in x)), } def exists_accessible(filepath: str) -> bool: """Check whether a file exists, and is accessible. Returns: True if the file exists and is accessible False if the file does not exist Raises: PermissionError if the file cannot be read. """ try: os.stat(filepath) except PermissionError: raise except FileNotFoundError: return False else: if os.access(filepath, os.R_OK): return True else: raise PermissionError("Permission denied: {filepath!r}") def config_basepath(config_path: str) -> str: """Return the base path of a configuration file""" if config_path.endswith(".yml"): return config_path[:-4] return config_path def read_raw_config(base_config_path: str) -> Dict[str, Any]: """Read the raw config corresponding to base_config_path. Can read yml files. """ yml_file = f"{base_config_path}.yml" if exists_accessible(yml_file): logger.debug("Loading config file %s", yml_file) with open(yml_file) as f: return yaml.safe_load(f) return {} def config_exists(config_path): """Check whether the given config exists""" basepath = config_basepath(config_path) return any( exists_accessible(basepath + extension) for extension in SWH_CONFIG_EXTENSIONS ) def read( conf_file: Optional[str] = None, default_conf: Optional[Dict[str, Tuple[str, Any]]] = None, ) -> Dict[str, Any]: """Read the user's configuration file. Fill in the gap using `default_conf`. `default_conf` is similar to this:: DEFAULT_CONF = { 'a': ('str', '/tmp/swh-loader-git/log'), 'b': ('str', 'dbname=swhloadergit') 'c': ('bool', true) 'e': ('bool', None) 'd': ('int', 10) } If conf_file is None, return the default config. """ conf: Dict[str, Any] = {} if conf_file: base_config_path = config_basepath(os.path.expanduser(conf_file)) conf = read_raw_config(base_config_path) or {} if not default_conf: return conf # remaining missing default configuration key are set # also type conversion is enforced for underneath layer for key, (nature_type, default_value) in default_conf.items(): val = conf.get(key, None) if val is None: # fallback to default value conf[key] = default_value elif not _map_check_fn.get(nature_type, lambda x: True)(val): # value present but not in the proper format, force type conversion conf[key] = _map_convert_fn.get(nature_type, lambda x: x)(val) return conf def priority_read( conf_filenames: List[str], default_conf: Optional[Dict[str, Tuple[str, Any]]] = None ): """Try reading the configuration files from conf_filenames, in order, - and return the configuration from the first one that exists. + and return the configuration from the first one that exists. - default_conf has the same specification as it does in read. + default_conf has the same specification as it does in read. """ # Try all the files in order for filename in conf_filenames: full_filename = os.path.expanduser(filename) if config_exists(full_filename): return read(full_filename, default_conf) # Else, return the default configuration return read(None, default_conf) def merge_default_configs(base_config, *other_configs): """Merge several default config dictionaries, from left to right""" full_config = base_config.copy() for config in other_configs: full_config.update(config) return full_config def merge_configs(base: Optional[Dict[str, Any]], other: Optional[Dict[str, Any]]): """Merge two config dictionaries This does merge config dicts recursively, with the rules, for every value of the dicts (with 'val' not being a dict): - None + type -> type - type + None -> None - dict + dict -> dict (merged) - val + dict -> TypeError - dict + val -> TypeError - val + val -> val (other) for instance: >>> d1 = { ... 'key1': { ... 'skey1': 'value1', ... 'skey2': {'sskey1': 'value2'}, ... }, ... 'key2': 'value3', ... } with >>> d2 = { ... 'key1': { ... 'skey1': 'value4', ... 'skey2': {'sskey2': 'value5'}, ... }, ... 'key3': 'value6', ... } will give: >>> d3 = { ... 'key1': { ... 'skey1': 'value4', # <-- note this ... 'skey2': { ... 'sskey1': 'value2', ... 'sskey2': 'value5', ... }, ... }, ... 'key2': 'value3', ... 'key3': 'value6', ... } >>> assert merge_configs(d1, d2) == d3 Note that no type checking is done for anything but dicts. """ if not isinstance(base, dict) or not isinstance(other, dict): raise TypeError("Cannot merge a %s with a %s" % (type(base), type(other))) output = {} allkeys = set(chain(base.keys(), other.keys())) for k in allkeys: vb = base.get(k) vo = other.get(k) if isinstance(vo, dict): output[k] = merge_configs(vb is not None and vb or {}, vo) elif isinstance(vb, dict) and k in other and other[k] is not None: output[k] = merge_configs(vb, vo is not None and vo or {}) elif k in other: output[k] = deepcopy(vo) else: output[k] = deepcopy(vb) return output def swh_config_paths(base_filename: str) -> List[str]: """Return the Software Heritage specific configuration paths for the given - filename.""" + filename.""" return [os.path.join(dirname, base_filename) for dirname in SWH_CONFIG_DIRECTORIES] def prepare_folders(conf, *keys): - """Prepare the folder mentioned in config under keys. - """ + """Prepare the folder mentioned in config under keys.""" def makedir(folder): if not os.path.exists(folder): os.makedirs(folder) for key in keys: makedir(conf[key]) def load_global_config(): """Load the global Software Heritage config""" return priority_read( - swh_config_paths(SWH_GLOBAL_CONFIG), SWH_DEFAULT_GLOBAL_CONFIG, + swh_config_paths(SWH_GLOBAL_CONFIG), + SWH_DEFAULT_GLOBAL_CONFIG, ) def load_named_config(name, default_conf=None, global_conf=True): """Load the config named `name` from the Software Heritage - configuration paths. + configuration paths. - If global_conf is True (default), read the global configuration - too. + If global_conf is True (default), read the global configuration + too. """ conf = {} if global_conf: conf.update(load_global_config()) conf.update(priority_read(swh_config_paths(name), default_conf)) return conf def load_from_envvar(default_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """Load configuration yaml file from the environment variable SWH_CONFIG_FILENAME, eventually enriched with default configuration key/value from the default_config dict if provided. Returns: Configuration dict Raises: AssertionError if SWH_CONFIG_FILENAME is undefined """ assert ( "SWH_CONFIG_FILENAME" in os.environ ), "SWH_CONFIG_FILENAME environment variable is undefined." cfg_path = os.environ["SWH_CONFIG_FILENAME"] cfg = read_raw_config(config_basepath(cfg_path)) cfg = merge_configs(default_config or {}, cfg) return cfg diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py index 94568ba..916e67b 100644 --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -1,675 +1,676 @@ # Copyright (C) 2015-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime, timezone import functools from importlib import import_module import logging from os import path import pathlib import re import subprocess from typing import Collection, Dict, List, Optional, Tuple, Union, cast import psycopg2 import psycopg2.errors import psycopg2.extensions from psycopg2.extensions import connection as pgconnection from psycopg2.extensions import encodings as pgencodings from psycopg2.extensions import make_dsn from psycopg2.extensions import parse_dsn as _parse_dsn from swh.core.utils import numfile_sortkey as sortkey logger = logging.getLogger(__name__) def now(): return datetime.now(tz=timezone.utc) def stored_procedure(stored_proc): """decorator to execute remote stored procedure, specified as argument Generally, the body of the decorated function should be empty. If it is not, the stored procedure will be executed first; the function body then. """ def wrap(meth): @functools.wraps(meth) def _meth(self, *args, **kwargs): cur = kwargs.get("cur", None) self._cursor(cur).execute("SELECT %s()" % stored_proc) meth(self, *args, **kwargs) return _meth return wrap def jsonize(value): """Convert a value to a psycopg2 JSON object if necessary""" if isinstance(value, dict): return psycopg2.extras.Json(value) return value def connect_to_conninfo(db_or_conninfo: Union[str, pgconnection]) -> pgconnection: """Connect to the database passed in argument Args: db_or_conninfo: A database connection, or a database connection info string Returns: a connected database handle Raises: psycopg2.Error if the database doesn't exist """ if isinstance(db_or_conninfo, pgconnection): return db_or_conninfo if "=" not in db_or_conninfo and "//" not in db_or_conninfo: # Database name db_or_conninfo = f"dbname={db_or_conninfo}" db = psycopg2.connect(db_or_conninfo) return db def swh_db_version(db_or_conninfo: Union[str, pgconnection]) -> Optional[int]: """Retrieve the swh version of the database. If the database is not initialized, this logs a warning and returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: Either the version of the database, or None if it couldn't be detected """ try: db = connect_to_conninfo(db_or_conninfo) except psycopg2.Error: logger.exception("Failed to connect to `%s`", db_or_conninfo) # Database not initialized return None try: with db.cursor() as c: query = "select version from dbversion order by dbversion desc limit 1" try: c.execute(query) result = c.fetchone() if result: return result[0] except psycopg2.errors.UndefinedTable: return None except Exception: logger.exception("Could not get version from `%s`", db_or_conninfo) return None def swh_db_versions( db_or_conninfo: Union[str, pgconnection] ) -> Optional[List[Tuple[int, datetime, str]]]: """Retrieve the swh version history of the database. If the database is not initialized, this logs a warning and returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: Either the version of the database, or None if it couldn't be detected """ try: db = connect_to_conninfo(db_or_conninfo) except psycopg2.Error: logger.exception("Failed to connect to `%s`", db_or_conninfo) # Database not initialized return None try: with db.cursor() as c: query = ( "select version, release, description " "from dbversion order by dbversion desc" ) try: c.execute(query) return cast(List[Tuple[int, datetime, str]], c.fetchall()) except psycopg2.errors.UndefinedTable: return None except Exception: logger.exception("Could not get versions from `%s`", db_or_conninfo) return None def swh_db_upgrade( conninfo: str, modname: str, to_version: Optional[int] = None ) -> int: """Upgrade the database at `conninfo` for module `modname` This will run migration scripts found in the `sql/upgrades` subdirectory of the module `modname`. By default, this will upgrade to the latest declared version. Args: conninfo: A database connection, or a database connection info string modname: datastore module the database stores content for to_version: if given, update the database to this version rather than the latest """ if to_version is None: to_version = 99999999 db_module, db_version, db_flavor = get_database_info(conninfo) if db_version is None: raise ValueError("Unable to retrieve the current version of the database") if db_module is None: raise ValueError("Unable to retrieve the module of the database") if db_module != modname: raise ValueError( "The stored module of the database is different than the given one" ) sqlfiles = [ fname for fname in get_sql_for_package(modname, upgrade=True) if db_version < int(fname.stem) <= to_version ] if not sqlfiles: return db_version for sqlfile in sqlfiles: new_version = int(path.splitext(path.basename(sqlfile))[0]) logger.info("Executing migration script '%s'", sqlfile) if db_version is not None and (new_version - db_version) > 1: logger.error( f"There are missing migration steps between {db_version} and " f"{new_version}. It might be expected but it most unlikely is not. " "Will stop here." ) return db_version execute_sqlfiles([sqlfile], conninfo, db_flavor) # check if the db version has been updated by the upgrade script db_version = swh_db_version(conninfo) assert db_version is not None if db_version == new_version: # nothing to do, upgrade script did the job pass elif db_version == new_version - 1: # it has not (new style), so do it swh_set_db_version( conninfo, new_version, desc=f"Upgraded to version {new_version} using {sqlfile}", ) db_version = swh_db_version(conninfo) else: # upgrade script did it wrong logger.error( f"The upgrade script {sqlfile} did not update the dbversion table " f"consistently ({db_version} vs. expected {new_version}). " "Will stop migration here. Please check your migration scripts." ) return db_version return new_version def swh_db_module(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]: """Retrieve the swh module used to create the database. If the database is not initialized, this logs a warning and returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: Either the module of the database, or None if it couldn't be detected """ try: db = connect_to_conninfo(db_or_conninfo) except psycopg2.Error: logger.exception("Failed to connect to `%s`", db_or_conninfo) # Database not initialized return None try: with db.cursor() as c: query = "select dbmodule from dbmodule limit 1" try: c.execute(query) resp = c.fetchone() if resp: return resp[0] except psycopg2.errors.UndefinedTable: return None except Exception: logger.exception("Could not get module from `%s`", db_or_conninfo) return None def swh_set_db_module( db_or_conninfo: Union[str, pgconnection], module: str, force=False ) -> None: """Set the swh module used to create the database. Fails if the dbmodule is already set or the table does not exist. Args: db_or_conninfo: A database connection, or a database connection info string module: the swh module to register (without the leading 'swh.') """ update = False if module.startswith("swh."): module = module[4:] current_module = swh_db_module(db_or_conninfo) if current_module is not None: if current_module == module: logger.warning("The database module is already set to %s", module) return if not force: raise ValueError( "The database module is already set to a value %s " "different than given %s", current_module, module, ) # force is True update = True try: db = connect_to_conninfo(db_or_conninfo) except psycopg2.Error: logger.exception("Failed to connect to `%s`", db_or_conninfo) # Database not initialized return None sqlfiles = [ fname for fname in get_sql_for_package("swh.core.db") if "dbmodule" in fname.stem ] execute_sqlfiles(sqlfiles, db_or_conninfo) with db.cursor() as c: if update: query = "update dbmodule set dbmodule = %s" else: query = "insert into dbmodule(dbmodule) values (%s)" c.execute(query, (module,)) db.commit() def swh_set_db_version( db_or_conninfo: Union[str, pgconnection], version: int, ts: Optional[datetime] = None, desc: str = "Work in progress", ) -> None: """Set the version of the database. Fails if the dbversion table does not exists. Args: db_or_conninfo: A database connection, or a database connection info string version: the version to add """ try: db = connect_to_conninfo(db_or_conninfo) except psycopg2.Error: logger.exception("Failed to connect to `%s`", db_or_conninfo) # Database not initialized return None if ts is None: ts = now() with db.cursor() as c: query = ( "insert into dbversion(version, release, description) values (%s, %s, %s)" ) c.execute(query, (version, ts, desc)) db.commit() def swh_db_flavor(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]: """Retrieve the swh flavor of the database. If the database is not initialized, or the database doesn't support flavors, this returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: The flavor of the database, or None if it could not be detected. """ try: db = connect_to_conninfo(db_or_conninfo) except psycopg2.Error: logger.exception("Failed to connect to `%s`", db_or_conninfo) # Database not initialized return None try: with db.cursor() as c: query = "select swh_get_dbflavor()" try: c.execute(query) return c.fetchone()[0] except psycopg2.errors.UndefinedFunction: # function not found: no flavor return None except Exception: logger.exception("Could not get flavor from `%s`", db_or_conninfo) return None # The following code has been imported from psycopg2, version 2.7.4, # https://github.com/psycopg/psycopg2/tree/5afb2ce803debea9533e293eef73c92ffce95bcd # and modified by Software Heritage. # # Original file: lib/extras.py # # psycopg2 is free software: you can redistribute it and/or modify it under the # terms of the GNU Lesser General Public License as published by the Free # Software Foundation, either version 3 of the License, or (at your option) any # later version. def _paginate(seq, page_size): """Consume an iterable and return it in chunks. Every chunk is at most `page_size`. Never return an empty chunk. """ page = [] it = iter(seq) while 1: try: for i in range(page_size): page.append(next(it)) yield page page = [] except StopIteration: if page: yield page return def _split_sql(sql): """Split *sql* on a single ``%s`` placeholder. Split on the %s, perform %% replacement and return pre, post lists of snippets. """ curr = pre = [] post = [] - tokens = re.split(br"(%.)", sql) + tokens = re.split(rb"(%.)", sql) for token in tokens: if len(token) != 2 or token[:1] != b"%": curr.append(token) continue if token[1:] == b"s": if curr is pre: curr = post else: raise ValueError("the query contains more than one '%s' placeholder") elif token[1:] == b"%": curr.append(b"%") else: raise ValueError( "unsupported format character: '%s'" % token[1:].decode("ascii", "replace") ) if curr is pre: raise ValueError("the query doesn't contain any '%s' placeholder") return pre, post def execute_values_generator(cur, sql, argslist, template=None, page_size=100): """Execute a statement using SQL ``VALUES`` with a sequence of parameters. Rows returned by the query are returned through a generator. You need to consume the generator for the queries to be executed! :param cur: the cursor to use to execute the query. :param sql: the query to execute. It must contain a single ``%s`` placeholder, which will be replaced by a `VALUES list`__. Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``. :param argslist: sequence of sequences or dictionaries with the arguments to send to the query. The type and content must be consistent with *template*. :param template: the snippet to merge to every item in *argslist* to compose the query. - If the *argslist* items are sequences it should contain positional placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there are constants value...). - If the *argslist* items are mappings it should contain named placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``). If not specified, assume the arguments are sequence and use a simple positional template (i.e. ``(%s, %s, ...)``), with the number of placeholders sniffed by the first element in *argslist*. :param page_size: maximum number of *argslist* items to include in every statement. If there are more items the function will execute more than one statement. :param yield_from_cur: Whether to yield results from the cursor in this function directly. .. __: https://www.postgresql.org/docs/current/static/queries-values.html After the execution of the function the `cursor.rowcount` property will **not** contain a total result. """ # we can't just use sql % vals because vals is bytes: if sql is bytes # there will be some decoding error because of stupid codec used, and Py3 # doesn't implement % on bytes. if not isinstance(sql, bytes): sql = sql.encode(pgencodings[cur.connection.encoding]) pre, post = _split_sql(sql) for page in _paginate(argslist, page_size=page_size): if template is None: template = b"(" + b",".join([b"%s"] * len(page[0])) + b")" parts = pre[:] for args in page: parts.append(cur.mogrify(template, args)) parts.append(b",") parts[-1:] = post cur.execute(b"".join(parts)) yield from cur def import_swhmodule(modname): if not modname.startswith("swh."): modname = f"swh.{modname}" try: m = import_module(modname) except ImportError as exc: logger.error(f"Could not load the {modname} module: {exc}") return None return m def get_sql_for_package(modname: str, upgrade: bool = False) -> List[pathlib.Path]: """Return the (sorted) list of sql script files for the given swh module If upgrade is True, return the list of available migration scripts, otherwise, return the list of initialization scripts. """ m = import_swhmodule(modname) if m is None: raise ValueError(f"Module {modname} cannot be loaded") sqldir = pathlib.Path(m.__file__).parent / "sql" if upgrade: sqldir /= "upgrades" if not sqldir.is_dir(): raise ValueError( "Module {} does not provide a db schema (no sql/ dir)".format(modname) ) return sorted(sqldir.glob("*.sql"), key=lambda x: sortkey(x.name)) def populate_database_for_package( modname: str, conninfo: str, flavor: Optional[str] = None ) -> Tuple[bool, Optional[int], Optional[str]]: """Populate the database, pointed at with ``conninfo``, using the SQL files found in the package ``modname``. Also fill the 'dbmodule' table with the given ``modname``. Args: modname: Name of the module of which we're loading the files conninfo: connection info string for the SQL database flavor: the module-specific flavor which we want to initialize the database under Returns: Tuple with three elements: whether the database has been initialized; the current version of the database; if it exists, the flavor of the database. """ current_version = swh_db_version(conninfo) if current_version is not None: dbflavor = swh_db_flavor(conninfo) return False, current_version, dbflavor def globalsortkey(key): "like sortkey but only on basenames" return sortkey(path.basename(key)) sqlfiles = get_sql_for_package(modname) + get_sql_for_package("swh.core.db") sqlfiles = sorted(sqlfiles, key=lambda x: sortkey(x.stem)) sqlfiles = [fpath for fpath in sqlfiles if "-superuser-" not in fpath.stem] execute_sqlfiles(sqlfiles, conninfo, flavor) # populate the dbmodule table swh_set_db_module(conninfo, modname) current_db_version = swh_db_version(conninfo) dbflavor = swh_db_flavor(conninfo) return True, current_db_version, dbflavor def get_database_info( conninfo: str, ) -> Tuple[Optional[str], Optional[int], Optional[str]]: """Get version, flavor and module of the db""" dbmodule = swh_db_module(conninfo) dbversion = swh_db_version(conninfo) dbflavor = None if dbversion is not None: dbflavor = swh_db_flavor(conninfo) return (dbmodule, dbversion, dbflavor) def parse_dsn_or_dbname(dsn_or_dbname: str) -> Dict[str, str]: """Parse a psycopg2 dsn, falling back to supporting plain database names as well""" try: return _parse_dsn(dsn_or_dbname) except psycopg2.ProgrammingError: # psycopg2 failed to parse the DSN; it's probably a database name, # handle it as such return _parse_dsn(f"dbname={dsn_or_dbname}") def init_admin_extensions(modname: str, conninfo: str) -> None: """The remaining initialization process -- running -superuser- SQL files -- is done using the given conninfo, thus connecting to the newly created database """ sqlfiles = get_sql_for_package(modname) sqlfiles = [fname for fname in sqlfiles if "-superuser-" in fname.stem] execute_sqlfiles(sqlfiles, conninfo) def create_database_for_package( modname: str, conninfo: str, template: str = "template1" ): """Create the database pointed at with ``conninfo``, and initialize it using -superuser- SQL files found in the package ``modname``. Args: modname: Name of the module of which we're loading the files conninfo: connection info string or plain database name for the SQL database template: the name of the database to connect to and use as template to create the new database """ # Use the given conninfo string, but with dbname replaced by the template dbname # for the database creation step creation_dsn = parse_dsn_or_dbname(conninfo) dbname = creation_dsn["dbname"] creation_dsn["dbname"] = template logger.debug("db_create dbname=%s (from %s)", dbname, template) subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", make_dsn(**creation_dsn), "-c", f'CREATE DATABASE "{dbname}"', ] ) init_admin_extensions(modname, conninfo) def execute_sqlfiles( sqlfiles: Collection[pathlib.Path], db_or_conninfo: Union[str, pgconnection], flavor: Optional[str] = None, ): """Execute a list of SQL files on the database pointed at with ``db_or_conninfo``. Args: sqlfiles: List of SQL files to execute db_or_conninfo: A database connection, or a database connection info string flavor: the database flavor to initialize """ if isinstance(db_or_conninfo, str): conninfo = db_or_conninfo else: conninfo = db_or_conninfo.dsn psql_command = [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", conninfo, ] flavor_set = False for sqlfile in sqlfiles: logger.debug(f"execute SQL file {sqlfile} dbname={conninfo}") subprocess.check_call(psql_command + ["-f", str(sqlfile)]) if ( flavor is not None and not flavor_set and sqlfile.name.endswith("-flavor.sql") ): logger.debug("Setting database flavor %s", flavor) query = f"insert into dbflavor (flavor) values ('{flavor}')" subprocess.check_call(psql_command + ["-c", query]) flavor_set = True if flavor is not None and not flavor_set: logger.warn( - "Asked for flavor %s, but module does not support database flavors", flavor, + "Asked for flavor %s, but module does not support database flavors", + flavor, ) diff --git a/swh/core/db/pytest_plugin.py b/swh/core/db/pytest_plugin.py index 9e5b2cb..e12a0f3 100644 --- a/swh/core/db/pytest_plugin.py +++ b/swh/core/db/pytest_plugin.py @@ -1,282 +1,281 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import glob from importlib import import_module import logging import subprocess from typing import Callable, Iterable, Iterator, List, Optional, Sequence, Set, Union from _pytest.fixtures import FixtureRequest import psycopg2 import pytest from pytest_postgresql.compat import check_for_psycopg2, connection from pytest_postgresql.executor import PostgreSQLExecutor from pytest_postgresql.executor_noop import NoopExecutor from pytest_postgresql.janitor import DatabaseJanitor from swh.core.db.db_utils import ( init_admin_extensions, populate_database_for_package, swh_set_db_version, ) from swh.core.utils import basename_sortkey # to keep mypy happy regardless pytest-postgresql version try: _pytest_pgsql_get_config_module = import_module("pytest_postgresql.config") except ImportError: # pytest_postgresql < 3.0.0 _pytest_pgsql_get_config_module = import_module("pytest_postgresql.factories") _pytest_postgresql_get_config = getattr(_pytest_pgsql_get_config_module, "get_config") logger = logging.getLogger(__name__) class SWHDatabaseJanitor(DatabaseJanitor): """SWH database janitor implementation with a a different setup/teardown policy than than the stock one. Instead of dropping, creating and initializing the database for each test, it creates and initializes the db once, then truncates the tables (and sequences) in between tests. This is needed to have acceptable test performances. """ def __init__( self, user: str, host: str, port: int, dbname: str, version: Union[str, float], password: Optional[str] = None, isolation_level: Optional[int] = None, connection_timeout: int = 60, dump_files: Optional[Union[str, Sequence[str]]] = None, no_truncate_tables: Set[str] = set(), no_db_drop: bool = False, ) -> None: super().__init__(user, host, port, dbname, version) # do no truncate the following tables self.no_truncate_tables = set(no_truncate_tables) self.no_db_drop = no_db_drop self.dump_files = dump_files def psql_exec(self, fname: str) -> None: conninfo = ( f"host={self.host} user={self.user} port={self.port} dbname={self.dbname}" ) subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", conninfo, "-f", fname, ] ) def db_reset(self) -> None: - """Truncate tables (all but self.no_truncate_tables set) and sequences - - """ + """Truncate tables (all but self.no_truncate_tables set) and sequences""" with psycopg2.connect( - dbname=self.dbname, user=self.user, host=self.host, port=self.port, + dbname=self.dbname, + user=self.user, + host=self.host, + port=self.port, ) as cnx: with cnx.cursor() as cur: cur.execute( "SELECT table_name FROM information_schema.tables " "WHERE table_schema = %s", ("public",), ) all_tables = set(table for (table,) in cur.fetchall()) tables_to_truncate = all_tables - self.no_truncate_tables for table in tables_to_truncate: cur.execute("TRUNCATE TABLE %s CASCADE" % table) cur.execute( "SELECT sequence_name FROM information_schema.sequences " "WHERE sequence_schema = %s", ("public",), ) seqs = set(seq for (seq,) in cur.fetchall()) for seq in seqs: cur.execute("ALTER SEQUENCE %s RESTART;" % seq) cnx.commit() def _db_exists(self, cur, dbname): cur.execute( "SELECT EXISTS " "(SELECT datname FROM pg_catalog.pg_database WHERE datname= %s);", (dbname,), ) row = cur.fetchone() return (row is not None) and row[0] def init(self) -> None: """Create database in postgresql out of a template it if it exists, bare creation otherwise.""" template_name = f"{self.dbname}_tmpl" logger.debug("Initialize DB %s", self.dbname) with self.cursor() as cur: tmpl_exists = self._db_exists(cur, template_name) db_exists = self._db_exists(cur, self.dbname) if not db_exists: if tmpl_exists: logger.debug( "Create %s from template %s", self.dbname, template_name ) cur.execute( f'CREATE DATABASE "{self.dbname}" TEMPLATE "{template_name}";' ) else: logger.debug("Create %s from scratch", self.dbname) cur.execute(f'CREATE DATABASE "{self.dbname}";') if self.dump_files: logger.warning( "Using dump_files on the postgresql_fact fixture " "is deprecated. See swh.core documentation for more " "details." ) for dump_file in gen_dump_files(self.dump_files): logger.info(f"Loading {dump_file}") self.psql_exec(dump_file) else: logger.debug("Reset %s", self.dbname) self.db_reset() def drop(self) -> None: """Drop database in postgresql.""" if self.no_db_drop: with self.cursor() as cur: self._terminate_connection(cur, self.dbname) else: super().drop() # the postgres_fact factory fixture below is mostly a copy of the code # from pytest-postgresql. We need a custom version here to be able to # specify our version of the DBJanitor we use. def postgresql_fact( process_fixture_name: str, dbname: Optional[str] = None, load: Optional[Sequence[Union[Callable, str]]] = None, isolation_level: Optional[int] = None, modname: Optional[str] = None, dump_files: Optional[Union[str, List[str]]] = None, no_truncate_tables: Set[str] = {"dbversion"}, no_db_drop: bool = False, ) -> Callable[[FixtureRequest], Iterator[connection]]: """ Return connection fixture factory for PostgreSQL. :param process_fixture_name: name of the process fixture :param dbname: database name :param load: SQL, function or function import paths to automatically load into our test database :param isolation_level: optional postgresql isolation level defaults to server's default :param modname: (swh) module name for which the database is created :dump_files: (deprecated, use load instead) list of sql script files to execute after the database has been created :no_truncate_tables: list of table not to truncate between tests (only used when no_db_drop is True) :no_db_drop: if True, keep the database between tests; in which case, the database is reset (see SWHDatabaseJanitor.db_reset()) by truncating most of the tables. Note that this makes de facto tests (potentially) interdependent, use with extra caution. :returns: function which makes a connection to postgresql """ @pytest.fixture def postgresql_factory(request: FixtureRequest) -> Iterator[connection]: """ Fixture factory for PostgreSQL. :param request: fixture request object :returns: postgresql client """ check_for_psycopg2() proc_fixture: Union[PostgreSQLExecutor, NoopExecutor] = request.getfixturevalue( process_fixture_name ) pg_host = proc_fixture.host pg_port = proc_fixture.port pg_user = proc_fixture.user pg_password = proc_fixture.password pg_options = proc_fixture.options pg_db = dbname or proc_fixture.dbname pg_load = load or [] assert pg_db is not None with SWHDatabaseJanitor( pg_user, pg_host, pg_port, pg_db, proc_fixture.version, pg_password, isolation_level=isolation_level, dump_files=dump_files, no_truncate_tables=no_truncate_tables, no_db_drop=no_db_drop, ) as janitor: db_connection: connection = psycopg2.connect( dbname=pg_db, user=pg_user, password=pg_password, host=pg_host, port=pg_port, options=pg_options, ) for load_element in pg_load: janitor.load(load_element) try: yield db_connection finally: db_connection.close() return postgresql_factory def initialize_database_for_module(modname, version, **kwargs): conninfo = psycopg2.connect(**kwargs).dsn init_admin_extensions(modname, conninfo) populate_database_for_package(modname, conninfo) try: swh_set_db_version(conninfo, version) except psycopg2.errors.UniqueViolation: logger.warn( "Version already set by db init scripts. " "This generally means the swh.{modname} package needs to be " "updated for swh.core>=1.2" ) def gen_dump_files(dump_files: Union[str, Iterable[str]]) -> Iterator[str]: - """Generate files potentially resolving glob patterns if any - - """ + """Generate files potentially resolving glob patterns if any""" if isinstance(dump_files, str): dump_files = [dump_files] for dump_file in dump_files: if glob.has_magic(dump_file): # if the dump_file is a glob pattern one, resolve it yield from ( fname for fname in sorted(glob.glob(dump_file), key=basename_sortkey) ) else: # otherwise, just return the filename yield dump_file diff --git a/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py b/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py index 67a3fb5..ba8c1b3 100644 --- a/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py +++ b/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py @@ -1,185 +1,177 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import glob import os from pytest_postgresql import factories from swh.core.db import BaseDb from swh.core.db.pytest_plugin import gen_dump_files, postgresql_fact SQL_DIR = os.path.join(os.path.dirname(__file__), "data") test_postgresql_proc = factories.postgresql_proc( dbname="fun", load=sorted(glob.glob(f"{SQL_DIR}/*.sql")), # type: ignore[arg-type] # type ignored because load is typed as Optional[List[...]] instead of an # Optional[Sequence[...]] in pytest_postgresql<4 ) # db with special policy for tables dbversion and people postgres_fun = postgresql_fact( - "test_postgresql_proc", no_db_drop=True, no_truncate_tables={"dbversion", "people"}, + "test_postgresql_proc", + no_db_drop=True, + no_truncate_tables={"dbversion", "people"}, ) postgres_fun2 = postgresql_fact( "test_postgresql_proc", dbname="fun2", load=sorted(glob.glob(f"{SQL_DIR}/*.sql")), no_truncate_tables={"dbversion", "people"}, no_db_drop=True, ) def test_smoke_test_fun_db_is_up(postgres_fun): - """This ensures the db is created and configured according to its dumps files. - - """ + """This ensures the db is created and configured according to its dumps files.""" with BaseDb.connect(postgres_fun.dsn).cursor() as cur: cur.execute("select count(*) from dbversion") nb_rows = cur.fetchone()[0] assert nb_rows == 5 cur.execute("select count(*) from fun") nb_rows = cur.fetchone()[0] assert nb_rows == 3 cur.execute("select count(*) from people") nb_rows = cur.fetchone()[0] assert nb_rows == 2 # in data, we requested a value already so it starts at 2 cur.execute("select nextval('serial')") val = cur.fetchone()[0] assert val == 2 def test_smoke_test_fun2_db_is_up(postgres_fun2): - """This ensures the db is created and configured according to its dumps files. - - """ + """This ensures the db is created and configured according to its dumps files.""" with BaseDb.connect(postgres_fun2.dsn).cursor() as cur: cur.execute("select count(*) from dbversion") nb_rows = cur.fetchone()[0] assert nb_rows == 5 cur.execute("select count(*) from fun") nb_rows = cur.fetchone()[0] assert nb_rows == 3 cur.execute("select count(*) from people") nb_rows = cur.fetchone()[0] assert nb_rows == 2 # in data, we requested a value already so it starts at 2 cur.execute("select nextval('serial')") val = cur.fetchone()[0] assert val == 2 def test_smoke_test_fun_db_is_still_up_and_got_reset(postgres_fun): """This ensures that within another tests, the 'fun' db is still up, created (and not configured again). This time, most of the data has been reset: - except for tables 'dbversion' and 'people' which were left as is - the other tables from the schema (here only "fun") got truncated - the sequences got truncated as well """ with BaseDb.connect(postgres_fun.dsn).cursor() as cur: # db version is excluded from the truncate cur.execute("select count(*) from dbversion") nb_rows = cur.fetchone()[0] assert nb_rows == 5 # people is also allowed not to be truncated cur.execute("select count(*) from people") nb_rows = cur.fetchone()[0] assert nb_rows == 2 # table and sequence are reset cur.execute("select count(*) from fun") nb_rows = cur.fetchone()[0] assert nb_rows == 0 cur.execute("select nextval('serial')") val = cur.fetchone()[0] assert val == 1 # db with no special policy for tables truncation, all tables are reset postgres_people = postgresql_fact( "postgresql_proc", dbname="people", dump_files=f"{SQL_DIR}/*.sql", no_truncate_tables=set(), no_db_drop=True, ) def test_gen_dump_files(): files = [os.path.basename(fn) for fn in gen_dump_files(f"{SQL_DIR}/*.sql")] assert files == ["0-schema.sql", "1-data.sql"] def test_smoke_test_people_db_up(postgres_people): - """'people' db is up and configured - - """ + """'people' db is up and configured""" with BaseDb.connect(postgres_people.dsn).cursor() as cur: cur.execute("select count(*) from dbversion") nb_rows = cur.fetchone()[0] assert nb_rows == 5 cur.execute("select count(*) from people") nb_rows = cur.fetchone()[0] assert nb_rows == 2 cur.execute("select count(*) from fun") nb_rows = cur.fetchone()[0] assert nb_rows == 3 cur.execute("select nextval('serial')") val = cur.fetchone()[0] assert val == 2 def test_smoke_test_people_db_up_and_reset(postgres_people): - """'people' db is up and got reset on every tables and sequences - - """ + """'people' db is up and got reset on every tables and sequences""" with BaseDb.connect(postgres_people.dsn).cursor() as cur: # tables are truncated after the first round cur.execute("select count(*) from dbversion") nb_rows = cur.fetchone()[0] assert nb_rows == 0 # tables are truncated after the first round cur.execute("select count(*) from people") nb_rows = cur.fetchone()[0] assert nb_rows == 0 # table and sequence are reset cur.execute("select count(*) from fun") nb_rows = cur.fetchone()[0] assert nb_rows == 0 cur.execute("select nextval('serial')") val = cur.fetchone()[0] assert val == 1 # db with no initialization step, an empty db postgres_no_init = postgresql_fact("postgresql_proc", dbname="something") def test_smoke_test_db_no_init(postgres_no_init): - """We can connect to the db nonetheless - - """ + """We can connect to the db nonetheless""" with BaseDb.connect(postgres_no_init.dsn).cursor() as cur: cur.execute("select now()") data = cur.fetchone()[0] assert data is not None diff --git a/swh/core/db/tests/test_cli.py b/swh/core/db/tests/test_cli.py index 3164ffa..12d9927 100644 --- a/swh/core/db/tests/test_cli.py +++ b/swh/core/db/tests/test_cli.py @@ -1,350 +1,336 @@ # Copyright (C) 2019-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import copy import os import traceback import pytest import yaml from swh.core.cli.db import db as swhdb from swh.core.db import BaseDb from swh.core.db.db_utils import import_swhmodule, swh_db_module, swh_db_version from swh.core.tests.test_cli import assert_section_contains def test_cli_swh_help(swhmain, cli_runner): swhmain.add_command(swhdb) result = cli_runner.invoke(swhmain, ["-h"]) assert result.exit_code == 0 assert_section_contains( result.output, "Commands", "db Software Heritage database generic tools." ) help_db_snippets = ( ( "Usage", ( "Usage: swh db [OPTIONS] COMMAND [ARGS]...", "Software Heritage database generic tools.", ), ), ( "Commands", ( "create Create a database for the Software Heritage .", "init Initialize a database for the Software Heritage .", "init-admin Execute superuser-level initialization steps", ), ), ) def test_cli_swh_db_help(swhmain, cli_runner): swhmain.add_command(swhdb) result = cli_runner.invoke(swhmain, ["db", "-h"]) assert result.exit_code == 0 for section, snippets in help_db_snippets: for snippet in snippets: assert_section_contains(result.output, section, snippet) @pytest.fixture def swh_db_cli(cli_runner, monkeypatch, postgresql): """This initializes a cli_runner and sets the correct environment variable expected by - the cli to run appropriately (when not specifying the --dbname flag) + the cli to run appropriately (when not specifying the --dbname flag) """ db_params = postgresql.get_dsn_parameters() monkeypatch.setenv("PGHOST", db_params["host"]) monkeypatch.setenv("PGUSER", db_params["user"]) monkeypatch.setenv("PGPORT", db_params["port"]) return cli_runner, db_params def craft_conninfo(test_db, dbname=None) -> str: """Craft conninfo string out of the test_db object. This also allows to override the dbname.""" db_params = test_db.get_dsn_parameters() if dbname: params = copy.deepcopy(db_params) params["dbname"] = dbname else: params = db_params return "postgresql://{user}@{host}:{port}/{dbname}".format(**params) def test_cli_swh_db_create_and_init_db(cli_runner, postgresql, mock_import_swhmodule): - """Create a db then initializing it should be ok - - """ + """Create a db then initializing it should be ok""" module_name = "test.cli" conninfo = craft_conninfo(postgresql, "new-db") # This creates the db and installs the necessary admin extensions result = cli_runner.invoke(swhdb, ["create", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" # This initializes the schema and data result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" # the origin value in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, installed during db creation step) with BaseDb.connect(conninfo).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 def test_cli_swh_db_initialization_fail_without_creation_first( cli_runner, postgresql, mock_import_swhmodule ): - """Init command on an inexisting db cannot work - - """ + """Init command on an inexisting db cannot work""" module_name = "test.cli" # it's mocked here conninfo = craft_conninfo(postgresql, "inexisting-db") result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) # Fails because we cannot connect to an inexisting db assert result.exit_code == 1, f"Unexpected output: {result.output}" def test_cli_swh_db_initialization_fail_without_extension( cli_runner, postgresql, mock_import_swhmodule ): """Init command cannot work without privileged extension. - In this test, the schema needs privileged extension to work. + In this test, the schema needs privileged extension to work. """ module_name = "test.cli" # it's mocked here conninfo = craft_conninfo(postgresql) result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) # Fails as the function `public.digest` is not installed, init-admin calls is needed # first (the next tests show such behavior) assert result.exit_code == 1, f"Unexpected output: {result.output}" def test_cli_swh_db_initialization_works_with_flags( cli_runner, postgresql, mock_import_swhmodule ): - """Init commands with carefully crafted libpq conninfo works - - """ + """Init commands with carefully crafted libpq conninfo works""" module_name = "test.cli" # it's mocked here conninfo = craft_conninfo(postgresql) result = cli_runner.invoke(swhdb, ["init-admin", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" # the origin values in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, init-admin calls installs it) with BaseDb.connect(postgresql.dsn).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 def test_cli_swh_db_initialization_with_env( swh_db_cli, mock_import_swhmodule, postgresql ): - """Init commands with standard environment variables works - - """ + """Init commands with standard environment variables works""" module_name = "test.cli" # it's mocked here cli_runner, db_params = swh_db_cli result = cli_runner.invoke( swhdb, ["init-admin", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke( swhdb, ["init", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" # the origin values in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, init-admin calls installs it) with BaseDb.connect(postgresql.dsn).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 def test_cli_swh_db_initialization_idempotent( swh_db_cli, mock_import_swhmodule, postgresql ): - """Multiple runs of the init commands are idempotent - - """ + """Multiple runs of the init commands are idempotent""" module_name = "test.cli" # mocked cli_runner, db_params = swh_db_cli result = cli_runner.invoke( swhdb, ["init-admin", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke( swhdb, ["init", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke( swhdb, ["init-admin", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke( swhdb, ["init", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" # the origin values in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, init-admin calls installs it) with BaseDb.connect(postgresql.dsn).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 def test_cli_swh_db_create_and_init_db_new_api( cli_runner, postgresql, mock_import_swhmodule, mocker, tmp_path ): - """Create a db then initializing it should be ok for a "new style" datastore - - """ + """Create a db then initializing it should be ok for a "new style" datastore""" module_name = "test.cli_new" conninfo = craft_conninfo(postgresql) # This initializes the schema and data cfgfile = tmp_path / "config.yml" cfgfile.write_text(yaml.dump({module_name: {"cls": "postgresql", "db": conninfo}})) result = cli_runner.invoke(swhdb, ["init-admin", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke(swhdb, ["-C", cfgfile, "init", module_name]) assert ( result.exit_code == 0 ), f"Unexpected output: {traceback.print_tb(result.exc_info[2])}" # the origin value in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, installed during db creation step) with BaseDb.connect(conninfo).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 def test_cli_swh_db_upgrade_new_api(cli_runner, postgresql, datadir, mocker, tmp_path): - """Upgrade scenario for a "new style" datastore - - """ + """Upgrade scenario for a "new style" datastore""" module_name = "test.cli_new" # the `current_version` variable is the version that will be returned by # any call to `get_current_version()` in this test session, thanks to the # local mocked version of import_swhmodule() below. current_version = 1 # custom version of the mockup to make it easy to change the # current_version returned by get_current_version() # TODO: find a better solution for this... def import_swhmodule_mock(modname): if modname.startswith("test."): dirname = modname.split(".", 1)[1] def get_datastore(cls, **kw): return mocker.MagicMock(get_current_version=lambda: current_version) return mocker.MagicMock( __name__=modname, __file__=os.path.join(datadir, dirname, "__init__.py"), name=modname, get_datastore=get_datastore, ) return import_swhmodule(modname) mocker.patch("swh.core.db.db_utils.import_swhmodule", import_swhmodule_mock) conninfo = craft_conninfo(postgresql) # This initializes the schema and data cfgfile = tmp_path / "config.yml" cfgfile.write_text(yaml.dump({module_name: {"cls": "postgresql", "db": conninfo}})) result = cli_runner.invoke(swhdb, ["init-admin", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke(swhdb, ["-C", cfgfile, "init", module_name]) assert ( result.exit_code == 0 ), f"Unexpected output: {traceback.print_tb(result.exc_info[2])}" assert swh_db_version(conninfo) == 1 # the upgrade should not do anything because the datastore does advertise # version 1 result = cli_runner.invoke(swhdb, ["-C", cfgfile, "upgrade", module_name]) assert swh_db_version(conninfo) == 1 # advertise current version as 3, a simple upgrade should get us there, but # no further current_version = 3 result = cli_runner.invoke(swhdb, ["-C", cfgfile, "upgrade", module_name]) assert swh_db_version(conninfo) == 3 # an attempt to go further should not do anything result = cli_runner.invoke( swhdb, ["-C", cfgfile, "upgrade", module_name, "--to-version", 5] ) assert swh_db_version(conninfo) == 3 # an attempt to go lower should not do anything result = cli_runner.invoke( swhdb, ["-C", cfgfile, "upgrade", module_name, "--to-version", 2] ) assert swh_db_version(conninfo) == 3 # advertise current version as 6, an upgrade with --to-version 4 should # stick to the given version 4 and no further current_version = 6 result = cli_runner.invoke( swhdb, ["-C", cfgfile, "upgrade", module_name, "--to-version", 4] ) assert swh_db_version(conninfo) == 4 assert "migration was not complete" in result.output # attempt to upgrade to a newer version than current code version fails result = cli_runner.invoke( swhdb, ["-C", cfgfile, "upgrade", module_name, "--to-version", current_version + 1], ) assert result.exit_code != 0 assert swh_db_version(conninfo) == 4 cnx = BaseDb.connect(conninfo) with cnx.transaction() as cur: cur.execute("drop table dbmodule") assert swh_db_module(conninfo) is None # db migration should recreate the missing dbmodule table result = cli_runner.invoke(swhdb, ["-C", cfgfile, "upgrade", module_name]) assert result.exit_code == 0 assert "Warning: the database does not have a dbmodule table." in result.output assert ( "Write the module information (test.cli_new) in the database? [Y/n]" in result.output ) assert swh_db_module(conninfo) == module_name diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py index 15c6dfe..a06794b 100644 --- a/swh/core/db/tests/test_db.py +++ b/swh/core/db/tests/test_db.py @@ -1,459 +1,466 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from dataclasses import dataclass import datetime from enum import IntEnum import inspect from string import printable from typing import Any from unittest.mock import MagicMock, Mock import uuid from hypothesis import given, settings, strategies from hypothesis.extra.pytz import timezones import psycopg2 import pytest from typing_extensions import Protocol from swh.core.db import BaseDb from swh.core.db.common import db_transaction, db_transaction_generator from swh.core.db.pytest_plugin import postgresql_fact from swh.core.db.tests.conftest import function_scoped_fixture_check # workaround mypy bug https://github.com/python/mypy/issues/5485 class Converter(Protocol): def __call__(self, x: Any) -> Any: ... @dataclass class Field: name: str """Column name""" pg_type: str """Type of the PostgreSQL column""" example: Any """Example value for the static tests""" strategy: strategies.SearchStrategy """Hypothesis strategy to generate these values""" in_wrapper: Converter = lambda x: x """Wrapper to convert this data type for the static tests""" out_converter: Converter = lambda x: x """Converter from the raw PostgreSQL column value to this data type""" # Limit PostgreSQL integer values pg_int = strategies.integers(-2147483648, +2147483647) pg_text = strategies.text( alphabet=strategies.characters( blacklist_categories=["Cs"], # surrogates blacklist_characters=[ "\x00", # pgsql does not support the null codepoint "\r", # pgsql normalizes those ], ), ) pg_bytea = strategies.binary() def pg_bytea_a(min_size: int, max_size: int) -> strategies.SearchStrategy: """Generate a PostgreSQL bytea[]""" return strategies.lists(pg_bytea, min_size=min_size, max_size=max_size) def pg_bytea_a_a(min_size: int, max_size: int) -> strategies.SearchStrategy: """Generate a PostgreSQL bytea[][]. The inner lists must all have the same size.""" return strategies.integers(min_value=max(1, min_size), max_value=max_size).flatmap( lambda n: strategies.lists( pg_bytea_a(min_size=n, max_size=n), min_size=min_size, max_size=max_size ) ) def pg_tstz() -> strategies.SearchStrategy: """Generate values that fit in a PostgreSQL timestamptz. Notes: We're forbidding old datetimes, because until 1956, many timezones had seconds in their "UTC offsets" (see ), which is not representable by PostgreSQL. """ min_value = datetime.datetime(1960, 1, 1, 0, 0, 0) return strategies.datetimes(min_value=min_value, timezones=timezones()) def pg_jsonb(min_size: int, max_size: int) -> strategies.SearchStrategy: """Generate values representable as a PostgreSQL jsonb object (dict).""" return strategies.dictionaries( strategies.text(printable), strategies.recursive( # should use floats() instead of integers(), but PostgreSQL # coerces large integers into floats, making the tests fail. We # only store ints in our generated data anyway. strategies.none() | strategies.booleans() | strategies.integers(-2147483648, +2147483647) | strategies.text(printable), lambda children: strategies.lists(children, max_size=max_size) | strategies.dictionaries( strategies.text(printable), children, max_size=max_size ), ), min_size=min_size, max_size=max_size, ) def tuple_2d_to_list_2d(v): """Convert a 2D tuple to a 2D list""" return [list(inner) for inner in v] def list_2d_to_tuple_2d(v): """Convert a 2D list to a 2D tuple""" return tuple(tuple(inner) for inner in v) class TestIntEnum(IntEnum): foo = 1 bar = 2 def now(): return datetime.datetime.now(tz=datetime.timezone.utc) FIELDS = ( Field("i", "int", 1, pg_int), Field("txt", "text", "foo", pg_text), Field("bytes", "bytea", b"bar", strategies.binary()), Field( "bytes_array", "bytea[]", [b"baz1", b"baz2"], pg_bytea_a(min_size=0, max_size=5), ), Field( "bytes_tuple", "bytea[]", (b"baz1", b"baz2"), pg_bytea_a(min_size=0, max_size=5).map(tuple), in_wrapper=list, out_converter=tuple, ), Field( "bytes_2d", "bytea[][]", [[b"quux1"], [b"quux2"]], pg_bytea_a_a(min_size=0, max_size=5), ), Field( "bytes_2d_tuple", "bytea[][]", ((b"quux1",), (b"quux2",)), pg_bytea_a_a(min_size=0, max_size=5).map(list_2d_to_tuple_2d), in_wrapper=tuple_2d_to_list_2d, out_converter=list_2d_to_tuple_2d, ), - Field("ts", "timestamptz", now(), pg_tstz(),), + Field( + "ts", + "timestamptz", + now(), + pg_tstz(), + ), Field( "dict", "jsonb", {"str": "bar", "int": 1, "list": ["a", "b"], "nested": {"a": "b"}}, pg_jsonb(min_size=0, max_size=5), in_wrapper=psycopg2.extras.Json, ), Field( "intenum", "int", TestIntEnum.foo, strategies.sampled_from(TestIntEnum), in_wrapper=int, out_converter=TestIntEnum, ), Field("uuid", "uuid", uuid.uuid4(), strategies.uuids()), Field( "text_list", "text[]", # All the funky corner cases ["null", "NULL", None, "\\", "\t", "\n", "\r", " ", "'", ",", '"', "{", "}"], strategies.lists(pg_text, min_size=0, max_size=5), ), Field( "tstz_list", "timestamptz[]", [now(), now() + datetime.timedelta(days=1)], strategies.lists(pg_tstz(), min_size=0, max_size=5), ), Field( "tstz_range", "tstzrange", psycopg2.extras.DateTimeTZRange( - lower=now(), upper=now() + datetime.timedelta(days=1), bounds="[)", + lower=now(), + upper=now() + datetime.timedelta(days=1), + bounds="[)", ), strategies.tuples( # generate two sorted timestamptzs for use as bounds strategies.tuples(pg_tstz(), pg_tstz()).map(sorted), # and a set of bounds strategies.sampled_from(["[]", "()", "[)", "(]"]), ).map( # and build the actual DateTimeTZRange object from these args lambda args: psycopg2.extras.DateTimeTZRange( - lower=args[0][0], upper=args[0][1], bounds=args[1], + lower=args[0][0], + upper=args[0][1], + bounds=args[1], ) ), ), ) INIT_SQL = "create table test_table (%s)" % ", ".join( f"{field.name} {field.pg_type}" for field in FIELDS ) COLUMNS = tuple(field.name for field in FIELDS) INSERT_SQL = "insert into test_table (%s) values (%s)" % ( ", ".join(COLUMNS), ", ".join("%s" for i in range(len(COLUMNS))), ) STATIC_ROW_IN = tuple(field.in_wrapper(field.example) for field in FIELDS) EXPECTED_ROW_OUT = tuple(field.example for field in FIELDS) db_rows = strategies.lists(strategies.tuples(*(field.strategy for field in FIELDS))) def convert_lines(cur): return [ tuple(field.out_converter(x) for x, field in zip(line, FIELDS)) for line in cur ] test_db = postgresql_fact("postgresql_proc", dbname="test-db2") @pytest.fixture def db_with_data(test_db, request): - """Fixture to initialize a db with some data out of the "INIT_SQL above - - """ + """Fixture to initialize a db with some data out of the "INIT_SQL above""" db = BaseDb.connect(test_db.dsn) with db.cursor() as cur: psycopg2.extras.register_default_jsonb(cur) cur.execute(INIT_SQL) yield db db.conn.rollback() db.conn.close() @pytest.mark.db def test_db_connect(db_with_data): with db_with_data.cursor() as cur: psycopg2.extras.register_default_jsonb(cur) cur.execute(INSERT_SQL, STATIC_ROW_IN) cur.execute("select * from test_table;") output = convert_lines(cur) assert len(output) == 1 assert EXPECTED_ROW_OUT == output[0] def test_db_initialized(db_with_data): with db_with_data.cursor() as cur: psycopg2.extras.register_default_jsonb(cur) cur.execute(INSERT_SQL, STATIC_ROW_IN) cur.execute("select * from test_table;") output = convert_lines(cur) assert len(output) == 1 assert EXPECTED_ROW_OUT == output[0] def test_db_copy_to_static(db_with_data): items = [{field.name: field.example for field in FIELDS}] db_with_data.copy_to(items, "test_table", COLUMNS) with db_with_data.cursor() as cur: cur.execute("select * from test_table;") output = convert_lines(cur) assert len(output) == 1 assert EXPECTED_ROW_OUT == output[0] @settings(suppress_health_check=function_scoped_fixture_check) @given(db_rows) def test_db_copy_to(db_with_data, data): items = [dict(zip(COLUMNS, item)) for item in data] with db_with_data.cursor() as cur: cur.execute("TRUNCATE TABLE test_table CASCADE") db_with_data.copy_to(items, "test_table", COLUMNS) with db_with_data.cursor() as cur: cur.execute("select * from test_table;") converted_lines = convert_lines(cur) assert converted_lines == data def test_db_copy_to_thread_exception(db_with_data): - data = [(2 ** 65, "foo", b"bar")] + data = [(2**65, "foo", b"bar")] items = [dict(zip(COLUMNS, item)) for item in data] with pytest.raises(psycopg2.errors.NumericValueOutOfRange): db_with_data.copy_to(items, "test_table", COLUMNS) def test_db_transaction(mocker): expected_cur = object() called = False class Storage: @db_transaction() def endpoint(self, cur=None, db=None): nonlocal called called = True assert cur is expected_cur storage = Storage() # 'with storage.get_db().transaction() as cur:' should cause # 'cur' to be 'expected_cur' db_mock = Mock() db_mock.transaction.return_value = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) put_db_mock = mocker.patch.object(storage, "put_db", create=True) storage.endpoint() assert called put_db_mock.assert_called_once_with(db_mock) def test_db_transaction__with_generator(): with pytest.raises(ValueError, match="generator"): class Storage: @db_transaction() def endpoint(self, cur=None, db=None): yield None def test_db_transaction_signature(): """Checks db_transaction removes the 'cur' and 'db' arguments.""" def f(self, foo, *, bar=None): pass expected_sig = inspect.signature(f) @db_transaction() def g(self, foo, *, bar=None, db=None, cur=None): pass actual_sig = inspect.signature(g) assert actual_sig == expected_sig def test_db_transaction_generator(mocker): expected_cur = object() called = False class Storage: @db_transaction_generator() def endpoint(self, cur=None, db=None): nonlocal called called = True assert cur is expected_cur yield None storage = Storage() # 'with storage.get_db().transaction() as cur:' should cause # 'cur' to be 'expected_cur' db_mock = Mock() db_mock.transaction.return_value = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) put_db_mock = mocker.patch.object(storage, "put_db", create=True) list(storage.endpoint()) assert called put_db_mock.assert_called_once_with(db_mock) def test_db_transaction_generator__with_nongenerator(): with pytest.raises(ValueError, match="generator"): class Storage: @db_transaction_generator() def endpoint(self, cur=None, db=None): pass def test_db_transaction_generator_signature(): """Checks db_transaction removes the 'cur' and 'db' arguments.""" def f(self, foo, *, bar=None): pass expected_sig = inspect.signature(f) @db_transaction_generator() def g(self, foo, *, bar=None, db=None, cur=None): yield None actual_sig = inspect.signature(g) assert actual_sig == expected_sig @pytest.mark.parametrize( "query_options", (None, {"something": 42, "statement_timeout": 200}) ) @pytest.mark.parametrize("use_generator", (True, False)) def test_db_transaction_query_options(mocker, use_generator, query_options): class Storage: @db_transaction(statement_timeout=100) def endpoint(self, cur=None, db=None): return [None] @db_transaction_generator(statement_timeout=100) def gen_endpoint(self, cur=None, db=None): yield None storage = Storage() # mockers mocked_apply = mocker.patch("swh.core.db.common.apply_options") # 'with storage.get_db().transaction() as cur:' should cause # 'cur' to be 'expected_cur' expected_cur = object() db_mock = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) mocker.patch.object(storage, "put_db", create=True) if query_options: storage.query_options = { "endpoint": query_options, "gen_endpoint": query_options, } if use_generator: list(storage.gen_endpoint()) else: list(storage.endpoint()) mocked_apply.assert_called_once_with( expected_cur, query_options if query_options is not None else {"statement_timeout": 100}, ) diff --git a/swh/core/db/tests/test_db_utils.py b/swh/core/db/tests/test_db_utils.py index 7e2719d..0a633f3 100644 --- a/swh/core/db/tests/test_db_utils.py +++ b/swh/core/db/tests/test_db_utils.py @@ -1,189 +1,185 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime, timedelta from os import path import pytest from swh.core.cli.db import db as swhdb from swh.core.db import BaseDb from swh.core.db.db_utils import ( get_database_info, get_sql_for_package, now, swh_db_module, swh_db_upgrade, swh_db_version, swh_db_versions, swh_set_db_module, ) from .test_cli import craft_conninfo @pytest.mark.parametrize("module", ["test.cli", "test.cli_new"]) def test_get_sql_for_package(mock_import_swhmodule, module): files = get_sql_for_package(module) assert files assert [f.name for f in files] == [ "0-superuser-init.sql", "30-schema.sql", "40-funcs.sql", "50-data.sql", ] @pytest.mark.parametrize("module", ["test.cli", "test.cli_new"]) def test_db_utils_versions(cli_runner, postgresql, mock_import_swhmodule, module): """Check get_database_info, swh_db_versions and swh_db_module work ok This test checks db versions for both a db with "new style" set of sql init scripts (i.e. the dbversion table is not created in these scripts, but by the populate_database_for_package() function directly, via the 'swh db init' command) and an "old style" set (dbversion created in the scripts)S. """ conninfo = craft_conninfo(postgresql) result = cli_runner.invoke(swhdb, ["init-admin", module, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke( swhdb, ["init", module, "--dbname", conninfo, "--initial-version", 10] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" # check the swh_db_module() function assert swh_db_module(conninfo) == module # the dbversion and dbmodule tables exists and are populated dbmodule, dbversion, dbflavor = get_database_info(conninfo) # check also the swh_db_versions() function versions = swh_db_versions(conninfo) assert dbmodule == module assert dbversion == 10 assert dbflavor is None # check also the swh_db_versions() function versions = swh_db_versions(conninfo) assert len(versions) == 1 assert versions[0][0] == 10 if module == "test.cli": assert versions[0][1] == datetime.fromisoformat( "2016-02-22T15:56:28.358587+00:00" ) assert versions[0][2] == "Work In Progress" else: # new scheme but with no datastore (so no version support from there) assert versions[0][2] == "DB initialization" # add a few versions in dbversion cnx = BaseDb.connect(conninfo) with cnx.transaction() as cur: cur.executemany( "insert into dbversion(version, release, description) values (%s, %s, %s)", [(i, now(), f"Upgrade to version {i}") for i in range(11, 15)], ) dbmodule, dbversion, dbflavor = get_database_info(conninfo) assert dbmodule == module assert dbversion == 14 assert dbflavor is None versions = swh_db_versions(conninfo) assert len(versions) == 5 for i, (version, ts, desc) in enumerate(versions): assert version == (14 - i) # these are in reverse order if version > 10: assert desc == f"Upgrade to version {version}" assert (now() - ts) < timedelta(seconds=1) @pytest.mark.parametrize("module", ["test.cli_new"]) def test_db_utils_upgrade( cli_runner, postgresql, mock_import_swhmodule, module, datadir ): - """Check swh_db_upgrade - - """ + """Check swh_db_upgrade""" conninfo = craft_conninfo(postgresql) result = cli_runner.invoke(swhdb, ["init-admin", module, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke(swhdb, ["init", module, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" assert swh_db_version(conninfo) == 1 new_version = swh_db_upgrade(conninfo, module) assert new_version == 6 assert swh_db_version(conninfo) == 6 versions = swh_db_versions(conninfo) # get rid of dates to ease checking versions = [(v[0], v[2]) for v in versions] assert versions[-1] == (1, "DB initialization") sqlbasedir = path.join(datadir, module.split(".", 1)[1], "sql", "upgrades") assert versions[1:-1] == [ (i, f"Upgraded to version {i} using {sqlbasedir}/{i:03d}.sql") for i in range(5, 1, -1) ] assert versions[0] == (6, "Updated version from upgrade script") cnx = BaseDb.connect(conninfo) with cnx.transaction() as cur: cur.execute("select url from origin where url like 'version%'") result = cur.fetchall() assert result == [("version%03d" % i,) for i in range(2, 7)] cur.execute( "select url from origin where url = 'this should never be executed'" ) result = cur.fetchall() assert not result @pytest.mark.parametrize("module", ["test.cli_new"]) def test_db_utils_swh_db_upgrade_sanity_checks( cli_runner, postgresql, mock_import_swhmodule, module, datadir ): - """Check swh_db_upgrade - - """ + """Check swh_db_upgrade""" conninfo = craft_conninfo(postgresql) result = cli_runner.invoke(swhdb, ["init-admin", module, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke(swhdb, ["init", module, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" cnx = BaseDb.connect(conninfo) with cnx.transaction() as cur: cur.execute("drop table dbmodule") # try to upgrade with a unset module with pytest.raises(ValueError): swh_db_upgrade(conninfo, module) # check the dbmodule is unset assert swh_db_module(conninfo) is None # set the stored module to something else swh_set_db_module(conninfo, f"{module}2") assert swh_db_module(conninfo) == f"{module}2" # try to upgrade with a different module with pytest.raises(ValueError): swh_db_upgrade(conninfo, module) # revert to the proper module in the db swh_set_db_module(conninfo, module, force=True) assert swh_db_module(conninfo) == module # trying again is a noop swh_set_db_module(conninfo, module) assert swh_db_module(conninfo) == module # drop the dbversion table with cnx.transaction() as cur: cur.execute("drop table dbversion") # an upgrade should fail due to missing stored version with pytest.raises(ValueError): swh_db_upgrade(conninfo, module) diff --git a/swh/core/pytest_plugin.py b/swh/core/pytest_plugin.py index fbb25dd..2e9bcb2 100644 --- a/swh/core/pytest_plugin.py +++ b/swh/core/pytest_plugin.py @@ -1,369 +1,369 @@ # Copyright (C) 2019-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import deque from functools import partial import logging from os import path import re from typing import Dict, List, Optional from urllib.parse import unquote, urlparse import pytest import requests from requests.adapters import BaseAdapter from requests.structures import CaseInsensitiveDict from requests.utils import get_encoding_from_headers logger = logging.getLogger(__name__) # Check get_local_factory function # Maximum number of iteration checks to generate requests responses MAX_VISIT_FILES = 10 def get_response_cb( request: requests.Request, context, datadir, ignore_urls: List[str] = [], visits: Optional[Dict] = None, ): """Mount point callback to fetch on disk the request's content. The request urls provided are url decoded first to resolve the associated file on disk. This is meant to be used as 'body' argument of the requests_mock.get() method. It will look for files on the local filesystem based on the requested URL, using the following rules: - files are searched in the datadir/ directory - the local file name is the path part of the URL with path hierarchy markers (aka '/') replaced by '_' Eg. if you use the requests_mock fixture in your test file as: requests_mock.get('https?://nowhere.com', body=get_response_cb) # or even requests_mock.get(re.compile('https?://'), body=get_response_cb) then a call requests.get like: requests.get('https://nowhere.com/path/to/resource?a=b&c=d') will look the content of the response in: datadir/https_nowhere.com/path_to_resource,a=b,c=d or a call requests.get like: requests.get('http://nowhere.com/path/to/resource?a=b&c=d') will look the content of the response in: datadir/http_nowhere.com/path_to_resource,a=b,c=d Args: request: Object requests context (requests.Context): Object holding response metadata information (status_code, headers, etc...) datadir: Data files path ignore_urls: urls whose status response should be 404 even if the local file exists visits: Dict of url, number of visits. If None, disable multi visit support (default) Returns: Optional[FileDescriptor] on disk file to read from the test context """ logger.debug("get_response_cb(%s, %s)", request, context) logger.debug("url: %s", request.url) logger.debug("ignore_urls: %s", ignore_urls) unquoted_url = unquote(request.url) if unquoted_url in ignore_urls: context.status_code = 404 return None url = urlparse(unquoted_url) # http://pypi.org ~> http_pypi.org # https://files.pythonhosted.org ~> https_files.pythonhosted.org dirname = "%s_%s" % (url.scheme, url.hostname) # url.path: pypi//json -> local file: pypi__json filename = url.path[1:] if filename.endswith("/"): filename = filename[:-1] filename = filename.replace("/", "_") if url.query: filename += "," + url.query.replace("&", ",") filepath = path.join(datadir, dirname, filename) if visits is not None: visit = visits.get(url, 0) visits[url] = visit + 1 if visit: filepath = filepath + "_visit%s" % visit if not path.isfile(filepath): logger.debug("not found filepath: %s", filepath) context.status_code = 404 return None fd = open(filepath, "rb") context.headers["content-length"] = str(path.getsize(filepath)) return fd @pytest.fixture def datadir(request: pytest.FixtureRequest) -> str: """By default, returns the test directory's data directory. This can be overridden on a per file tree basis. Add an override definition in the local conftest, for example:: import pytest from os import path @pytest.fixture def datadir(): return path.join(path.abspath(path.dirname(__file__)), 'resources') """ # pytest >= 7 renamed FixtureRequest fspath attribute to path path_ = request.path if hasattr(request, "path") else request.fspath # type: ignore return path.join(path.dirname(str(path_)), "data") def requests_mock_datadir_factory( ignore_urls: List[str] = [], has_multi_visit: bool = False ): """This factory generates fixtures which allow to look for files on the local filesystem based on the requested URL, using the following rules: - files are searched in the data/ directory - the local file name is the path part of the URL with path hierarchy markers (aka '/') replaced by '_' Multiple implementations are possible, for example: ``requests_mock_datadir_factory([])`` This computes the file name from the query and always returns the same result. ``requests_mock_datadir_factory(has_multi_visit=True)`` This computes the file name from the query and returns the content of the filename the first time, the next call returning the content of files suffixed with _visit1 and so on and so forth. If the file is not found, returns a 404. ``requests_mock_datadir_factory(ignore_urls=['url1', 'url2'])`` This will ignore any files corresponding to url1 and url2, always returning 404. Args: ignore_urls: List of urls to always returns 404 (whether file exists or not) has_multi_visit: Activate or not the multiple visits behavior """ @pytest.fixture def requests_mock_datadir(requests_mock, datadir): if not has_multi_visit: cb = partial(get_response_cb, ignore_urls=ignore_urls, datadir=datadir) requests_mock.get(re.compile("https?://"), body=cb) else: visits = {} requests_mock.get( re.compile("https?://"), body=partial( get_response_cb, ignore_urls=ignore_urls, visits=visits, datadir=datadir, ), ) return requests_mock return requests_mock_datadir # Default `requests_mock_datadir` implementation requests_mock_datadir = requests_mock_datadir_factory() """ Instance of :py:func:`requests_mock_datadir_factory`, with the default arguments. """ # Implementation for multiple visits behavior: # - first time, it checks for a file named `filename` # - second time, it checks for a file named `filename`_visit1 # etc... requests_mock_datadir_visits = requests_mock_datadir_factory(has_multi_visit=True) """ Instance of :py:func:`requests_mock_datadir_factory`, with the default arguments, but `has_multi_visit=True`. """ @pytest.fixture def swh_rpc_client(swh_rpc_client_class, swh_rpc_adapter): """This fixture generates an RPCClient instance that uses the class generated by the rpc_client_class fixture as backend. Since it uses the swh_rpc_adapter, HTTP queries will be intercepted and routed directly to the current Flask app (as provided by the `app` fixture). So this stack of fixtures allows to test the RPCClient -> RPCServerApp communication path using a real RPCClient instance and a real Flask (RPCServerApp) app instance. To use this fixture: - ensure an `app` fixture exists and generate a Flask application, - implement an `swh_rpc_client_class` fixtures that returns the RPCClient-based class to use as client side for the tests, - implement your tests using this `swh_rpc_client` fixture. See swh/core/api/tests/test_rpc_client_server.py for an example of usage. """ url = "mock://example.com" cli = swh_rpc_client_class(url=url) # we need to clear the list of existing adapters here so we ensure we # have one and only one adapter which is then used for all the requests. cli.session.adapters.clear() cli.session.mount("mock://", swh_rpc_adapter) return cli @pytest.fixture def swh_rpc_adapter(app): """Fixture that generates a requests.Adapter instance that can be used to test client/servers code based on swh.core.api classes. See swh/core/api/tests/test_rpc_client_server.py for an example of usage. """ with app.test_client() as client: yield RPCTestAdapter(client) class RPCTestAdapter(BaseAdapter): def __init__(self, client): self._client = client def build_response(self, req, resp): response = requests.Response() # Fallback to None if there's no status_code, for whatever reason. response.status_code = resp.status_code # Make headers case-insensitive. response.headers = CaseInsensitiveDict(getattr(resp, "headers", {})) # Set encoding. response.encoding = get_encoding_from_headers(response.headers) response.raw = resp response.reason = response.raw.status if isinstance(req.url, bytes): response.url = req.url.decode("utf-8") else: response.url = req.url # Give the Response some context. response.request = req response.connection = self response._content = resp.data return response def send(self, request, **kw): """ Overrides ``requests.adapters.BaseAdapter.send`` """ resp = self._client.open( request.url, method=request.method, headers=request.headers.items(), data=request.body, ) return self.build_response(request, resp) @pytest.fixture def flask_app_client(app): with app.test_client() as client: yield client # stolen from pytest-flask, required to have url_for() working within tests # using flask_app_client fixture. @pytest.fixture(autouse=True) def _push_request_context(request: pytest.FixtureRequest): """During tests execution request context has been pushed, e.g. `url_for`, `session`, etc. can be used in tests as is:: def test_app(app, client): assert client.get(url_for('myview')).status_code == 200 """ if "app" not in request.fixturenames: return app = request.getfixturevalue("app") ctx = app.test_request_context() ctx.push() def teardown(): ctx.pop() request.addfinalizer(teardown) class FakeSocket(object): - """ A fake socket for testing. """ + """A fake socket for testing.""" def __init__(self): self.payloads = deque() def send(self, payload): assert type(payload) == bytes self.payloads.append(payload) def recv(self): try: return self.payloads.popleft().decode("utf-8") except IndexError: return None def close(self): pass def __repr__(self): return str(self.payloads) @pytest.fixture def statsd(): """Simple fixture giving a Statsd instance suitable for tests The Statsd instance uses a FakeSocket as `.socket` attribute in which one can get the accumulated statsd messages in a deque in `.socket.payloads`. """ from swh.core.statsd import Statsd statsd = Statsd() statsd._socket = FakeSocket() yield statsd diff --git a/swh/core/statsd.py b/swh/core/statsd.py index afe3a83..29b9240 100644 --- a/swh/core/statsd.py +++ b/swh/core/statsd.py @@ -1,495 +1,496 @@ # Copyright (C) 2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information # Initially imported from https://github.com/DataDog/datadogpy/ # at revision 62b3a3e89988dc18d78c282fe3ff5d1813917436 # # Copyright (c) 2015, Datadog # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of Datadog nor the names of its contributors may be # used to endorse or promote products derived from this software without # specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # # # Vastly adapted for integration in swh.core: # # - Removed python < 3.5 compat code # - trimmed the imports down to be a single module # - adjust some options: # - drop unix socket connection option # - add environment variable support for setting the statsd host and # port (pulled the idea from the main python statsd module) # - only send timer metrics in milliseconds (that's what # prometheus-statsd-exporter expects) # - drop DataDog-specific metric types (that are unsupported in # prometheus-statsd-exporter) # - made the tags a dict instead of a list (prometheus-statsd-exporter only # supports tags with a value, mirroring prometheus) # - switch from time.time to time.monotonic # - improve unit test coverage # - documentation cleanup from asyncio import iscoroutinefunction from contextlib import contextmanager from functools import wraps import itertools import logging import os from random import random import re import socket import threading from time import monotonic from typing import Collection, Dict, Optional import warnings log = logging.getLogger("swh.core.statsd") class TimedContextManagerDecorator(object): """ A context manager and a decorator which will report the elapsed time in the context OR in a function call. Attributes: elapsed (float): the elapsed time at the point of completion """ def __init__( self, statsd, metric=None, error_metric=None, tags=None, sample_rate=1 ): self.statsd = statsd self.metric = metric self.error_metric = error_metric self.tags = tags self.sample_rate = sample_rate self.elapsed = None # this is for testing purpose def __call__(self, func): """ Decorator which returns the elapsed time of the function call. Default to the function name if metric was not provided. """ if not self.metric: self.metric = "%s.%s" % (func.__module__, func.__name__) # Coroutines if iscoroutinefunction(func): @wraps(func) async def wrapped_co(*args, **kwargs): start = monotonic() try: result = await func(*args, **kwargs) except: # noqa self._send_error() raise self._send(start) return result return wrapped_co # Others @wraps(func) def wrapped(*args, **kwargs): start = monotonic() try: result = func(*args, **kwargs) except: # noqa self._send_error() raise self._send(start) return result return wrapped def __enter__(self): if not self.metric: raise TypeError("Cannot used timed without a metric!") self._start = monotonic() return self def __exit__(self, type, value, traceback): # Report the elapsed time of the context manager if no error. if type is None: self._send(self._start) else: self._send_error() def _send(self, start): elapsed = (monotonic() - start) * 1000 self.statsd.timing( self.metric, elapsed, tags=self.tags, sample_rate=self.sample_rate ) self.elapsed = elapsed def _send_error(self): if self.error_metric is None: self.error_metric = self.metric + "_error_count" self.statsd.increment(self.error_metric, tags=self.tags) def start(self): """Start the timer""" self.__enter__() def stop(self): """Stop the timer, send the metric value""" self.__exit__(None, None, None) class Statsd(object): """Initialize a client to send metrics to a StatsD server. Arguments: host (str): the host of the StatsD server. Defaults to localhost. port (int): the port of the StatsD server. Defaults to 8125. max_buffer_size (int): Maximum number of metrics to buffer before sending to the server if sending metrics in batch namespace (str): Namespace to prefix all metric names constant_tags (Dict[str, str]): Tags to attach to all metrics Note: This class also supports the following environment variables: STATSD_HOST Override the default host of the statsd server STATSD_PORT Override the default port of the statsd server STATSD_TAGS Tags to attach to every metric reported. Example value: "label:value,other_label:other_value" """ def __init__( self, host=None, port=None, max_buffer_size=50, namespace=None, constant_tags=None, ): # Connection if host is None: host = os.environ.get("STATSD_HOST") or "localhost" self.host = host if port is None: port = os.environ.get("STATSD_PORT") or 8125 self.port = int(port) # Socket self._socket = None self.lock = threading.Lock() self.max_buffer_size = max_buffer_size self._send = self._send_to_server self.encoding = "utf-8" # Tags self.constant_tags = {} tags_envvar = os.environ.get("STATSD_TAGS", "") for tag in tags_envvar.split(","): if not tag: continue if ":" not in tag: warnings.warn( "STATSD_TAGS needs to be in key:value format, " "%s invalid" % tag, UserWarning, ) continue k, v = tag.split(":", 1) # look for a possible env var substitution, using $NAME or ${NAME} format m = re.match(r"^[$]([{])?(?P\w+)(?(1)[}]|)$", v) if m: envvar = m.group("envvar") if envvar in os.environ: v = os.environ[envvar] self.constant_tags[k] = v if constant_tags: self.constant_tags.update( {str(k): str(v) for k, v in constant_tags.items()} ) # Namespace if namespace is not None: namespace = str(namespace) self.namespace = namespace def __enter__(self): self.open_buffer(self.max_buffer_size) return self def __exit__(self, type, value, traceback): self.close_buffer() def gauge(self, metric, value, tags=None, sample_rate=1): """ Record the value of a gauge, optionally setting a list of tags and a sample rate. >>> statsd.gauge('users.online', 123) >>> statsd.gauge('active.connections', 1001, tags={"protocol": "http"}) """ return self._report(metric, "g", value, tags, sample_rate) def increment(self, metric, value=1, tags=None, sample_rate=1): """ Increment a counter, optionally setting a value, tags and a sample rate. >>> statsd.increment('page.views') >>> statsd.increment('files.transferred', 124) """ self._report(metric, "c", value, tags, sample_rate) def decrement(self, metric, value=1, tags=None, sample_rate=1): """ Decrement a counter, optionally setting a value, tags and a sample rate. >>> statsd.decrement('files.remaining') >>> statsd.decrement('active.connections', 2) """ metric_value = -value if value else value self._report(metric, "c", metric_value, tags, sample_rate) def histogram(self, metric, value, tags=None, sample_rate=1): """ Sample a histogram value, optionally setting tags and a sample rate. >>> statsd.histogram('uploaded.file.size', 1445) >>> statsd.histogram('file.count', 26, tags={"filetype": "python"}) """ self._report(metric, "h", value, tags, sample_rate) def timing(self, metric, value, tags=None, sample_rate=1): """ Record a timing, optionally setting tags and a sample rate. >>> statsd.timing("query.response.time", 1234) """ self._report(metric, "ms", value, tags, sample_rate) def timed(self, metric=None, error_metric=None, tags=None, sample_rate=1): """ A decorator or context manager that will measure the distribution of a function's/context's run time. Optionally specify a list of tags or a sample rate. If the metric is not defined as a decorator, the module name and function name will be used. The metric is required as a context manager. :: @statsd.timed('user.query.time', sample_rate=0.5) def get_user(user_id): # Do what you need to ... pass # Is equivalent to ... with statsd.timed('user.query.time', sample_rate=0.5): # Do what you need to ... pass # Is equivalent to ... start = time.monotonic() try: get_user(user_id) finally: statsd.timing('user.query.time', time.monotonic() - start) """ return TimedContextManagerDecorator( statsd=self, metric=metric, error_metric=error_metric, tags=tags, sample_rate=sample_rate, ) def set(self, metric, value, tags=None, sample_rate=1): """ Sample a set value. >>> statsd.set('visitors.uniques', 999) """ self._report(metric, "s", value, tags, sample_rate) @property def socket(self): """ Return a connected socket. Note: connect the socket before assigning it to the class instance to avoid bad thread race conditions. """ with self.lock: if not self._socket: sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.connect((self.host, self.port)) self._socket = sock return self._socket def open_buffer(self, max_buffer_size=50): """ Open a buffer to send a batch of metrics in one packet. You can also use this as a context manager. >>> with Statsd() as batch: ... batch.gauge('users.online', 123) ... batch.gauge('active.connections', 1001) """ self.max_buffer_size = max_buffer_size self.buffer = [] self._send = self._send_to_buffer def close_buffer(self): """ Flush the buffer and switch back to single metric packets. """ self._send = self._send_to_server if self.buffer: # Only send packets if there are packets to send self._flush_buffer() def close_socket(self): """ Closes connected socket if connected. """ with self.lock: if self._socket: self._socket.close() self._socket = None def _report(self, metric, metric_type, value, tags, sample_rate): """ Create a metric packet and send it. """ if value is None: return if sample_rate != 1 and random() > sample_rate: return # Resolve the full tag list tags = self._add_constant_tags(tags) # Create/format the metric packet payload = "%s%s:%s|%s%s%s" % ( (self.namespace + ".") if self.namespace else "", metric, value, metric_type, ("|@" + str(sample_rate)) if sample_rate != 1 else "", ("|#" + ",".join("%s:%s" % (k, v) for (k, v) in sorted(tags.items()))) if tags else "", ) # Send it self._send(payload) def _send_to_server(self, packet): try: # If set, use socket directly self.socket.send(packet.encode("utf-8")) except socket.timeout: return except socket.error: log.debug( "Error submitting statsd packet." " Dropping the packet and closing the socket." ) self.close_socket() def _send_to_buffer(self, packet): self.buffer.append(packet) if len(self.buffer) >= self.max_buffer_size: self._flush_buffer() def _flush_buffer(self): self._send_to_server("\n".join(self.buffer)) self.buffer = [] def _add_constant_tags(self, tags): return { str(k): str(v) for k, v in itertools.chain( - self.constant_tags.items(), (tags if tags else {}).items(), + self.constant_tags.items(), + (tags if tags else {}).items(), ) } @contextmanager def status_gauge( self, metric_name: str, statuses: Collection[str], tags: Optional[Dict[str, str]] = None, ): """Context manager to keep track of status changes as a gauge In addition to the `metric_name` and `tags` arguments, it expects a list of `statuses` to declare which statuses are possible, and returns a callable as context manager. This callable takes ones of the possible statuses as argument. Typical usage would be: >>> with statsd.status_gauge( "metric_name", ["starting", "processing", "waiting"]) as set_status: set_status("starting") # ... set_status("waiting") # ... """ if tags is None: tags = {} current_status: Optional[str] = None # reset status gauges to make sure they do not "leak" for status in statuses: self.gauge(metric_name, 0, {**tags, "status": status}) def set_status(new_status: str): nonlocal current_status assert isinstance(tags, dict) if new_status not in statuses: raise ValueError(f"{new_status} not in {statuses}") if current_status and new_status != current_status: self.gauge(metric_name, 0, {**tags, "status": current_status}) current_status = new_status self.gauge(metric_name, 1, {**tags, "status": current_status}) yield set_status # reset gauges on exit for status in statuses: self.gauge(metric_name, 0, {**tags, "status": status}) statsd = Statsd() diff --git a/swh/core/tarball.py b/swh/core/tarball.py index ad6d377..d7b01e9 100644 --- a/swh/core/tarball.py +++ b/swh/core/tarball.py @@ -1,210 +1,202 @@ # Copyright (C) 2015-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import shutil import stat from subprocess import run import tarfile import zipfile import magic from . import utils def _unpack_tar(tarpath: str, extract_dir: str) -> str: """Unpack tarballs unsupported by the standard python library. Examples include tar.Z, tar.lz, tar.x, etc.... As this implementation relies on the `tar` command, this function supports the same compression the tar command supports. This expects the `extract_dir` to exist. Raises: shutil.ReadError in case of issue uncompressing the archive (tarpath does not exist, extract_dir does not exist, etc...) Returns: full path to the uncompressed directory. """ try: run(["tar", "xf", tarpath, "-C", extract_dir], check=True) return extract_dir except Exception as e: raise shutil.ReadError( f"Unable to uncompress {tarpath} to {extract_dir}. Reason: {e}" ) def _unpack_zip(zippath: str, extract_dir: str) -> str: """Unpack zip files unsupported by the standard python library, for instance those with legacy compression type 6 (implode). This expects the `extract_dir` to exist. Raises: shutil.ReadError in case of issue uncompressing the archive (zippath does not exist, extract_dir does not exist, etc...) Returns: full path to the uncompressed directory. """ try: run(["unzip", "-q", "-d", extract_dir, zippath], check=True) return extract_dir except Exception as e: raise shutil.ReadError( f"Unable to uncompress {zippath} to {extract_dir}. Reason: {e}" ) def register_new_archive_formats(): - """Register new archive formats to uncompress - - """ + """Register new archive formats to uncompress""" registered_formats = [f[0] for f in shutil.get_unpack_formats()] for name, extensions, function in ADDITIONAL_ARCHIVE_FORMATS: if name in registered_formats: continue shutil.register_unpack_format(name, extensions, function) _mime_to_archive_format = { "application/x-compress": "tar.Z|x", "application/x-tar": "tar", "application/x-bzip2": "bztar", "application/gzip": "gztar", "application/x-lzip": "tar.lz", "application/zip": "zip", } def uncompress(tarpath: str, dest: str): """Uncompress tarpath to dest folder if tarball is supported. Note that this fixes permissions after successfully uncompressing the archive. Args: tarpath: path to tarball to uncompress dest: the destination folder where to uncompress the tarball, it will be created if it does not exist Raises: ValueError when a problem occurs during unpacking """ try: os.makedirs(dest, exist_ok=True) format = None # try to get archive format from extension for format_, exts, _ in shutil.get_unpack_formats(): if any([tarpath.lower().endswith(ext.lower()) for ext in exts]): format = format_ break # try to get archive format from file mimetype if format is None: m = magic.Magic(mime=True) mime = m.from_file(tarpath) format = _mime_to_archive_format.get(mime) shutil.unpack_archive(tarpath, extract_dir=dest, format=format) except shutil.ReadError as e: raise ValueError(f"Problem during unpacking {tarpath}. Reason: {e}") except NotImplementedError: if tarpath.lower().endswith(".zip") or format == "zip": _unpack_zip(tarpath, dest) else: raise normalize_permissions(dest) def normalize_permissions(path: str): """Normalize the permissions of all files and directories under `path`. This makes all subdirectories and files with the user executable bit set mode 0o0755, and all other files mode 0o0644. Args: path: the path under which permissions should be normalized """ for dirpath, _, fnames in os.walk(path): os.chmod(dirpath, 0o0755) for fname in fnames: fpath = os.path.join(dirpath, fname) if not os.path.islink(fpath): is_executable = os.stat(fpath).st_mode & stat.S_IXUSR forced_mode = 0o0755 if is_executable else 0o0644 os.chmod(fpath, forced_mode) def _ls(rootdir): - """Generator of filepath, filename from rootdir. - - """ + """Generator of filepath, filename from rootdir.""" for dirpath, dirnames, fnames in os.walk(rootdir): for fname in dirnames + fnames: fpath = os.path.join(dirpath, fname) fname = utils.commonname(rootdir, fpath) yield fpath, fname def _compress_zip(tarpath, files): - """Compress dirpath's content as tarpath. - - """ + """Compress dirpath's content as tarpath.""" with zipfile.ZipFile(tarpath, "w") as z: for fpath, fname in files: z.write(fpath, arcname=fname) def _compress_tar(tarpath, files): - """Compress dirpath's content as tarpath. - - """ + """Compress dirpath's content as tarpath.""" with tarfile.open(tarpath, "w:bz2") as t: for fpath, fname in files: t.add(fpath, arcname=fname, recursive=False) def compress(tarpath, nature, dirpath_or_files): """Create a tarball tarpath with nature nature. The content of the tarball is either dirpath's content (if representing a directory path) or dirpath's iterable contents. Compress the directory dirpath's content to a tarball. The tarball being dumped at tarpath. The nature of the tarball is determined by the nature argument. """ if isinstance(dirpath_or_files, str): files = _ls(dirpath_or_files) else: # iterable of 'filepath, filename' files = dirpath_or_files if nature == "zip": _compress_zip(tarpath, files) else: _compress_tar(tarpath, files) return tarpath # Additional uncompression archive format support ADDITIONAL_ARCHIVE_FORMATS = [ # name, extensions, function ("tar.Z|x", [".tar.Z", ".tar.x"], _unpack_tar), ("jar", [".jar"], _unpack_zip), ("tbz2", [".tbz", "tbz2"], _unpack_tar), # FIXME: make this optional depending on the runtime lzip package install ("tar.lz", [".tar.lz"], _unpack_tar), ("crate", [".crate"], _unpack_tar), ] register_new_archive_formats() diff --git a/swh/core/tests/fixture/test_pytest_plugin.py b/swh/core/tests/fixture/test_pytest_plugin.py index dbabbd8..ef1cb65 100644 --- a/swh/core/tests/fixture/test_pytest_plugin.py +++ b/swh/core/tests/fixture/test_pytest_plugin.py @@ -1,24 +1,22 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import requests from .conftest import DATADIR # In this arborescence, we override in the local conftest.py module the # "datadir" fixture to specify where to retrieve the data files from. def test_requests_mock_datadir_with_datadir_fixture_override(requests_mock_datadir): - """Override datadir fixture should retrieve data from elsewhere - - """ + """Override datadir fixture should retrieve data from elsewhere""" response = requests.get("https://example.com/file.json") assert response.ok assert response.json() == {"welcome": "you"} def test_data_dir_override(datadir): assert datadir == DATADIR diff --git a/swh/core/tests/test_cli.py b/swh/core/tests/test_cli.py index 0eb86e4..eaad6d0 100644 --- a/swh/core/tests/test_cli.py +++ b/swh/core/tests/test_cli.py @@ -1,345 +1,386 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import textwrap from typing import List from unittest.mock import patch import click from click.testing import CliRunner import pkg_resources import pytest help_msg_snippets = ( ( "Usage", ( "swh [OPTIONS] COMMAND [ARGS]...", "Command line interface for Software Heritage.", ), ), - ("Options", ("-l, --log-level", "--log-config", "--sentry-dsn", "-h, --help",)), + ( + "Options", + ( + "-l, --log-level", + "--log-config", + "--sentry-dsn", + "-h, --help", + ), + ), ) def get_section(cli_output: str, section: str) -> List[str]: """Get the given `section` of the `cli_output`""" result = [] in_section = False for line in cli_output.splitlines(): if not line: continue if in_section: if not line.startswith(" "): break else: if line.startswith(section): in_section = True if in_section: result.append(line) return result def assert_section_contains(cli_output: str, section: str, snippet: str) -> bool: """Check that a given `section` of the `cli_output` contains the given `snippet`""" section_lines = get_section(cli_output, section) assert section_lines, "Section %s not found in output %r" % (section, cli_output) for line in section_lines: if snippet in line: return True else: assert False, "%r not found in section %r of output %r" % ( snippet, section, cli_output, ) def test_swh_help(swhmain): runner = CliRunner() result = runner.invoke(swhmain, ["-h"]) assert result.exit_code == 0 for section, snippets in help_msg_snippets: for snippet in snippets: assert_section_contains(result.output, section, snippet) result = runner.invoke(swhmain, ["--help"]) assert result.exit_code == 0 for section, snippets in help_msg_snippets: for snippet in snippets: assert_section_contains(result.output, section, snippet) def test_command(swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): click.echo("Hello SWH!") runner = CliRunner() with patch("sentry_sdk.init") as sentry_sdk_init: result = runner.invoke(swhmain, ["test"]) sentry_sdk_init.assert_not_called() assert result.exit_code == 0 assert result.output.strip() == "Hello SWH!" def test_loglevel_default(caplog, swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): assert logging.root.level == 20 click.echo("Hello SWH!") runner = CliRunner() result = runner.invoke(swhmain, ["test"]) assert result.exit_code == 0 assert result.output.strip() == """Hello SWH!""" def test_loglevel_error(caplog, swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): assert logging.root.level == 40 click.echo("Hello SWH!") runner = CliRunner() result = runner.invoke(swhmain, ["-l", "ERROR", "test"]) assert result.exit_code == 0 assert result.output.strip() == """Hello SWH!""" def test_loglevel_debug(caplog, swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): assert logging.root.level == 10 click.echo("Hello SWH!") runner = CliRunner() result = runner.invoke(swhmain, ["-l", "DEBUG", "test"]) assert result.exit_code == 0 assert result.output.strip() == """Hello SWH!""" def test_sentry(swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): click.echo("Hello SWH!") runner = CliRunner() with patch("sentry_sdk.init") as sentry_sdk_init: result = runner.invoke(swhmain, ["--sentry-dsn", "test_dsn", "test"]) assert result.exit_code == 0 assert result.output.strip() == """Hello SWH!""" sentry_sdk_init.assert_called_once_with( - dsn="test_dsn", debug=False, integrations=[], release=None, environment=None, + dsn="test_dsn", + debug=False, + integrations=[], + release=None, + environment=None, ) def test_sentry_debug(swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): click.echo("Hello SWH!") runner = CliRunner() with patch("sentry_sdk.init") as sentry_sdk_init: result = runner.invoke( swhmain, ["--sentry-dsn", "test_dsn", "--sentry-debug", "test"] ) assert result.exit_code == 0 assert result.output.strip() == """Hello SWH!""" sentry_sdk_init.assert_called_once_with( - dsn="test_dsn", debug=True, integrations=[], release=None, environment=None, + dsn="test_dsn", + debug=True, + integrations=[], + release=None, + environment=None, ) def test_sentry_env(swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): click.echo("Hello SWH!") runner = CliRunner() with patch("sentry_sdk.init") as sentry_sdk_init: env = { "SWH_SENTRY_DSN": "test_dsn", "SWH_SENTRY_DEBUG": "1", } result = runner.invoke(swhmain, ["test"], env=env, auto_envvar_prefix="SWH") assert result.exit_code == 0 assert result.output.strip() == """Hello SWH!""" sentry_sdk_init.assert_called_once_with( - dsn="test_dsn", debug=True, integrations=[], release=None, environment=None, + dsn="test_dsn", + debug=True, + integrations=[], + release=None, + environment=None, ) def test_sentry_env_main_package(swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): click.echo("Hello SWH!") runner = CliRunner() with patch("sentry_sdk.init") as sentry_sdk_init: env = { "SWH_SENTRY_DSN": "test_dsn", "SWH_MAIN_PACKAGE": "swh.core", "SWH_SENTRY_ENVIRONMENT": "tests", } result = runner.invoke(swhmain, ["test"], env=env, auto_envvar_prefix="SWH") assert result.exit_code == 0 version = pkg_resources.get_distribution("swh.core").version assert result.output.strip() == """Hello SWH!""" sentry_sdk_init.assert_called_once_with( dsn="test_dsn", debug=False, integrations=[], release="swh.core@" + version, environment="tests", ) @pytest.fixture def log_config_path(tmp_path): log_config = textwrap.dedent( """\ --- version: 1 formatters: formatter: format: 'custom format:%(name)s:%(levelname)s:%(message)s' handlers: console: class: logging.StreamHandler stream: ext://sys.stdout formatter: formatter level: DEBUG root: level: DEBUG handlers: - console loggers: dontshowdebug: level: INFO """ ) (tmp_path / "log_config.yml").write_text(log_config) yield str(tmp_path / "log_config.yml") def test_log_config(log_config_path, swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): logging.debug("Root log debug") logging.info("Root log info") logging.getLogger("dontshowdebug").debug("Not shown") logging.getLogger("dontshowdebug").info("Shown") runner = CliRunner() - result = runner.invoke(swhmain, ["--log-config", log_config_path, "test",],) + result = runner.invoke( + swhmain, + [ + "--log-config", + log_config_path, + "test", + ], + ) assert result.exit_code == 0 assert result.output.strip() == "\n".join( [ "custom format:root:DEBUG:Root log debug", "custom format:root:INFO:Root log info", "custom format:dontshowdebug:INFO:Shown", ] ) def test_log_config_log_level_interaction(log_config_path, swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): logging.debug("Root log debug") logging.info("Root log info") logging.getLogger("dontshowdebug").debug("Not shown") logging.getLogger("dontshowdebug").info("Shown") runner = CliRunner() result = runner.invoke( - swhmain, ["--log-config", log_config_path, "--log-level", "INFO", "test",], + swhmain, + [ + "--log-config", + log_config_path, + "--log-level", + "INFO", + "test", + ], ) assert result.exit_code == 0 assert result.output.strip() == "\n".join( [ "custom format:root:INFO:Root log info", "custom format:dontshowdebug:INFO:Shown", ] ) def test_multiple_log_level_behavior(swhmain): @swhmain.command(name="test") @click.pass_context def swhtest(ctx): assert logging.getLevelName(logging.root.level) == "DEBUG" assert logging.getLevelName(logging.getLogger("dontshowdebug").level) == "INFO" return 0 runner = CliRunner() result = runner.invoke( - swhmain, ["--log-level", "DEBUG", "--log-level", "dontshowdebug:INFO", "test",] + swhmain, + [ + "--log-level", + "DEBUG", + "--log-level", + "dontshowdebug:INFO", + "test", + ], ) assert result.exit_code == 0, result.output def test_invalid_log_level(swhmain): runner = CliRunner() result = runner.invoke(swhmain, ["--log-level", "broken:broken:DEBUG"]) assert result.exit_code != 0 assert "Invalid log level specification" in result.output runner = CliRunner() result = runner.invoke(swhmain, ["--log-level", "UNKNOWN"]) assert result.exit_code != 0 assert "Log level UNKNOWN unknown" in result.output def test_aliased_command(swhmain): @swhmain.command(name="canonical-test") @click.pass_context def swhtest(ctx): "A test command." click.echo("Hello SWH!") swhmain.add_alias(swhtest, "othername") runner = CliRunner() # check we have only 'canonical-test' listed in the usage help msg result = runner.invoke(swhmain, ["-h"]) assert result.exit_code == 0 assert "canonical-test A test command." in result.output assert "othername" not in result.output # check we can execute the cmd with 'canonical-test' result = runner.invoke(swhmain, ["canonical-test"]) assert result.exit_code == 0 assert result.output.strip() == """Hello SWH!""" # check we can also execute the cmd with the alias 'othername' result = runner.invoke(swhmain, ["othername"]) assert result.exit_code == 0 assert result.output.strip() == """Hello SWH!""" diff --git a/swh/core/tests/test_config.py b/swh/core/tests/test_config.py index 8d02b2c..5f5005d 100644 --- a/swh/core/tests/test_config.py +++ b/swh/core/tests/test_config.py @@ -1,364 +1,367 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import shutil import pkg_resources.extern.packaging.version import pytest import yaml from swh.core import config pytest_v = pkg_resources.get_distribution("pytest").parsed_version if pytest_v < pkg_resources.extern.packaging.version.parse("3.9"): @pytest.fixture def tmp_path(): import pathlib import tempfile with tempfile.TemporaryDirectory() as tmpdir: yield pathlib.Path(tmpdir) default_conf = { "a": ("int", 2), "b": ("string", "default-string"), "c": ("bool", True), "d": ("int", 10), "e": ("int", None), "f": ("bool", None), "g": ("string", None), "h": ("bool", True), "i": ("bool", True), "ls": ("list[str]", ["a", "b", "c"]), "li": ("list[int]", [42, 43]), } other_default_conf = { "a": ("int", 3), } full_default_conf = default_conf.copy() full_default_conf["a"] = other_default_conf["a"] parsed_default_conf = {key: value for key, (type, value) in default_conf.items()} parsed_conffile = { "a": 1, "b": "this is a string", "c": True, "d": 10, "e": None, "f": None, "g": None, "h": False, "i": True, "ls": ["list", "of", "strings"], "li": [1, 2, 3, 4], } @pytest.fixture def swh_config(tmp_path): # create a temporary folder conffile = tmp_path / "config.yml" conf_contents = """ a: 1 b: this is a string c: true h: false ls: list, of, strings li: 1, 2, 3, 4 """ conffile.open("w").write(conf_contents) return conffile @pytest.fixture def swh_config_unreadable(swh_config): # Create an unreadable, proper configuration file os.chmod(str(swh_config), 0o000) yield swh_config # Make the broken perms file readable again to be able to remove them os.chmod(str(swh_config), 0o644) @pytest.fixture def swh_config_unreadable_dir(swh_config): # Create a proper configuration file in an unreadable directory perms_broken_dir = swh_config.parent / "unreadabledir" perms_broken_dir.mkdir() shutil.move(str(swh_config), str(perms_broken_dir)) os.chmod(str(perms_broken_dir), 0o000) yield perms_broken_dir / swh_config.name # Make the broken perms items readable again to be able to remove them os.chmod(str(perms_broken_dir), 0o755) @pytest.fixture def swh_config_empty(tmp_path): # create a temporary folder conffile = tmp_path / "config.yml" conffile.touch() return conffile def test_read(swh_config): # when res = config.read(str(swh_config), default_conf) # then assert res == parsed_conffile def test_read_no_default_conf(swh_config): """If no default config if provided to read, this should directly parse the config file yaml """ config_path = str(swh_config) actual_config = config.read(config_path) with open(config_path) as f: expected_config = yaml.safe_load(f) assert actual_config == expected_config def test_read_empty_file(): # when res = config.read(None, default_conf) # then assert res == parsed_default_conf def test_support_non_existing_conffile(tmp_path): # when res = config.read(str(tmp_path / "void.yml"), default_conf) # then assert res == parsed_default_conf def test_support_empty_conffile(swh_config_empty): # when res = config.read(str(swh_config_empty), default_conf) # then assert res == parsed_default_conf def test_raise_on_broken_directory_perms(swh_config_unreadable_dir): with pytest.raises(PermissionError): config.read(str(swh_config_unreadable_dir), default_conf) def test_raise_on_broken_file_perms(swh_config_unreadable): with pytest.raises(PermissionError): config.read(str(swh_config_unreadable), default_conf) def test_merge_default_configs(): # when res = config.merge_default_configs(default_conf, other_default_conf) # then assert res == full_default_conf def test_priority_read_nonexist_conf(swh_config): noexist = str(swh_config.parent / "void.yml") # when res = config.priority_read([noexist, str(swh_config)], default_conf) # then assert res == parsed_conffile def test_priority_read_conf_nonexist_empty(swh_config): noexist = swh_config.parent / "void.yml" empty = swh_config.parent / "empty.yml" empty.touch() # when res = config.priority_read( [str(p) for p in (swh_config, noexist, empty)], default_conf ) # then assert res == parsed_conffile def test_priority_read_empty_conf_nonexist(swh_config): noexist = swh_config.parent / "void.yml" empty = swh_config.parent / "empty.yml" empty.touch() # when res = config.priority_read( [str(p) for p in (empty, swh_config, noexist)], default_conf ) # then assert res == parsed_default_conf def test_swh_config_paths(): res = config.swh_config_paths("foo/bar.yml") assert res == [ "~/.config/swh/foo/bar.yml", "~/.swh/foo/bar.yml", "/etc/softwareheritage/foo/bar.yml", ] def test_prepare_folder(tmp_path): # given conf = { "path1": str(tmp_path / "path1"), "path2": str(tmp_path / "path2" / "depth1"), } # the folders does not exists assert not os.path.exists(conf["path1"]), "path1 should not exist." assert not os.path.exists(conf["path2"]), "path2 should not exist." # when config.prepare_folders(conf, "path1") # path1 exists but not path2 assert os.path.exists(conf["path1"]), "path1 should now exist!" assert not os.path.exists(conf["path2"]), "path2 should not exist." # path1 already exists, skips it but creates path2 config.prepare_folders(conf, "path1", "path2") assert os.path.exists(conf["path1"]), "path1 should still exist!" assert os.path.exists(conf["path2"]), "path2 should now exist." def test_merge_config(): cfg_a = { "a": 42, "b": [1, 2, 3], "c": None, "d": {"gheez": 27}, "e": { "ea": "Mr. Bungle", "eb": None, "ec": [11, 12, 13], "ed": {"eda": "Secret Chief 3", "edb": "Faith No More"}, "ee": 451, }, "f": "Janis", } cfg_b = { "a": 43, "b": [41, 42, 43], "c": "Tom Waits", "d": None, "e": { "ea": "Igorrr", "ec": [51, 52], "ed": {"edb": "Sleepytime Gorilla Museum", "edc": "Nils Peter Molvaer"}, }, "g": "Hüsker Dü", } # merge A, B cfg_m = config.merge_configs(cfg_a, cfg_b) assert cfg_m == { "a": 43, # b takes precedence "b": [41, 42, 43], # b takes precedence "c": "Tom Waits", # b takes precedence "d": None, # b['d'] takes precedence (explicit None) "e": { "ea": "Igorrr", # a takes precedence "eb": None, # only in a "ec": [51, 52], # b takes precedence "ed": { "eda": "Secret Chief 3", # only in a "edb": "Sleepytime Gorilla Museum", # b takes precedence "edc": "Nils Peter Molvaer", }, # only defined in b "ee": 451, }, "f": "Janis", # only defined in a "g": "Hüsker Dü", # only defined in b } # merge B, A cfg_m = config.merge_configs(cfg_b, cfg_a) assert cfg_m == { "a": 42, # a takes precedence "b": [1, 2, 3], # a takes precedence "c": None, # a takes precedence "d": {"gheez": 27}, # a takes precedence "e": { "ea": "Mr. Bungle", # a takes precedence "eb": None, # only defined in a "ec": [11, 12, 13], # a takes precedence "ed": { "eda": "Secret Chief 3", # only in a "edb": "Faith No More", # a takes precedence "edc": "Nils Peter Molvaer", }, # only in b "ee": 451, }, "f": "Janis", # only in a "g": "Hüsker Dü", # only in b } def test_merge_config_type_error(): for v in (1, "str", None): with pytest.raises(TypeError): config.merge_configs(v, {}) with pytest.raises(TypeError): config.merge_configs({}, v) for v in (1, "str"): with pytest.raises(TypeError): config.merge_configs({"a": v}, {"a": {}}) with pytest.raises(TypeError): config.merge_configs({"a": {}}, {"a": v}) def test_load_from_envvar_no_environment_var_swh_config_filename_set(): """Without SWH_CONFIG_FILENAME set, load_from_envvar raises""" with pytest.raises(AssertionError, match="SWH_CONFIG_FILENAME environment"): config.load_from_envvar() def test_load_from_envvar_no_default_config(swh_config, monkeypatch): config_path = str(swh_config) monkeypatch.setenv("SWH_CONFIG_FILENAME", config_path) actual_config = config.load_from_envvar() expected_config = config.read(config_path) assert actual_config == expected_config def test_load_from_envvar_with_default_config(swh_config, monkeypatch): default_config = { "number": 666, "something-cool": ["something", "cool"], } config_path = str(swh_config) monkeypatch.setenv("SWH_CONFIG_FILENAME", config_path) actual_config = config.load_from_envvar(default_config) expected_config = config.read(config_path) expected_config.update( - {"number": 666, "something-cool": ["something", "cool"],} + { + "number": 666, + "something-cool": ["something", "cool"], + } ) assert actual_config == expected_config diff --git a/swh/core/tests/test_pytest_plugin.py b/swh/core/tests/test_pytest_plugin.py index f8d23ef..399e695 100644 --- a/swh/core/tests/test_pytest_plugin.py +++ b/swh/core/tests/test_pytest_plugin.py @@ -1,117 +1,119 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from os import path from urllib.parse import unquote import requests from swh.core.pytest_plugin import requests_mock_datadir_factory def test_get_response_cb_with_encoded_url(requests_mock_datadir): # The following urls (quoted, unquoted) will be resolved as the same file for encoded_url, expected_response in [ ("https://forge.s.o/api/diffusion?attachments%5Buris%5D=1", "something"), ( "https://www.reference.com/web?q=What+Is+an+Example+of+a+URL?&qo=contentPageRelatedSearch&o=600605&l=dir&sga=1", # noqa "something else", ), ]: for url in [encoded_url, unquote(encoded_url)]: response = requests.get(url) assert response.ok assert response.json() == expected_response def test_get_response_cb_with_visits_nominal(requests_mock_datadir_visits): response = requests.get("https://example.com/file.json") assert response.ok assert response.json() == {"hello": "you"} response = requests.get("http://example.com/something.json") assert response.ok assert response.json() == "something" response = requests.get("https://example.com/file.json") assert response.ok assert response.json() == {"hello": "world"} response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 def test_get_response_cb_with_visits(requests_mock_datadir_visits): response = requests.get("https://example.com/file.json") assert response.ok assert response.json() == {"hello": "you"} response = requests.get("https://example.com/other.json") assert response.ok assert response.json() == "foobar" response = requests.get("https://example.com/file.json") assert response.ok assert response.json() == {"hello": "world"} response = requests.get("https://example.com/other.json") assert not response.ok assert response.status_code == 404 response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 def test_get_response_cb_no_visit(requests_mock_datadir): response = requests.get("https://example.com/file.json") assert response.ok assert response.json() == {"hello": "you"} response = requests.get("https://example.com/file.json") assert response.ok assert response.json() == {"hello": "you"} def test_get_response_cb_query_params(requests_mock_datadir): response = requests.get("https://example.com/file.json?toto=42") assert not response.ok assert response.status_code == 404 response = requests.get("https://example.com/file.json?name=doe&firstname=jane") assert response.ok assert response.json() == {"hello": "jane doe"} requests_mock_datadir_ignore = requests_mock_datadir_factory( - ignore_urls=["https://example.com/file.json"], has_multi_visit=False, + ignore_urls=["https://example.com/file.json"], + has_multi_visit=False, ) def test_get_response_cb_ignore_url(requests_mock_datadir_ignore): response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 requests_mock_datadir_ignore_and_visit = requests_mock_datadir_factory( - ignore_urls=["https://example.com/file.json"], has_multi_visit=True, + ignore_urls=["https://example.com/file.json"], + has_multi_visit=True, ) def test_get_response_cb_ignore_url_with_visit(requests_mock_datadir_ignore_and_visit): response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 def test_data_dir(datadir): expected_datadir = path.join(path.abspath(path.dirname(__file__)), "data") assert datadir == expected_datadir diff --git a/swh/core/tests/test_tarball.py b/swh/core/tests/test_tarball.py index 1cf9c08..a774477 100644 --- a/swh/core/tests/test_tarball.py +++ b/swh/core/tests/test_tarball.py @@ -1,250 +1,240 @@ # Copyright (C) 2019-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import hashlib import os import shutil import pytest from swh.core import tarball @pytest.fixture def prepare_shutil_state(): - """Reset any shutil modification in its current state - - """ + """Reset any shutil modification in its current state""" import shutil registered_formats = [f[0] for f in shutil.get_unpack_formats()] for format_id in tarball.ADDITIONAL_ARCHIVE_FORMATS: name = format_id[0] if name in registered_formats: shutil.unregister_unpack_format(name) return shutil def test_compress_uncompress_zip(tmp_path): tocompress = tmp_path / "compressme" tocompress.mkdir() for i in range(10): fpath = tocompress / ("file%s.txt" % i) fpath.write_text("content of file %s" % i) zipfile = tmp_path / "archive.zip" tarball.compress(str(zipfile), "zip", str(tocompress)) destdir = tmp_path / "destdir" tarball.uncompress(str(zipfile), str(destdir)) lsdir = sorted(x.name for x in destdir.iterdir()) assert ["file%s.txt" % i for i in range(10)] == lsdir @pytest.mark.xfail( reason=( "Python's zipfile library doesn't support Info-ZIP's " "extension for file permissions." ) ) def test_compress_uncompress_zip_modes(tmp_path): tocompress = tmp_path / "compressme" tocompress.mkdir() fpath = tocompress / "text.txt" fpath.write_text("echo foo") fpath.chmod(0o644) fpath = tocompress / "executable.sh" fpath.write_text("echo foo") fpath.chmod(0o755) zipfile = tmp_path / "archive.zip" tarball.compress(str(zipfile), "zip", str(tocompress)) destdir = tmp_path / "destdir" tarball.uncompress(str(zipfile), str(destdir)) (executable_path, text_path) = sorted(destdir.iterdir()) assert text_path.stat().st_mode == 0o100644 # succeeds, it's the default assert executable_path.stat().st_mode == 0o100755 # fails def test_compress_uncompress_tar(tmp_path): tocompress = tmp_path / "compressme" tocompress.mkdir() for i in range(10): fpath = tocompress / ("file%s.txt" % i) fpath.write_text("content of file %s" % i) tarfile = tmp_path / "archive.tar" tarball.compress(str(tarfile), "tar", str(tocompress)) destdir = tmp_path / "destdir" tarball.uncompress(str(tarfile), str(destdir)) lsdir = sorted(x.name for x in destdir.iterdir()) assert ["file%s.txt" % i for i in range(10)] == lsdir def test_compress_uncompress_tar_modes(tmp_path): tocompress = tmp_path / "compressme" tocompress.mkdir() fpath = tocompress / "text.txt" fpath.write_text("echo foo") fpath.chmod(0o644) fpath = tocompress / "executable.sh" fpath.write_text("echo foo") fpath.chmod(0o755) tarfile = tmp_path / "archive.tar" tarball.compress(str(tarfile), "tar", str(tocompress)) destdir = tmp_path / "destdir" tarball.uncompress(str(tarfile), str(destdir)) (executable_path, text_path) = sorted(destdir.iterdir()) assert text_path.stat().st_mode == 0o100644 assert executable_path.stat().st_mode == 0o100755 def test_uncompress_tar_failure(tmp_path, datadir): - """Unpack inexistent tarball should fail - - """ + """Unpack inexistent tarball should fail""" tarpath = os.path.join(datadir, "archives", "inexistent-archive.tar.Z") assert not os.path.exists(tarpath) with pytest.raises(ValueError, match="Problem during unpacking"): tarball.uncompress(tarpath, tmp_path) def test_uncompress_tar(tmp_path, datadir): - """Unpack supported tarball into an existent folder should be ok - - """ + """Unpack supported tarball into an existent folder should be ok""" filename = "groff-1.02.tar.Z" tarpath = os.path.join(datadir, "archives", filename) assert os.path.exists(tarpath) extract_dir = os.path.join(tmp_path, filename) tarball.uncompress(tarpath, extract_dir) assert len(os.listdir(extract_dir)) > 0 def test_register_new_archive_formats(prepare_shutil_state): - """Registering new archive formats should be fine - - """ + """Registering new archive formats should be fine""" unpack_formats_v1 = [f[0] for f in shutil.get_unpack_formats()] for format_id in tarball.ADDITIONAL_ARCHIVE_FORMATS: assert format_id[0] not in unpack_formats_v1 # when tarball.register_new_archive_formats() # then unpack_formats_v2 = [f[0] for f in shutil.get_unpack_formats()] for format_id in tarball.ADDITIONAL_ARCHIVE_FORMATS: assert format_id[0] in unpack_formats_v2 def test_uncompress_archives(tmp_path, datadir): - """High level call uncompression on supported archives - - """ + """High level call uncompression on supported archives""" archive_dir = os.path.join(datadir, "archives") archive_files = os.listdir(archive_dir) for archive_file in archive_files: archive_path = os.path.join(archive_dir, archive_file) extract_dir = os.path.join(tmp_path, archive_file) tarball.uncompress(archive_path, dest=extract_dir) assert len(os.listdir(extract_dir)) > 0 def test_normalize_permissions(tmp_path): for perms in range(0o1000): filename = str(perms) file_path = tmp_path / filename file_path.touch() file_path.chmod(perms) for file in tmp_path.iterdir(): assert file.stat().st_mode == 0o100000 | int(file.name) tarball.normalize_permissions(str(tmp_path)) for file in tmp_path.iterdir(): if int(file.name) & 0o100: # original file was executable for its owner assert file.stat().st_mode == 0o100755 else: assert file.stat().st_mode == 0o100644 def test_unpcompress_zip_imploded(tmp_path, datadir): """Unpack a zip archive with compression type 6 (implode), not supported by python zipfile module. """ filename = "msk316src.zip" zippath = os.path.join(datadir, "archives", filename) assert os.path.exists(zippath) extract_dir = os.path.join(tmp_path, filename) tarball.uncompress(zippath, extract_dir) assert len(os.listdir(extract_dir)) > 0 def test_uncompress_upper_archive_extension(tmp_path, datadir): """Copy test archives in a temporary directory but turn their names to uppercase, then check they can be successfully extracted. """ archives_path = os.path.join(datadir, "archives") archive_files = [ f for f in os.listdir(archives_path) if os.path.isfile(os.path.join(archives_path, f)) ] for archive_file in archive_files: archive_file_upper = os.path.join(tmp_path, archive_file.upper()) extract_dir = os.path.join(tmp_path, archive_file) shutil.copy(os.path.join(archives_path, archive_file), archive_file_upper) tarball.uncompress(archive_file_upper, extract_dir) assert len(os.listdir(extract_dir)) > 0 def test_uncompress_archive_no_extension(tmp_path, datadir): """Copy test archives in a temporary directory but turn their names to their md5 sums, then check they can be successfully extracted. """ archives_path = os.path.join(datadir, "archives") archive_files = [ f for f in os.listdir(archives_path) if os.path.isfile(os.path.join(archives_path, f)) ] for archive_file in archive_files: archive_file_path = os.path.join(archives_path, archive_file) with open(archive_file_path, "rb") as f: md5sum = hashlib.md5(f.read()).hexdigest() archive_file_md5sum = os.path.join(tmp_path, md5sum) extract_dir = os.path.join(tmp_path, archive_file) shutil.copy(archive_file_path, archive_file_md5sum) tarball.uncompress(archive_file_md5sum, extract_dir) assert len(os.listdir(extract_dir)) > 0 diff --git a/swh/core/utils.py b/swh/core/utils.py index e65ed16..7e44c76 100644 --- a/swh/core/utils.py +++ b/swh/core/utils.py @@ -1,191 +1,189 @@ # Copyright (C) 2016-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import codecs from contextlib import contextmanager import itertools import os import re from typing import Iterable, Tuple, TypeVar @contextmanager def cwd(path): """Contextually change the working directory to do thy bidding. Then gets back to the original location. """ prev_cwd = os.getcwd() os.chdir(path) try: yield finally: os.chdir(prev_cwd) def grouper(iterable, n): """ Collect data into fixed-length size iterables. The last block might contain less elements as it will hold only the remaining number of elements. The invariant here is that the number of elements in the input iterable and the sum of the number of elements of all iterables generated from this function should be equal. If ``iterable`` is an iterable of bytes or strings that you need to join later, then :func:`iter_chunks`` is preferable, as it avoids this join by slicing directly. Args: iterable (Iterable): an iterable n (int): size of block to slice the iterable into Yields: fixed-length blocks as iterables. As mentioned, the last iterable might be less populated. """ args = [iter(iterable)] * n stop_value = object() for _data in itertools.zip_longest(*args, fillvalue=stop_value): yield (d for d in _data if d is not stop_value) TStr = TypeVar("TStr", bytes, str) def iter_chunks( iterable: Iterable[TStr], chunk_size: int, *, remainder: bool = False ) -> Iterable[TStr]: """ Reads ``bytes`` objects (resp. ``str`` objects) from the ``iterable``, and yields them as chunks of exactly ``chunk_size`` bytes (resp. characters). ``iterable`` is typically obtained by repeatedly calling a method like :meth:`io.RawIOBase.read`; which does only guarantees an upper bound on the size; whereas this function returns chunks of exactly the size. Args: iterable: the input data chunk_size: the exact size of chunks to return remainder: if True, a last chunk with size strictly smaller than ``chunk_size`` may be returned, if the data stream from the ``iterable`` had a length that is not a multiple of ``chunk_size`` """ buf = None iterator = iter(iterable) while True: assert buf is None or len(buf) < chunk_size try: new_data = next(iterator) except StopIteration: if remainder and buf: yield buf # may be shorter than ``chunk_size`` return if buf: buf += new_data else: # spares a copy buf = new_data new_buf = None for i in range(0, len(buf), chunk_size): chunk = buf[i : i + chunk_size] if len(chunk) == chunk_size: yield chunk else: assert not new_buf new_buf = chunk buf = new_buf def backslashescape_errors(exception): if isinstance(exception, UnicodeDecodeError): bad_data = exception.object[exception.start : exception.end] escaped = "".join(r"\x%02x" % x for x in bad_data) return escaped, exception.end return codecs.backslashreplace_errors(exception) codecs.register_error("backslashescape", backslashescape_errors) def encode_with_unescape(value): """Encode an unicode string containing \\x backslash escapes""" slices = [] start = 0 odd_backslashes = False i = 0 while i < len(value): if value[i] == "\\": odd_backslashes = not odd_backslashes else: if odd_backslashes: if value[i] != "x": raise ValueError( "invalid escape for %r at position %d" % (value, i - 1) ) slices.append( value[start : i - 1].replace("\\\\", "\\").encode("utf-8") ) slices.append(bytes.fromhex(value[i + 1 : i + 3])) odd_backslashes = False start = i = i + 3 continue i += 1 slices.append(value[start:i].replace("\\\\", "\\").encode("utf-8")) return b"".join(slices) def decode_with_escape(value): """Decode a bytestring as utf-8, escaping the bytes of invalid utf-8 sequences as \\x. We also escape NUL bytes as they are invalid in JSON strings. """ # escape backslashes value = value.replace(b"\\", b"\\\\") value = value.replace(b"\x00", b"\\x00") return value.decode("utf-8", "backslashescape") def commonname(path0, path1, as_str=False): - """Compute the commonname between the path0 and path1. - - """ + """Compute the commonname between the path0 and path1.""" return path1.split(path0)[1] def numfile_sortkey(fname: str) -> Tuple[int, str]: """Simple function to sort filenames of the form: nnxxx.ext where nn is a number according to the numbers. Returns a tuple (order, remaining), where 'order' is the numeric (int) value extracted from the file name, and 'remaining' is the remaining part of the file name. Typically used to sort sql/nn-swh-xxx.sql files. Unmatched file names will return 999999 as order value. """ m = re.match(r"(\d*)(.*)", fname) assert m is not None num, rem = m.groups() return (int(num) if num else 999999, rem) def basename_sortkey(fname: str) -> Tuple[int, str]: "like numfile_sortkey but on basenames" return numfile_sortkey(os.path.basename(fname))