Page MenuHomeSoftware Heritage

No OneTemporary

diff --git a/PKG-INFO b/PKG-INFO
index 0487602..31b8657 100644
--- a/PKG-INFO
+++ b/PKG-INFO
@@ -1,42 +1,42 @@
Metadata-Version: 2.1
Name: swh.core
-Version: 2.2.0
+Version: 2.2.1
Summary: Software Heritage core utilities
Home-page: https://forge.softwareheritage.org/diffusion/DCORE/
Author: Software Heritage developers
Author-email: swh-devel@inria.fr
License: UNKNOWN
Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest
Project-URL: Funding, https://www.softwareheritage.org/donate
Project-URL: Source, https://forge.softwareheritage.org/source/swh-core
Project-URL: Documentation, https://docs.softwareheritage.org/devel/swh-core/
Platform: UNKNOWN
Classifier: Programming Language :: Python :: 3
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
Classifier: Operating System :: OS Independent
Classifier: Development Status :: 5 - Production/Stable
Requires-Python: >=3.7
Description-Content-Type: text/x-rst
Provides-Extra: testing-core
Provides-Extra: logging
Provides-Extra: db
Provides-Extra: http
Provides-Extra: testing
License-File: LICENSE
License-File: AUTHORS
Software Heritage - Core foundations
====================================
Low-level utilities and helpers used by almost all other modules in the stack.
core library for swh's modules:
- config parser
- serialization
- logging mechanism
- database connection
- http-based RPC client/server
diff --git a/mypy.ini b/mypy.ini
index 79e3c0d..4718366 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -1,51 +1,48 @@
[mypy]
namespace_packages = True
warn_unused_ignores = True
# 3rd party libraries without stubs (yet)
[mypy-aiohttp_utils.*]
ignore_missing_imports = True
[mypy-arrow.*]
ignore_missing_imports = True
[mypy-celery.*]
ignore_missing_imports = True
[mypy-decorator.*]
ignore_missing_imports = True
[mypy-deprecated.*]
ignore_missing_imports = True
[mypy-django.*] # false positive, only used my hypotesis' extras
ignore_missing_imports = True
[mypy-iso8601.*]
ignore_missing_imports = True
[mypy-magic.*]
ignore_missing_imports = True
[mypy-msgpack.*]
ignore_missing_imports = True
[mypy-pkg_resources.*]
ignore_missing_imports = True
-[mypy-psycopg2.*]
-ignore_missing_imports = True
-
[mypy-pytest.*]
ignore_missing_imports = True
[mypy-pytest_postgresql.*]
ignore_missing_imports = True
[mypy-requests_mock.*]
ignore_missing_imports = True
[mypy-systemd.*]
ignore_missing_imports = True
diff --git a/requirements-test.txt b/requirements-test.txt
index 681761e..a524054 100644
--- a/requirements-test.txt
+++ b/requirements-test.txt
@@ -1,11 +1,12 @@
hypothesis >= 3.11.0
pytest < 7.0.0 # v7.0.0 removed _pytest.tmpdir.TempdirFactory, which is used by some of the pytest plugins we use
pytest-mock
pytz
requests-mock
types-click
types-flask
+types-psycopg2
types-pytz
types-pyyaml
types-requests
diff --git a/swh.core.egg-info/PKG-INFO b/swh.core.egg-info/PKG-INFO
index 0487602..31b8657 100644
--- a/swh.core.egg-info/PKG-INFO
+++ b/swh.core.egg-info/PKG-INFO
@@ -1,42 +1,42 @@
Metadata-Version: 2.1
Name: swh.core
-Version: 2.2.0
+Version: 2.2.1
Summary: Software Heritage core utilities
Home-page: https://forge.softwareheritage.org/diffusion/DCORE/
Author: Software Heritage developers
Author-email: swh-devel@inria.fr
License: UNKNOWN
Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest
Project-URL: Funding, https://www.softwareheritage.org/donate
Project-URL: Source, https://forge.softwareheritage.org/source/swh-core
Project-URL: Documentation, https://docs.softwareheritage.org/devel/swh-core/
Platform: UNKNOWN
Classifier: Programming Language :: Python :: 3
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
Classifier: Operating System :: OS Independent
Classifier: Development Status :: 5 - Production/Stable
Requires-Python: >=3.7
Description-Content-Type: text/x-rst
Provides-Extra: testing-core
Provides-Extra: logging
Provides-Extra: db
Provides-Extra: http
Provides-Extra: testing
License-File: LICENSE
License-File: AUTHORS
Software Heritage - Core foundations
====================================
Low-level utilities and helpers used by almost all other modules in the stack.
core library for swh's modules:
- config parser
- serialization
- logging mechanism
- database connection
- http-based RPC client/server
diff --git a/swh.core.egg-info/requires.txt b/swh.core.egg-info/requires.txt
index 34f80b6..45c1ee1 100644
--- a/swh.core.egg-info/requires.txt
+++ b/swh.core.egg-info/requires.txt
@@ -1,57 +1,59 @@
click
deprecated
python-magic
pyyaml
sentry-sdk
[db]
psycopg2
typing-extensions
pytest-postgresql<4.0.0,>=3
[http]
aiohttp
aiohttp_utils>=3.1.1
blinker
flask
iso8601
msgpack>=1.0.0
requests
[logging]
systemd-python
[testing]
hypothesis>=3.11.0
pytest<7.0.0
pytest-mock
pytz
requests-mock
types-click
types-flask
+types-psycopg2
types-pytz
types-pyyaml
types-requests
psycopg2
typing-extensions
pytest-postgresql<4.0.0,>=3
aiohttp
aiohttp_utils>=3.1.1
blinker
flask
iso8601
msgpack>=1.0.0
requests
systemd-python
[testing-core]
hypothesis>=3.11.0
pytest<7.0.0
pytest-mock
pytz
requests-mock
types-click
types-flask
+types-psycopg2
types-pytz
types-pyyaml
types-requests
diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py
index b96aed6..5e416cc 100644
--- a/swh/core/api/__init__.py
+++ b/swh/core/api/__init__.py
@@ -1,500 +1,500 @@
# Copyright (C) 2015-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from collections import abc
import functools
import inspect
import logging
import pickle
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from deprecated import deprecated
from flask import Flask, Request, Response, abort, request
import requests
from werkzeug.exceptions import HTTPException
from .negotiation import Formatter as FormatterBase
from .negotiation import Negotiator as NegotiatorBase
from .negotiation import negotiate as _negotiate
from .serializers import (
exception_to_dict,
json_dumps,
json_loads,
msgpack_dumps,
msgpack_loads,
)
from .serializers import decode_response
from .serializers import encode_data_client as encode_data
logger = logging.getLogger(__name__)
# support for content negotiation
class Negotiator(NegotiatorBase):
def best_mimetype(self):
return request.accept_mimetypes.best_match(
self.accept_mimetypes, "application/json"
)
def _abort(self, status_code, err=None):
return abort(status_code, err)
def negotiate(formatter_cls, *args, **kwargs):
return _negotiate(Negotiator, formatter_cls, *args, **kwargs)
class Formatter(FormatterBase):
def _make_response(self, body, content_type):
return Response(body, content_type=content_type)
def configure(self, extra_encoders=None):
self.extra_encoders = extra_encoders
class JSONFormatter(Formatter):
format = "json"
mimetypes = ["application/json"]
def render(self, obj):
return json_dumps(obj, extra_encoders=self.extra_encoders)
class MsgpackFormatter(Formatter):
format = "msgpack"
mimetypes = ["application/x-msgpack"]
def render(self, obj):
return msgpack_dumps(obj, extra_encoders=self.extra_encoders)
# base API classes
class RemoteException(Exception):
"""raised when remote returned an out-of-band failure notification, e.g., as a
HTTP status code or serialized exception
Attributes:
response: HTTP response corresponding to the failure
"""
def __init__(
self,
payload: Optional[Any] = None,
response: Optional[requests.Response] = None,
):
if payload is not None:
super().__init__(payload)
else:
super().__init__()
self.response = response
def __str__(self):
if (
self.args
and isinstance(self.args[0], dict)
and "type" in self.args[0]
and "args" in self.args[0]
):
return (
f"<RemoteException {self.response.status_code} "
f'{self.args[0]["type"]}: {self.args[0]["args"]}>'
)
else:
return super().__str__()
F = TypeVar("F", bound=Callable)
def remote_api_endpoint(path: str, method: str = "POST") -> Callable[[F], F]:
def dec(f: F) -> F:
f._endpoint_path = path # type: ignore
f._method = method # type: ignore
return f
return dec
class APIError(Exception):
"""API Error"""
def __str__(self):
return "An unexpected error occurred in the backend: {}".format(self.args)
class MetaRPCClient(type):
"""Metaclass for RPCClient, which adds a method for each endpoint
of the database it is designed to access.
See for example :class:`swh.indexer.storage.api.client.RemoteStorage`"""
def __new__(cls, name, bases, attributes):
# For each method wrapped with @remote_api_endpoint in an API backend
# (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new
# method in RemoteStorage, with the same documentation.
#
# Note that, despite the usage of decorator magic (eg. functools.wrap),
# this never actually calls an IndexerStorage method.
backend_class = attributes.get("backend_class", None)
for base in bases:
if backend_class is not None:
break
backend_class = getattr(base, "backend_class", None)
if backend_class:
for (meth_name, meth) in backend_class.__dict__.items():
if hasattr(meth, "_endpoint_path"):
cls.__add_endpoint(meth_name, meth, attributes)
return super().__new__(cls, name, bases, attributes)
@staticmethod
def __add_endpoint(meth_name, meth, attributes):
wrapped_meth = inspect.unwrap(meth)
@functools.wraps(meth) # Copy signature and doc
def meth_(*args, **kwargs):
# Match arguments and parameters
post_data = inspect.getcallargs(wrapped_meth, *args, **kwargs)
# Remove arguments that should not be passed
self = post_data.pop("self")
post_data.pop("cur", None)
post_data.pop("db", None)
# Send the request.
- return self.post(meth._endpoint_path, post_data)
+ return self._post(meth._endpoint_path, post_data)
if meth_name not in attributes:
attributes[meth_name] = meth_
class RPCClient(metaclass=MetaRPCClient):
"""Proxy to an internal SWH RPC
"""
backend_class = None # type: ClassVar[Optional[type]]
"""For each method of `backend_class` decorated with
:func:`remote_api_endpoint`, a method with the same prototype and
docstring will be added to this class. Calls to this new method will
be translated into HTTP requests to a remote server.
This backend class will never be instantiated, it only serves as
a template."""
api_exception = APIError # type: ClassVar[Type[Exception]]
"""The exception class to raise in case of communication error with
the server."""
reraise_exceptions: ClassVar[List[Type[Exception]]] = []
"""On server errors, if any of the exception classes in this list
has the same name as the error name, then the exception will
be instantiated and raised instead of a generic RemoteException."""
extra_type_encoders: List[Tuple[type, str, Callable]] = []
"""Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps`
to be able to serialize more object types."""
extra_type_decoders: Dict[str, Callable] = {}
"""Value of `extra_decoders` passed to `json_loads` or `msgpack_loads`
to be able to deserialize more object types."""
def __init__(
self,
url,
api_exception=None,
timeout=None,
chunk_size=4096,
reraise_exceptions=None,
**kwargs,
):
if api_exception:
self.api_exception = api_exception
if reraise_exceptions:
self.reraise_exceptions = reraise_exceptions
base_url = url if url.endswith("/") else url + "/"
self.url = base_url
self.session = requests.Session()
adapter = requests.adapters.HTTPAdapter(
max_retries=kwargs.get("max_retries", 3),
pool_connections=kwargs.get("pool_connections", 20),
pool_maxsize=kwargs.get("pool_maxsize", 100),
)
self.session.mount(self.url, adapter)
self.timeout = timeout
self.chunk_size = chunk_size
def _url(self, endpoint):
return "%s%s" % (self.url, endpoint)
def raw_verb(self, verb, endpoint, **opts):
if "chunk_size" in opts:
# if the chunk_size argument has been passed, consider the user
# also wants stream=True, otherwise, what's the point.
opts["stream"] = True
if self.timeout and "timeout" not in opts:
opts["timeout"] = self.timeout
try:
return getattr(self.session, verb)(self._url(endpoint), **opts)
except requests.exceptions.ConnectionError as e:
raise self.api_exception(e)
def _post(self, endpoint, data, **opts):
if isinstance(data, (abc.Iterator, abc.Generator)):
data = (self._encode_data(x) for x in data)
else:
data = self._encode_data(data)
chunk_size = opts.pop("chunk_size", self.chunk_size)
response = self.raw_verb(
"post",
endpoint,
data=data,
headers={
"content-type": "application/x-msgpack",
"accept": "application/x-msgpack",
},
**opts,
)
if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked":
self.raise_for_status(response)
return response.iter_content(chunk_size)
else:
return self._decode_response(response)
def _encode_data(self, data):
return encode_data(data, extra_encoders=self.extra_type_encoders)
_post_stream = _post
@deprecated(version="2.1.0", reason="Use _post instead")
def post(self, *args, **kwargs):
return self._post(*args, **kwargs)
@deprecated(version="2.1.0", reason="Use _post_stream instead")
def post_stream(self, *args, **kwargs):
return self._post_stream(*args, **kwargs)
def _get(self, endpoint, **opts):
chunk_size = opts.pop("chunk_size", self.chunk_size)
response = self.raw_verb(
"get", endpoint, headers={"accept": "application/x-msgpack"}, **opts
)
if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked":
self.raise_for_status(response)
return response.iter_content(chunk_size)
else:
return self._decode_response(response)
def _get_stream(self, endpoint, **opts):
return self._get(endpoint, stream=True, **opts)
@deprecated(version="2.1.0", reason="Use _get instead")
def get(self, *args, **kwargs):
return self._get(*args, **kwargs)
@deprecated(version="2.1.0", reason="Use _get_stream instead")
def get_stream(self, *args, **kwargs):
return self._get_stream(*args, **kwargs)
def raise_for_status(self, response) -> None:
"""check response HTTP status code and raise an exception if it denotes an
error; do nothing otherwise
"""
status_code = response.status_code
status_class = response.status_code // 100
if status_code == 404:
raise RemoteException(payload="404 not found", response=response)
exception = None
# TODO: only old servers send pickled error; stop trying to unpickle
# after they are all upgraded
try:
if status_class == 4:
data = self._decode_response(response, check_status=False)
if isinstance(data, dict):
# TODO: remove "exception" key check once all servers
# are using new schema
exc_data = data["exception"] if "exception" in data else data
for exc_type in self.reraise_exceptions:
if exc_type.__name__ == exc_data["type"]:
exception = exc_type(*exc_data["args"])
break
else:
exception = RemoteException(payload=exc_data, response=response)
else:
exception = pickle.loads(data)
elif status_class == 5:
data = self._decode_response(response, check_status=False)
if "exception_pickled" in data:
exception = pickle.loads(data["exception_pickled"])
else:
# TODO: remove "exception" key check once all servers
# are using new schema
exc_data = data["exception"] if "exception" in data else data
exception = RemoteException(payload=exc_data, response=response)
except (TypeError, pickle.UnpicklingError):
raise RemoteException(payload=data, response=response)
if exception:
raise exception from None
if status_class != 2:
raise RemoteException(
payload=f"API HTTP error: {status_code} {response.content}",
response=response,
)
def _decode_response(self, response, check_status=True):
if check_status:
self.raise_for_status(response)
return decode_response(response, extra_decoders=self.extra_type_decoders)
def __repr__(self):
return "<{} url={}>".format(self.__class__.__name__, self.url)
class BytesRequest(Request):
"""Request with proper escaping of arbitrary byte sequences."""
encoding = "utf-8"
encoding_errors = "surrogateescape"
ENCODERS: Dict[str, Callable[[Any], Union[bytes, str]]] = {
"application/x-msgpack": msgpack_dumps,
"application/json": json_dumps,
}
def encode_data_server(
data, content_type="application/x-msgpack", extra_type_encoders=None
):
encoded_data = ENCODERS[content_type](data, extra_encoders=extra_type_encoders)
return Response(encoded_data, mimetype=content_type,)
def decode_request(request, extra_decoders=None):
content_type = request.mimetype
data = request.get_data()
if not data:
return {}
if content_type == "application/x-msgpack":
r = msgpack_loads(data, extra_decoders=extra_decoders)
elif content_type == "application/json":
# XXX this .decode() is needed for py35.
# Should not be needed any more with py37
r = json_loads(data.decode("utf-8"), extra_decoders=extra_decoders)
else:
raise ValueError("Wrong content type `%s` for API request" % content_type)
return r
def error_handler(exception, encoder, status_code=500):
logging.exception(exception)
response = encoder(exception_to_dict(exception))
if isinstance(exception, HTTPException):
response.status_code = exception.code
else:
# TODO: differentiate between server errors and client errors
response.status_code = status_code
return response
class RPCServerApp(Flask):
"""For each endpoint of the given `backend_class`, tells app.route to call
a function that decodes the request and sends it to the backend object
provided by the factory.
:param Any backend_class:
The class of the backend, which will be analyzed to look
for API endpoints.
:param Optional[Callable[[], backend_class]] backend_factory:
A function with no argument that returns an instance of
`backend_class`. If unset, defaults to calling `backend_class`
constructor directly.
For each method 'do_x()' of the ``backend_factory``, subclasses may implement
two methods: ``pre_do_x(self, kw)`` and ``post_do_x(self, ret, kw)`` that will
be called respectively before and after ``do_x(**kw)``. ``kw`` is the dict
of request parameters, and ``ret`` is the return value of ``do_x(**kw)``.
"""
request_class = BytesRequest
extra_type_encoders: List[Tuple[type, str, Callable]] = []
"""Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps`
to be able to serialize more object types."""
extra_type_decoders: Dict[str, Callable] = {}
"""Value of `extra_decoders` passed to `json_loads` or `msgpack_loads`
to be able to deserialize more object types."""
method_decorators: List[Callable[[Callable], Callable]] = []
"""List of decorators to all methods generated from the ``backend_class``."""
def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs):
super().__init__(*args, **kwargs)
self.add_backend_class(backend_class, backend_factory)
def add_backend_class(self, backend_class=None, backend_factory=None):
if backend_class is None and backend_factory is not None:
raise ValueError(
"backend_factory should only be provided if backend_class is"
)
if backend_class is not None:
backend_factory = backend_factory or backend_class
for (meth_name, meth) in backend_class.__dict__.items():
if hasattr(meth, "_endpoint_path"):
self.__add_endpoint(meth_name, meth, backend_factory)
def __add_endpoint(self, meth_name, meth, backend_factory):
from flask import request
@negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders)
@negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders)
@functools.wraps(meth) # Copy signature and doc
def f():
# Call the actual code
pre_hook = getattr(self, f"pre_{meth_name}", None)
post_hook = getattr(self, f"post_{meth_name}", None)
obj_meth = getattr(backend_factory(), meth_name)
kw = decode_request(request, extra_decoders=self.extra_type_decoders)
if pre_hook is not None:
pre_hook(kw)
ret = obj_meth(**kw)
if post_hook is not None:
post_hook(ret, kw)
return ret
for decorator in self.method_decorators:
f = decorator(f)
self.route("/" + meth._endpoint_path, methods=["POST"])(f)
diff --git a/swh/core/api/tests/test_rpc_client.py b/swh/core/api/tests/test_rpc_client.py
index 865bd1f..747ee67 100644
--- a/swh/core/api/tests/test_rpc_client.py
+++ b/swh/core/api/tests/test_rpc_client.py
@@ -1,177 +1,177 @@
# Copyright (C) 2018-2019 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import re
import pytest
from requests.exceptions import ConnectionError
from swh.core.api import APIError, RemoteException, RPCClient, remote_api_endpoint
from swh.core.api.serializers import exception_to_dict, msgpack_dumps
from .test_serializers import ExtraType, extra_decoders, extra_encoders
class ReraiseException(Exception):
pass
@pytest.fixture
def rpc_client(requests_mock):
class TestStorage:
@remote_api_endpoint("test_endpoint_url")
def test_endpoint(self, test_data, db=None, cur=None):
...
@remote_api_endpoint("path/to/endpoint")
def something(self, data, db=None, cur=None):
...
@remote_api_endpoint("serializer_test")
def serializer_test(self, data, db=None, cur=None):
...
@remote_api_endpoint("overridden/endpoint")
def overridden_method(self, data):
return "foo"
class Testclient(RPCClient):
backend_class = TestStorage
extra_type_encoders = extra_encoders
extra_type_decoders = extra_decoders
reraise_exceptions = [ReraiseException]
def overridden_method(self, data):
return "bar"
def callback(request, context):
assert request.headers["Content-Type"] == "application/x-msgpack"
context.headers["Content-Type"] = "application/x-msgpack"
if request.path == "/test_endpoint_url":
context.content = b"\xa3egg"
elif request.path == "/path/to/endpoint":
context.content = b"\xa4spam"
elif request.path == "/serializer_test":
context.content = (
b"\x82\xc4\x07swhtype\xa9extratype"
b"\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux"
)
else:
assert False
return context.content
requests_mock.post(re.compile("mock://example.com/"), content=callback)
return Testclient(url="mock://example.com")
def test_client(rpc_client):
assert hasattr(rpc_client, "test_endpoint")
assert hasattr(rpc_client, "something")
res = rpc_client.test_endpoint("spam")
assert res == "egg"
res = rpc_client.test_endpoint(test_data="spam")
assert res == "egg"
res = rpc_client.something("whatever")
assert res == "spam"
res = rpc_client.something(data="whatever")
assert res == "spam"
def test_client_extra_serializers(rpc_client):
res = rpc_client.serializer_test(["foo", ExtraType("bar", b"baz")])
assert res == ExtraType({"spam": "egg"}, "qux")
def test_client_overridden_method(rpc_client):
res = rpc_client.overridden_method("foo")
assert res == "bar"
def test_client_connexion_error(rpc_client, requests_mock):
"""
ConnectionError should be wrapped and raised as an APIError.
"""
error_message = "unreachable host"
requests_mock.post(
re.compile("mock://example.com/connection_error"),
exc=ConnectionError(error_message),
)
with pytest.raises(APIError) as exc_info:
- rpc_client.post("connection_error", data={})
+ rpc_client._post("connection_error", data={})
assert type(exc_info.value.args[0]) == ConnectionError
assert str(exc_info.value.args[0]) == error_message
def _exception_response(exception, status_code, old_exception_schema=False):
def callback(request, context):
assert request.headers["Content-Type"] == "application/x-msgpack"
context.headers["Content-Type"] = "application/x-msgpack"
exc_dict = exception_to_dict(exception)
if old_exception_schema:
exc_dict = {"exception": exc_dict}
context.content = msgpack_dumps(exc_dict)
context.status_code = status_code
return context.content
return callback
@pytest.mark.parametrize("old_exception_schema", [False, True])
def test_client_reraise_exception(rpc_client, requests_mock, old_exception_schema):
"""
Exception caught server-side and whitelisted will be raised again client-side.
"""
error_message = "something went wrong"
endpoint = "reraise_exception"
requests_mock.post(
re.compile(f"mock://example.com/{endpoint}"),
content=_exception_response(
exception=ReraiseException(error_message),
status_code=400,
old_exception_schema=old_exception_schema,
),
)
with pytest.raises(ReraiseException) as exc_info:
- rpc_client.post(endpoint, data={})
+ rpc_client._post(endpoint, data={})
assert str(exc_info.value) == error_message
@pytest.mark.parametrize(
"status_code, old_exception_schema",
[(400, False), (500, False), (400, True), (500, True),],
)
def test_client_raise_remote_exception(
rpc_client, requests_mock, status_code, old_exception_schema
):
"""
Exception caught server-side and not whitelisted will be wrapped and raised
as a RemoteException client-side.
"""
error_message = "something went wrong"
endpoint = "raise_remote_exception"
requests_mock.post(
re.compile(f"mock://example.com/{endpoint}"),
content=_exception_response(
exception=Exception(error_message),
status_code=status_code,
old_exception_schema=old_exception_schema,
),
)
with pytest.raises(RemoteException) as exc_info:
- rpc_client.post(endpoint, data={})
+ rpc_client._post(endpoint, data={})
assert str(exc_info.value.args[0]["type"]) == "Exception"
assert str(exc_info.value.args[0]["message"]) == error_message
diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py
index 92491ac..dbacece 100644
--- a/swh/core/db/db_utils.py
+++ b/swh/core/db/db_utils.py
@@ -1,664 +1,672 @@
# Copyright (C) 2015-2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from datetime import datetime, timezone
import functools
from importlib import import_module
import logging
from os import path
import pathlib
import re
import subprocess
from typing import Collection, Dict, List, Optional, Tuple, Union
import psycopg2
+import psycopg2.errors
import psycopg2.extensions
from psycopg2.extensions import connection as pgconnection
from psycopg2.extensions import encodings as pgencodings
from psycopg2.extensions import make_dsn
from psycopg2.extensions import parse_dsn as _parse_dsn
from swh.core.utils import numfile_sortkey as sortkey
logger = logging.getLogger(__name__)
def now():
return datetime.now(tz=timezone.utc)
def stored_procedure(stored_proc):
"""decorator to execute remote stored procedure, specified as argument
Generally, the body of the decorated function should be empty. If it is
not, the stored procedure will be executed first; the function body then.
"""
def wrap(meth):
@functools.wraps(meth)
def _meth(self, *args, **kwargs):
cur = kwargs.get("cur", None)
self._cursor(cur).execute("SELECT %s()" % stored_proc)
meth(self, *args, **kwargs)
return _meth
return wrap
def jsonize(value):
"""Convert a value to a psycopg2 JSON object if necessary"""
if isinstance(value, dict):
return psycopg2.extras.Json(value)
return value
def connect_to_conninfo(db_or_conninfo: Union[str, pgconnection]) -> pgconnection:
"""Connect to the database passed in argument
Args:
db_or_conninfo: A database connection, or a database connection info string
Returns:
a connected database handle
Raises:
psycopg2.Error if the database doesn't exist
"""
if isinstance(db_or_conninfo, pgconnection):
return db_or_conninfo
if "=" not in db_or_conninfo and "//" not in db_or_conninfo:
# Database name
db_or_conninfo = f"dbname={db_or_conninfo}"
db = psycopg2.connect(db_or_conninfo)
return db
def swh_db_version(db_or_conninfo: Union[str, pgconnection]) -> Optional[int]:
"""Retrieve the swh version of the database.
If the database is not initialized, this logs a warning and returns None.
Args:
db_or_conninfo: A database connection, or a database connection info string
Returns:
Either the version of the database, or None if it couldn't be detected
"""
try:
db = connect_to_conninfo(db_or_conninfo)
except psycopg2.Error:
logger.exception("Failed to connect to `%s`", db_or_conninfo)
# Database not initialized
return None
try:
with db.cursor() as c:
query = "select version from dbversion order by dbversion desc limit 1"
try:
c.execute(query)
result = c.fetchone()
if result:
return result[0]
except psycopg2.errors.UndefinedTable:
return None
except Exception:
logger.exception("Could not get version from `%s`", db_or_conninfo)
return None
def swh_db_versions(
db_or_conninfo: Union[str, pgconnection]
) -> Optional[List[Tuple[int, datetime, str]]]:
"""Retrieve the swh version history of the database.
If the database is not initialized, this logs a warning and returns None.
Args:
db_or_conninfo: A database connection, or a database connection info string
Returns:
Either the version of the database, or None if it couldn't be detected
"""
try:
db = connect_to_conninfo(db_or_conninfo)
except psycopg2.Error:
logger.exception("Failed to connect to `%s`", db_or_conninfo)
# Database not initialized
return None
try:
with db.cursor() as c:
query = (
"select version, release, description "
"from dbversion order by dbversion desc"
)
try:
c.execute(query)
return c.fetchall()
except psycopg2.errors.UndefinedTable:
return None
except Exception:
logger.exception("Could not get versions from `%s`", db_or_conninfo)
return None
def swh_db_upgrade(
conninfo: str, modname: str, to_version: Optional[int] = None
) -> int:
"""Upgrade the database at `conninfo` for module `modname`
This will run migration scripts found in the `sql/upgrades` subdirectory of
the module `modname`. By default, this will upgrade to the latest declared version.
Args:
conninfo: A database connection, or a database connection info string
modname: datastore module the database stores content for
to_version: if given, update the database to this version rather than the latest
"""
if to_version is None:
to_version = 99999999
db_module, db_version, db_flavor = get_database_info(conninfo)
if db_version is None:
raise ValueError("Unable to retrieve the current version of the database")
if db_module is None:
raise ValueError("Unable to retrieve the module of the database")
if db_module != modname:
raise ValueError(
"The stored module of the database is different than the given one"
)
sqlfiles = [
fname
for fname in get_sql_for_package(modname, upgrade=True)
if db_version < int(fname.stem) <= to_version
]
for sqlfile in sqlfiles:
new_version = int(path.splitext(path.basename(sqlfile))[0])
logger.info("Executing migration script {sqlfile}")
if db_version is not None and (new_version - db_version) > 1:
logger.error(
f"There are missing migration steps between {db_version} and "
f"{new_version}. It might be expected but it most unlikely is not. "
"Will stop here."
)
return db_version
execute_sqlfiles([sqlfile], conninfo, db_flavor)
# check if the db version has been updated by the upgrade script
db_version = swh_db_version(conninfo)
assert db_version is not None
if db_version == new_version:
# nothing to do, upgrade script did the job
pass
elif db_version == new_version - 1:
# it has not (new style), so do it
swh_set_db_version(
conninfo,
new_version,
desc=f"Upgraded to version {new_version} using {sqlfile}",
)
db_version = swh_db_version(conninfo)
else:
# upgrade script did it wrong
logger.error(
f"The upgrade script {sqlfile} did not update the dbversion table "
f"consistently ({db_version} vs. expected {new_version}). "
"Will stop migration here. Please check your migration scripts."
)
return db_version
return new_version
def swh_db_module(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]:
"""Retrieve the swh module used to create the database.
If the database is not initialized, this logs a warning and returns None.
Args:
db_or_conninfo: A database connection, or a database connection info string
Returns:
Either the module of the database, or None if it couldn't be detected
"""
try:
db = connect_to_conninfo(db_or_conninfo)
except psycopg2.Error:
logger.exception("Failed to connect to `%s`", db_or_conninfo)
# Database not initialized
return None
try:
with db.cursor() as c:
query = "select dbmodule from dbmodule limit 1"
try:
c.execute(query)
resp = c.fetchone()
if resp:
return resp[0]
except psycopg2.errors.UndefinedTable:
return None
except Exception:
logger.exception("Could not get module from `%s`", db_or_conninfo)
return None
def swh_set_db_module(
db_or_conninfo: Union[str, pgconnection], module: str, force=False
) -> None:
"""Set the swh module used to create the database.
Fails if the dbmodule is already set or the table does not exist.
Args:
db_or_conninfo: A database connection, or a database connection info string
module: the swh module to register (without the leading 'swh.')
"""
update = False
if module.startswith("swh."):
module = module[4:]
current_module = swh_db_module(db_or_conninfo)
if current_module is not None:
if current_module == module:
logger.warning("The database module is already set to %s", module)
return
if not force:
raise ValueError(
"The database module is already set to a value %s "
"different than given %s",
current_module,
module,
)
# force is True
update = True
try:
db = connect_to_conninfo(db_or_conninfo)
except psycopg2.Error:
logger.exception("Failed to connect to `%s`", db_or_conninfo)
# Database not initialized
return None
sqlfiles = [
fname
for fname in get_sql_for_package("swh.core.db")
if "dbmodule" in fname.stem
]
execute_sqlfiles(sqlfiles, db_or_conninfo)
with db.cursor() as c:
if update:
query = "update dbmodule set dbmodule = %s"
else:
query = "insert into dbmodule(dbmodule) values (%s)"
c.execute(query, (module,))
db.commit()
def swh_set_db_version(
db_or_conninfo: Union[str, pgconnection],
version: int,
ts: Optional[datetime] = None,
desc: str = "Work in progress",
) -> None:
"""Set the version of the database.
Fails if the dbversion table does not exists.
Args:
db_or_conninfo: A database connection, or a database connection info string
version: the version to add
"""
try:
db = connect_to_conninfo(db_or_conninfo)
except psycopg2.Error:
logger.exception("Failed to connect to `%s`", db_or_conninfo)
# Database not initialized
return None
if ts is None:
ts = now()
with db.cursor() as c:
query = (
"insert into dbversion(version, release, description) values (%s, %s, %s)"
)
c.execute(query, (version, ts, desc))
db.commit()
def swh_db_flavor(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]:
"""Retrieve the swh flavor of the database.
If the database is not initialized, or the database doesn't support
flavors, this returns None.
Args:
db_or_conninfo: A database connection, or a database connection info string
Returns:
The flavor of the database, or None if it could not be detected.
"""
try:
db = connect_to_conninfo(db_or_conninfo)
except psycopg2.Error:
logger.exception("Failed to connect to `%s`", db_or_conninfo)
# Database not initialized
return None
try:
with db.cursor() as c:
query = "select swh_get_dbflavor()"
try:
c.execute(query)
return c.fetchone()[0]
except psycopg2.errors.UndefinedFunction:
# function not found: no flavor
return None
except Exception:
logger.exception("Could not get flavor from `%s`", db_or_conninfo)
return None
# The following code has been imported from psycopg2, version 2.7.4,
# https://github.com/psycopg/psycopg2/tree/5afb2ce803debea9533e293eef73c92ffce95bcd
# and modified by Software Heritage.
#
# Original file: lib/extras.py
#
# psycopg2 is free software: you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option) any
# later version.
def _paginate(seq, page_size):
"""Consume an iterable and return it in chunks.
Every chunk is at most `page_size`. Never return an empty chunk.
"""
page = []
it = iter(seq)
while 1:
try:
for i in range(page_size):
page.append(next(it))
yield page
page = []
except StopIteration:
if page:
yield page
return
def _split_sql(sql):
"""Split *sql* on a single ``%s`` placeholder.
Split on the %s, perform %% replacement and return pre, post lists of
snippets.
"""
curr = pre = []
post = []
tokens = re.split(br"(%.)", sql)
for token in tokens:
if len(token) != 2 or token[:1] != b"%":
curr.append(token)
continue
if token[1:] == b"s":
if curr is pre:
curr = post
else:
raise ValueError("the query contains more than one '%s' placeholder")
elif token[1:] == b"%":
curr.append(b"%")
else:
raise ValueError(
"unsupported format character: '%s'"
% token[1:].decode("ascii", "replace")
)
if curr is pre:
raise ValueError("the query doesn't contain any '%s' placeholder")
return pre, post
def execute_values_generator(cur, sql, argslist, template=None, page_size=100):
"""Execute a statement using SQL ``VALUES`` with a sequence of parameters.
Rows returned by the query are returned through a generator.
You need to consume the generator for the queries to be executed!
:param cur: the cursor to use to execute the query.
:param sql: the query to execute. It must contain a single ``%s``
placeholder, which will be replaced by a `VALUES list`__.
Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``.
:param argslist: sequence of sequences or dictionaries with the arguments
to send to the query. The type and content must be consistent with
*template*.
:param template: the snippet to merge to every item in *argslist* to
compose the query.
- If the *argslist* items are sequences it should contain positional
placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there
are constants value...).
- If the *argslist* items are mappings it should contain named
placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``).
If not specified, assume the arguments are sequence and use a simple
positional template (i.e. ``(%s, %s, ...)``), with the number of
placeholders sniffed by the first element in *argslist*.
:param page_size: maximum number of *argslist* items to include in every
statement. If there are more items the function will execute more than
one statement.
:param yield_from_cur: Whether to yield results from the cursor in this
function directly.
.. __: https://www.postgresql.org/docs/current/static/queries-values.html
After the execution of the function the `cursor.rowcount` property will
**not** contain a total result.
"""
# we can't just use sql % vals because vals is bytes: if sql is bytes
# there will be some decoding error because of stupid codec used, and Py3
# doesn't implement % on bytes.
if not isinstance(sql, bytes):
sql = sql.encode(pgencodings[cur.connection.encoding])
pre, post = _split_sql(sql)
for page in _paginate(argslist, page_size=page_size):
if template is None:
template = b"(" + b",".join([b"%s"] * len(page[0])) + b")"
parts = pre[:]
for args in page:
parts.append(cur.mogrify(template, args))
parts.append(b",")
parts[-1:] = post
cur.execute(b"".join(parts))
yield from cur
def import_swhmodule(modname):
if not modname.startswith("swh."):
modname = f"swh.{modname}"
try:
m = import_module(modname)
except ImportError as exc:
logger.error(f"Could not load the {modname} module: {exc}")
return None
return m
def get_sql_for_package(modname: str, upgrade: bool = False) -> List[pathlib.Path]:
"""Return the (sorted) list of sql script files for the given swh module
If upgrade is True, return the list of available migration scripts,
otherwise, return the list of initialization scripts.
"""
m = import_swhmodule(modname)
if m is None:
raise ValueError(f"Module {modname} cannot be loaded")
sqldir = pathlib.Path(m.__file__).parent / "sql"
if upgrade:
sqldir /= "upgrades"
if not sqldir.is_dir():
raise ValueError(
"Module {} does not provide a db schema (no sql/ dir)".format(modname)
)
return sorted(sqldir.glob("*.sql"), key=lambda x: sortkey(x.name))
def populate_database_for_package(
modname: str, conninfo: str, flavor: Optional[str] = None
) -> Tuple[bool, Optional[int], Optional[str]]:
"""Populate the database, pointed at with ``conninfo``,
using the SQL files found in the package ``modname``.
Also fill the 'dbmodule' table with the given ``modname``.
Args:
modname: Name of the module of which we're loading the files
conninfo: connection info string for the SQL database
flavor: the module-specific flavor which we want to initialize the database under
Returns:
Tuple with three elements: whether the database has been initialized; the current
version of the database; if it exists, the flavor of the database.
"""
current_version = swh_db_version(conninfo)
if current_version is not None:
dbflavor = swh_db_flavor(conninfo)
return False, current_version, dbflavor
def globalsortkey(key):
"like sortkey but only on basenames"
return sortkey(path.basename(key))
sqlfiles = get_sql_for_package(modname) + get_sql_for_package("swh.core.db")
sqlfiles = sorted(sqlfiles, key=lambda x: sortkey(x.stem))
sqlfiles = [fpath for fpath in sqlfiles if "-superuser-" not in fpath.stem]
execute_sqlfiles(sqlfiles, conninfo, flavor)
# populate the dbmodule table
swh_set_db_module(conninfo, modname)
current_db_version = swh_db_version(conninfo)
dbflavor = swh_db_flavor(conninfo)
return True, current_db_version, dbflavor
def get_database_info(
conninfo: str,
) -> Tuple[Optional[str], Optional[int], Optional[str]]:
"""Get version, flavor and module of the db"""
dbmodule = swh_db_module(conninfo)
dbversion = swh_db_version(conninfo)
dbflavor = None
if dbversion is not None:
dbflavor = swh_db_flavor(conninfo)
return (dbmodule, dbversion, dbflavor)
def parse_dsn_or_dbname(dsn_or_dbname: str) -> Dict[str, str]:
"""Parse a psycopg2 dsn, falling back to supporting plain database names as well"""
try:
return _parse_dsn(dsn_or_dbname)
except psycopg2.ProgrammingError:
# psycopg2 failed to parse the DSN; it's probably a database name,
# handle it as such
return _parse_dsn(f"dbname={dsn_or_dbname}")
def init_admin_extensions(modname: str, conninfo: str) -> None:
"""The remaining initialization process -- running -superuser- SQL files -- is done
using the given conninfo, thus connecting to the newly created database
"""
sqlfiles = get_sql_for_package(modname)
sqlfiles = [fname for fname in sqlfiles if "-superuser-" in fname.stem]
execute_sqlfiles(sqlfiles, conninfo)
def create_database_for_package(
modname: str, conninfo: str, template: str = "template1"
):
"""Create the database pointed at with ``conninfo``, and initialize it using
-superuser- SQL files found in the package ``modname``.
Args:
modname: Name of the module of which we're loading the files
conninfo: connection info string or plain database name for the SQL database
template: the name of the database to connect to and use as template to create
the new database
"""
# Use the given conninfo string, but with dbname replaced by the template dbname
# for the database creation step
creation_dsn = parse_dsn_or_dbname(conninfo)
dbname = creation_dsn["dbname"]
creation_dsn["dbname"] = template
logger.debug("db_create dbname=%s (from %s)", dbname, template)
subprocess.check_call(
[
"psql",
"--quiet",
"--no-psqlrc",
"-v",
"ON_ERROR_STOP=1",
"-d",
make_dsn(**creation_dsn),
"-c",
f'CREATE DATABASE "{dbname}"',
]
)
init_admin_extensions(modname, conninfo)
def execute_sqlfiles(
- sqlfiles: Collection[pathlib.Path], conninfo: str, flavor: Optional[str] = None
+ sqlfiles: Collection[pathlib.Path],
+ db_or_conninfo: Union[str, pgconnection],
+ flavor: Optional[str] = None,
):
- """Execute a list of SQL files on the database pointed at with ``conninfo``.
+ """Execute a list of SQL files on the database pointed at with ``db_or_conninfo``.
Args:
sqlfiles: List of SQL files to execute
- conninfo: connection info string for the SQL database
+ db_or_conninfo: A database connection, or a database connection info string
flavor: the database flavor to initialize
"""
+ if isinstance(db_or_conninfo, str):
+ conninfo = db_or_conninfo
+ else:
+ conninfo = db_or_conninfo.dsn
+
psql_command = [
"psql",
"--quiet",
"--no-psqlrc",
"-v",
"ON_ERROR_STOP=1",
"-d",
conninfo,
]
flavor_set = False
for sqlfile in sqlfiles:
logger.debug(f"execute SQL file {sqlfile} dbname={conninfo}")
subprocess.check_call(psql_command + ["-f", str(sqlfile)])
if (
flavor is not None
and not flavor_set
and sqlfile.name.endswith("-flavor.sql")
):
logger.debug("Setting database flavor %s", flavor)
query = f"insert into dbflavor (flavor) values ('{flavor}')"
subprocess.check_call(psql_command + ["-c", query])
flavor_set = True
if flavor is not None and not flavor_set:
logger.warn(
"Asked for flavor %s, but module does not support database flavors", flavor,
)

File Metadata

Mime Type
text/x-diff
Expires
Fri, Jul 4, 2:29 PM (2 d, 6 h ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3312380

Event Timeline