diff --git a/PKG-INFO b/PKG-INFO index 4e531ef..179fb42 100644 --- a/PKG-INFO +++ b/PKG-INFO @@ -1,91 +1,93 @@ Metadata-Version: 2.1 Name: swh.core -Version: 0.0.95 +Version: 0.1.0 Summary: Software Heritage core utilities Home-page: https://forge.softwareheritage.org/diffusion/DCORE/ Author: Software Heritage developers Author-email: swh-devel@inria.fr License: UNKNOWN Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate Project-URL: Source, https://forge.softwareheritage.org/source/swh-core +Project-URL: Documentation, https://docs.softwareheritage.org/devel/swh-core/ Description: swh-core ======== core library for swh's modules: - config parser - hash computations - serialization - logging mechanism - database connection - http-based RPC client/server Development ----------- We strongly recommend you to use a [virtualenv][1] if you want to run tests or hack the code. To set up your development environment: ``` (swh) user@host:~/swh-environment/swh-core$ pip install -e .[testing] ``` This will install every Python package needed to run this package's tests. Unit tests can be executed using [pytest][2] or [tox][3]. ``` (swh) user@host:~/swh-environment/swh-core$ pytest ============================== test session starts ============================== platform linux -- Python 3.7.3, pytest-3.10.1, py-1.8.0, pluggy-0.12.0 hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/home/ddouard/src/swh-environment/swh-core/.hypothesis/examples') rootdir: /home/ddouard/src/swh-environment/swh-core, inifile: pytest.ini plugins: requests-mock-1.6.0, hypothesis-4.26.4, celery-4.3.0, postgresql-1.4.1 collected 89 items swh/core/api/tests/test_api.py .. [ 2%] swh/core/api/tests/test_async.py .... [ 6%] swh/core/api/tests/test_serializers.py ..... [ 12%] swh/core/db/tests/test_db.py .... [ 16%] swh/core/tests/test_cli.py ...... [ 23%] swh/core/tests/test_config.py .............. [ 39%] swh/core/tests/test_statsd.py ........................................... [ 87%] .... [ 92%] swh/core/tests/test_utils.py ....... [100%] ===================== 89 passed, 9 warnings in 6.94 seconds ===================== ``` Note: this git repository uses [pre-commit][4] hooks to ensure better and more consistent code. It should already be installed in your virtualenv (if not, just type `pip install pre-commit`). Make sure to activate it in your local copy of the git repository: ``` (swh) user@host:~/swh-environment/swh-core$ pre-commit install pre-commit installed at .git/hooks/pre-commit ``` Please read the [developer setup manual][5] for more information on how to hack on Software Heritage. [1]: https://virtualenv.pypa.io [2]: https://docs.pytest.org [3]: https://tox.readthedocs.io [4]: https://pre-commit.com [5]: https://docs.softwareheritage.org/devel/developer-setup.html Platform: UNKNOWN Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable +Requires-Python: >=3.7 Description-Content-Type: text/markdown Provides-Extra: testing-core Provides-Extra: logging Provides-Extra: db Provides-Extra: testing-db Provides-Extra: http Provides-Extra: testing diff --git a/requirements-test-db.txt b/requirements-test-db.txt index cfd42eb..32990ae 100644 --- a/requirements-test-db.txt +++ b/requirements-test-db.txt @@ -1 +1,2 @@ pytest-postgresql +typing-extensions diff --git a/setup.py b/setup.py index 65c2803..e5b2567 100755 --- a/setup.py +++ b/setup.py @@ -1,86 +1,88 @@ #!/usr/bin/env python3 # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os from setuptools import setup, find_packages from os import path from io import open here = path.abspath(path.dirname(__file__)) # Get the long description from the README file with open(path.join(here, "README.md"), encoding="utf-8") as f: long_description = f.read() def parse_requirements(*names): requirements = [] for name in names: if name: reqf = "requirements-%s.txt" % name else: reqf = "requirements.txt" if not os.path.exists(reqf): return requirements with open(reqf) as f: for line in f.readlines(): line = line.strip() if not line or line.startswith("#"): continue requirements.append(line) return requirements setup( name="swh.core", description="Software Heritage core utilities", long_description=long_description, long_description_content_type="text/markdown", + python_requires=">=3.7", author="Software Heritage developers", author_email="swh-devel@inria.fr", url="https://forge.softwareheritage.org/diffusion/DCORE/", packages=find_packages(), py_modules=["pytest_swh_core"], scripts=[], install_requires=parse_requirements(None, "swh"), setup_requires=["vcversioner"], extras_require={ "testing-core": parse_requirements("test"), "logging": parse_requirements("logging"), "db": parse_requirements("db"), "testing-db": parse_requirements("test-db"), "http": parse_requirements("http"), # kitchen sink, please do not use "testing": parse_requirements("test", "test-db", "db", "http", "logging"), }, vcversioner={}, include_package_data=True, entry_points=""" [console_scripts] swh=swh.core.cli:main swh-db-init=swh.core.cli.db:db_init [swh.cli.subcommands] db=swh.core.cli.db:db db-init=swh.core.cli.db:db_init [pytest11] pytest_swh_core = swh.core.pytest_plugin """, classifiers=[ "Programming Language :: Python :: 3", "Intended Audience :: Developers", "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", "Operating System :: OS Independent", "Development Status :: 5 - Production/Stable", ], project_urls={ "Bug Reports": "https://forge.softwareheritage.org/maniphest", "Funding": "https://www.softwareheritage.org/donate", "Source": "https://forge.softwareheritage.org/source/swh-core", + "Documentation": "https://docs.softwareheritage.org/devel/swh-core/", }, ) diff --git a/swh.core.egg-info/PKG-INFO b/swh.core.egg-info/PKG-INFO index 4e531ef..179fb42 100644 --- a/swh.core.egg-info/PKG-INFO +++ b/swh.core.egg-info/PKG-INFO @@ -1,91 +1,93 @@ Metadata-Version: 2.1 Name: swh.core -Version: 0.0.95 +Version: 0.1.0 Summary: Software Heritage core utilities Home-page: https://forge.softwareheritage.org/diffusion/DCORE/ Author: Software Heritage developers Author-email: swh-devel@inria.fr License: UNKNOWN Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate Project-URL: Source, https://forge.softwareheritage.org/source/swh-core +Project-URL: Documentation, https://docs.softwareheritage.org/devel/swh-core/ Description: swh-core ======== core library for swh's modules: - config parser - hash computations - serialization - logging mechanism - database connection - http-based RPC client/server Development ----------- We strongly recommend you to use a [virtualenv][1] if you want to run tests or hack the code. To set up your development environment: ``` (swh) user@host:~/swh-environment/swh-core$ pip install -e .[testing] ``` This will install every Python package needed to run this package's tests. Unit tests can be executed using [pytest][2] or [tox][3]. ``` (swh) user@host:~/swh-environment/swh-core$ pytest ============================== test session starts ============================== platform linux -- Python 3.7.3, pytest-3.10.1, py-1.8.0, pluggy-0.12.0 hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/home/ddouard/src/swh-environment/swh-core/.hypothesis/examples') rootdir: /home/ddouard/src/swh-environment/swh-core, inifile: pytest.ini plugins: requests-mock-1.6.0, hypothesis-4.26.4, celery-4.3.0, postgresql-1.4.1 collected 89 items swh/core/api/tests/test_api.py .. [ 2%] swh/core/api/tests/test_async.py .... [ 6%] swh/core/api/tests/test_serializers.py ..... [ 12%] swh/core/db/tests/test_db.py .... [ 16%] swh/core/tests/test_cli.py ...... [ 23%] swh/core/tests/test_config.py .............. [ 39%] swh/core/tests/test_statsd.py ........................................... [ 87%] .... [ 92%] swh/core/tests/test_utils.py ....... [100%] ===================== 89 passed, 9 warnings in 6.94 seconds ===================== ``` Note: this git repository uses [pre-commit][4] hooks to ensure better and more consistent code. It should already be installed in your virtualenv (if not, just type `pip install pre-commit`). Make sure to activate it in your local copy of the git repository: ``` (swh) user@host:~/swh-environment/swh-core$ pre-commit install pre-commit installed at .git/hooks/pre-commit ``` Please read the [developer setup manual][5] for more information on how to hack on Software Heritage. [1]: https://virtualenv.pypa.io [2]: https://docs.pytest.org [3]: https://tox.readthedocs.io [4]: https://pre-commit.com [5]: https://docs.softwareheritage.org/devel/developer-setup.html Platform: UNKNOWN Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable +Requires-Python: >=3.7 Description-Content-Type: text/markdown Provides-Extra: testing-core Provides-Extra: logging Provides-Extra: db Provides-Extra: testing-db Provides-Extra: http Provides-Extra: testing diff --git a/swh.core.egg-info/requires.txt b/swh.core.egg-info/requires.txt index e25eb0c..11098a2 100644 --- a/swh.core.egg-info/requires.txt +++ b/swh.core.egg-info/requires.txt @@ -1,52 +1,54 @@ Click Deprecated PyYAML sentry-sdk [db] psycopg2 [http] aiohttp aiohttp_utils>=3.1.1 arrow decorator Flask iso8601 msgpack>0.5 requests blinker [logging] systemd-python [testing] pytest pytest-mock requests-mock hypothesis>=3.11.0 pre-commit pytz pytest-postgresql +typing-extensions psycopg2 aiohttp aiohttp_utils>=3.1.1 arrow decorator Flask iso8601 msgpack>0.5 requests blinker systemd-python [testing-core] pytest pytest-mock requests-mock hypothesis>=3.11.0 pre-commit pytz [testing-db] pytest-postgresql +typing-extensions diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py index b846cf6..b4e0a3f 100644 --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -1,270 +1,277 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 import datetime from enum import Enum import json import traceback import types from uuid import UUID import arrow import iso8601 import msgpack from typing import Any, Dict, Union, Tuple from requests import Response +def encode_datetime(dt: datetime.datetime) -> str: + """Wrapper of datetime.datetime.isoformat() that forbids naive datetimes.""" + if dt.tzinfo is None: + raise ValueError(f"{dt} is a naive datetime.") + return dt.isoformat() + + ENCODERS = [ (arrow.Arrow, "arrow", arrow.Arrow.isoformat), - (datetime.datetime, "datetime", datetime.datetime.isoformat), + (datetime.datetime, "datetime", encode_datetime), ( datetime.timedelta, "timedelta", lambda o: { "days": o.days, "seconds": o.seconds, "microseconds": o.microseconds, }, ), (UUID, "uuid", str), # Only for JSON: (bytes, "bytes", lambda o: base64.b85encode(o).decode("ascii")), ] DECODERS = { "arrow": arrow.get, "datetime": lambda d: iso8601.parse_date(d, default_timezone=None), "timedelta": lambda d: datetime.timedelta(**d), "uuid": UUID, # Only for JSON: "bytes": base64.b85decode, } class MsgpackExtTypeCodes(Enum): LONG_INT = 1 def encode_data_client(data: Any, extra_encoders=None) -> bytes: try: return msgpack_dumps(data, extra_encoders=extra_encoders) except OverflowError as e: raise ValueError("Limits were reached. Please, check your input.\n" + str(e)) def decode_response(response: Response, extra_decoders=None) -> Any: content_type = response.headers["content-type"] if content_type.startswith("application/x-msgpack"): r = msgpack_loads(response.content, extra_decoders=extra_decoders) elif content_type.startswith("application/json"): r = json_loads(response.text, extra_decoders=extra_decoders) elif content_type.startswith("text/"): r = response.text else: raise ValueError("Wrong content type `%s` for API response" % content_type) return r class SWHJSONEncoder(json.JSONEncoder): """JSON encoder for data structures generated by Software Heritage. This JSON encoder extends the default Python JSON encoder and adds awareness for the following specific types: - bytes (get encoded as a Base85 string); - datetime.datetime (get encoded as an ISO8601 string). Non-standard types get encoded as a a dictionary with two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. SWHJSONEncoder also encodes arbitrary iterables as a list (allowing serialization of generators). Caveats: Limitations in the JSONEncoder extension mechanism prevent us from "escaping" dictionaries that only contain the swhtype and d keys, and therefore arbitrary data structures can't be round-tripped through SWHJSONEncoder and SWHJSONDecoder. """ def __init__(self, extra_encoders=None, **kwargs): super().__init__(**kwargs) self.encoders = ENCODERS if extra_encoders: self.encoders += extra_encoders def default(self, o: Any) -> Union[Dict[str, Union[Dict[str, int], str]], list]: for (type_, type_name, encoder) in self.encoders: if isinstance(o, type_): return { "swhtype": type_name, "d": encoder(o), } try: return super().default(o) except TypeError as e: try: iterable = iter(o) except TypeError: raise e from None else: return list(iterable) class SWHJSONDecoder(json.JSONDecoder): """JSON decoder for data structures encoded with SWHJSONEncoder. This JSON decoder extends the default Python JSON decoder, allowing the decoding of: - bytes (encoded as a Base85 string); - datetime.datetime (encoded as an ISO8601 string). Non-standard types must be encoded as a a dictionary with exactly two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. To limit the impact our encoding, if the swhtype key doesn't contain a known value, the dictionary is decoded as-is. """ def __init__(self, extra_decoders=None, **kwargs): super().__init__(**kwargs) self.decoders = DECODERS if extra_decoders: self.decoders = {**self.decoders, **extra_decoders} def decode_data(self, o: Any) -> Any: if isinstance(o, dict): if set(o.keys()) == {"d", "swhtype"}: if o["swhtype"] == "bytes": return base64.b85decode(o["d"]) decoder = self.decoders.get(o["swhtype"]) if decoder: return decoder(self.decode_data(o["d"])) return {key: self.decode_data(value) for key, value in o.items()} if isinstance(o, list): return [self.decode_data(value) for value in o] else: return o def raw_decode(self, s: str, idx: int = 0) -> Tuple[Any, int]: data, index = super().raw_decode(s, idx) return self.decode_data(data), index def json_dumps(data: Any, extra_encoders=None) -> str: return json.dumps(data, cls=SWHJSONEncoder, extra_encoders=extra_encoders) def json_loads(data: str, extra_decoders=None) -> Any: return json.loads(data, cls=SWHJSONDecoder, extra_decoders=extra_decoders) def msgpack_dumps(data: Any, extra_encoders=None) -> bytes: """Write data as a msgpack stream""" encoders = ENCODERS if extra_encoders: encoders += extra_encoders def encode_types(obj): if isinstance(obj, int): # integer overflowed while packing. Handle it as an extended type length, rem = divmod(obj.bit_length(), 8) if rem: length += 1 return msgpack.ExtType( MsgpackExtTypeCodes.LONG_INT.value, int.to_bytes(obj, length, "big") ) if isinstance(obj, types.GeneratorType): return list(obj) for (type_, type_name, encoder) in encoders: if isinstance(obj, type_): return { b"swhtype": type_name, b"d": encoder(obj), } return obj return msgpack.packb(data, use_bin_type=True, default=encode_types) def msgpack_loads(data: bytes, extra_decoders=None) -> Any: """Read data as a msgpack stream. .. Caution:: This function is used by swh.journal to decode the contents of the journal. This function **must** be kept backwards-compatible. """ decoders = DECODERS if extra_decoders: decoders = {**decoders, **extra_decoders} def ext_hook(code, data): if code == MsgpackExtTypeCodes.LONG_INT.value: return int.from_bytes(data, "big") raise ValueError("Unknown msgpack extended code %s" % code) def decode_types(obj): # Support for current encodings if set(obj.keys()) == {b"d", b"swhtype"}: decoder = decoders.get(obj[b"swhtype"]) if decoder: return decoder(obj[b"d"]) # Support for legacy encodings if b"__datetime__" in obj and obj[b"__datetime__"]: return iso8601.parse_date(obj[b"s"], default_timezone=None) if b"__uuid__" in obj and obj[b"__uuid__"]: return UUID(obj[b"s"]) if b"__timedelta__" in obj and obj[b"__timedelta__"]: return datetime.timedelta(**obj[b"s"]) if b"__arrow__" in obj and obj[b"__arrow__"]: return arrow.get(obj[b"s"]) # Fallthrough return obj try: try: return msgpack.unpackb( data, raw=False, object_hook=decode_types, ext_hook=ext_hook, strict_map_key=False, ) except TypeError: # msgpack < 0.6.0 return msgpack.unpackb( data, raw=False, object_hook=decode_types, ext_hook=ext_hook ) except TypeError: # msgpack < 0.5.2 return msgpack.unpackb( data, encoding="utf-8", object_hook=decode_types, ext_hook=ext_hook ) def exception_to_dict(exception): tb = traceback.format_exception(None, exception, exception.__traceback__) return { "exception": { "type": type(exception).__name__, "args": exception.args, "message": str(exception), "traceback": tb, } } diff --git a/swh/core/api/tests/test_async.py b/swh/core/api/tests/test_async.py index 7d24408..7b8400a 100644 --- a/swh/core/api/tests/test_async.py +++ b/swh/core/api/tests/test_async.py @@ -1,232 +1,241 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import msgpack import json import pytest from swh.core.api.asynchronous import RPCServerApp, Response from swh.core.api.asynchronous import encode_msgpack, decode_request from swh.core.api.serializers import msgpack_dumps, SWHJSONEncoder pytest_plugins = ["aiohttp.pytest_plugin", "pytester"] class TestServerException(Exception): pass class TestClientError(Exception): pass async def root(request): return Response("toor") STRUCT = { "txt": "something stupid", # 'date': datetime.date(2019, 6, 9), # not supported - "datetime": datetime.datetime(2019, 6, 9, 10, 12), + "datetime": datetime.datetime(2019, 6, 9, 10, 12, tzinfo=datetime.timezone.utc), "timedelta": datetime.timedelta(days=-2, hours=3), "int": 42, "float": 3.14, - "subdata": {"int": 42, "datetime": datetime.datetime(2019, 6, 10, 11, 12),}, - "list": [42, datetime.datetime(2019, 9, 10, 11, 12), "ok"], + "subdata": { + "int": 42, + "datetime": datetime.datetime( + 2019, 6, 10, 11, 12, tzinfo=datetime.timezone.utc + ), + }, + "list": [ + 42, + datetime.datetime(2019, 9, 10, 11, 12, tzinfo=datetime.timezone.utc), + "ok", + ], } async def struct(request): return Response(STRUCT) async def echo(request): data = await decode_request(request) return Response(data) async def server_exception(request): raise TestServerException() async def client_error(request): raise TestClientError() async def echo_no_nego(request): # let the content negotiation handle the serialization for us... data = await decode_request(request) ret = encode_msgpack(data) return ret def check_mimetype(src, dst): src = src.split(";")[0].strip() dst = dst.split(";")[0].strip() assert src == dst @pytest.fixture def async_app(): app = RPCServerApp() app.client_exception_classes = (TestClientError,) app.router.add_route("GET", "/", root) app.router.add_route("GET", "/struct", struct) app.router.add_route("POST", "/echo", echo) app.router.add_route("GET", "/server_exception", server_exception) app.router.add_route("GET", "/client_error", client_error) app.router.add_route("POST", "/echo-no-nego", echo_no_nego) return app async def test_get_simple(async_app, aiohttp_client) -> None: assert async_app is not None cli = await aiohttp_client(async_app) resp = await cli.get("/") assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") data = await resp.read() value = msgpack.unpackb(data, raw=False) assert value == "toor" async def test_get_server_exception(async_app, aiohttp_client) -> None: cli = await aiohttp_client(async_app) resp = await cli.get("/server_exception") assert resp.status == 500 data = await resp.read() data = msgpack.unpackb(data, raw=False) assert data["exception"]["type"] == "TestServerException" async def test_get_client_error(async_app, aiohttp_client) -> None: cli = await aiohttp_client(async_app) resp = await cli.get("/client_error") assert resp.status == 400 data = await resp.read() data = msgpack.unpackb(data, raw=False) assert data["exception"]["type"] == "TestClientError" async def test_get_simple_nego(async_app, aiohttp_client) -> None: cli = await aiohttp_client(async_app) for ctype in ("x-msgpack", "json"): resp = await cli.get("/", headers={"Accept": "application/%s" % ctype}) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) assert (await decode_request(resp)) == "toor" async def test_get_struct(async_app, aiohttp_client) -> None: """Test returned structured from a simple GET data is OK""" cli = await aiohttp_client(async_app) resp = await cli.get("/struct") assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT async def test_get_struct_nego(async_app, aiohttp_client) -> None: """Test returned structured from a simple GET data is OK""" cli = await aiohttp_client(async_app) for ctype in ("x-msgpack", "json"): resp = await cli.get("/struct", headers={"Accept": "application/%s" % ctype}) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) assert (await decode_request(resp)) == STRUCT async def test_post_struct_msgpack(async_app, aiohttp_client) -> None: """Test that msgpack encoded posted struct data is returned as is""" cli = await aiohttp_client(async_app) # simple struct resp = await cli.post( "/echo", headers={"Content-Type": "application/x-msgpack"}, data=msgpack_dumps({"toto": 42}), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == {"toto": 42} # complex struct resp = await cli.post( "/echo", headers={"Content-Type": "application/x-msgpack"}, data=msgpack_dumps(STRUCT), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT async def test_post_struct_json(async_app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as is""" cli = await aiohttp_client(async_app) resp = await cli.post( "/echo", headers={"Content-Type": "application/json"}, data=json.dumps({"toto": 42}, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == {"toto": 42} resp = await cli.post( "/echo", headers={"Content-Type": "application/json"}, data=json.dumps(STRUCT, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") # assert resp.headers['Content-Type'] == 'application/x-msgpack' assert (await decode_request(resp)) == STRUCT async def test_post_struct_nego(async_app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as is using content negotiation (accept json or msgpack). """ cli = await aiohttp_client(async_app) for ctype in ("x-msgpack", "json"): resp = await cli.post( "/echo", headers={ "Content-Type": "application/json", "Accept": "application/%s" % ctype, }, data=json.dumps(STRUCT, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) assert (await decode_request(resp)) == STRUCT async def test_post_struct_no_nego(async_app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as msgpack when using non-negotiation-compatible handlers. """ cli = await aiohttp_client(async_app) for ctype in ("x-msgpack", "json"): resp = await cli.post( "/echo-no-nego", headers={ "Content-Type": "application/json", "Accept": "application/%s" % ctype, }, data=json.dumps(STRUCT, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py index 5211310..b710657 100644 --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -1,186 +1,203 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import json from typing import Any, Callable, List, Tuple import unittest from uuid import UUID import arrow import requests import requests_mock from swh.core.api.serializers import ( SWHJSONDecoder, SWHJSONEncoder, msgpack_dumps, msgpack_loads, decode_response, ) class ExtraType: def __init__(self, arg1, arg2): self.arg1 = arg1 self.arg2 = arg2 def __repr__(self): return f"ExtraType({self.arg1}, {self.arg2})" def __eq__(self, other): return isinstance(other, ExtraType) and (self.arg1, self.arg2) == ( other.arg1, other.arg2, ) extra_encoders: List[Tuple[type, str, Callable[..., Any]]] = [ (ExtraType, "extratype", lambda o: (o.arg1, o.arg2)) ] extra_decoders = { "extratype": lambda o: ExtraType(*o), } class Serializers(unittest.TestCase): def setUp(self): self.tz = datetime.timezone(datetime.timedelta(minutes=118)) self.data = { "bytes": b"123456789\x99\xaf\xff\x00\x12", - "datetime_naive": datetime.datetime(2015, 1, 1, 12, 4, 42, 231455), "datetime_tz": datetime.datetime( 2015, 3, 4, 18, 25, 13, 1234, tzinfo=self.tz ), "datetime_utc": datetime.datetime( 2015, 3, 4, 18, 25, 13, 1234, tzinfo=datetime.timezone.utc ), "datetime_delta": datetime.timedelta(64), "arrow_date": arrow.get("2018-04-25T16:17:53.533672+00:00"), "swhtype": "fake", "swh_dict": {"swhtype": 42, "d": "test"}, "random_dict": {"swhtype": 43}, "uuid": UUID("cdd8f804-9db6-40c3-93ab-5955d3836234"), } self.encoded_data = { "bytes": {"swhtype": "bytes", "d": "F)}kWH8wXmIhn8j01^"}, - "datetime_naive": { - "swhtype": "datetime", - "d": "2015-01-01T12:04:42.231455", - }, "datetime_tz": { "swhtype": "datetime", "d": "2015-03-04T18:25:13.001234+01:58", }, "datetime_utc": { "swhtype": "datetime", "d": "2015-03-04T18:25:13.001234+00:00", }, "datetime_delta": { "swhtype": "timedelta", "d": {"days": 64, "seconds": 0, "microseconds": 0}, }, "arrow_date": {"swhtype": "arrow", "d": "2018-04-25T16:17:53.533672+00:00"}, "swhtype": "fake", "swh_dict": {"swhtype": 42, "d": "test"}, "random_dict": {"swhtype": 43}, "uuid": {"swhtype": "uuid", "d": "cdd8f804-9db6-40c3-93ab-5955d3836234"}, } self.legacy_msgpack = { "bytes": b"\xc4\x0e123456789\x99\xaf\xff\x00\x12", - "datetime_naive": ( - b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xba" - b"2015-01-01T12:04:42.231455" - ), "datetime_tz": ( b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xd9 " b"2015-03-04T18:25:13.001234+01:58" ), "datetime_utc": ( b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xd9 " b"2015-03-04T18:25:13.001234+00:00" ), "datetime_delta": ( b"\x82\xc4\r__timedelta__\xc3\xc4\x01s\x83\xa4" b"days@\xa7seconds\x00\xacmicroseconds\x00" ), "arrow_date": ( b"\x82\xc4\t__arrow__\xc3\xc4\x01s\xd9 " b"2018-04-25T16:17:53.533672+00:00" ), "swhtype": b"\xa4fake", "swh_dict": b"\x82\xa7swhtype*\xa1d\xa4test", "random_dict": b"\x81\xa7swhtype+", "uuid": ( b"\x82\xc4\x08__uuid__\xc3\xc4\x01s\xd9$" b"cdd8f804-9db6-40c3-93ab-5955d3836234" ), } self.generator = (i for i in range(5)) self.gen_lst = list(range(5)) def test_round_trip_json(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.data, json.loads(data, cls=SWHJSONDecoder)) def test_round_trip_json_extra_types(self): original_data = [ExtraType("baz", self.data), "qux"] data = json.dumps( original_data, cls=SWHJSONEncoder, extra_encoders=extra_encoders ) self.assertEqual( original_data, json.loads(data, cls=SWHJSONDecoder, extra_decoders=extra_decoders), ) def test_encode_swh_json(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.encoded_data, json.loads(data)) def test_round_trip_msgpack(self): original_data = { **self.data, "none_dict_key": {None: 42}, "long_int_is_loooong": 10000000000000000000000000000000, } data = msgpack_dumps(original_data) self.assertEqual(original_data, msgpack_loads(data)) def test_round_trip_msgpack_extra_types(self): original_data = [ExtraType("baz", self.data), "qux"] data = msgpack_dumps(original_data, extra_encoders=extra_encoders) self.assertEqual( original_data, msgpack_loads(data, extra_decoders=extra_decoders) ) def test_generator_json(self): data = json.dumps(self.generator, cls=SWHJSONEncoder) self.assertEqual(self.gen_lst, json.loads(data, cls=SWHJSONDecoder)) def test_generator_msgpack(self): data = msgpack_dumps(self.generator) self.assertEqual(self.gen_lst, msgpack_loads(data)) @requests_mock.Mocker() def test_decode_response_json(self, mock_requests): mock_requests.get( "https://example.org/test/data", json=self.encoded_data, headers={"content-type": "application/json"}, ) response = requests.get("https://example.org/test/data") assert decode_response(response) == self.data def test_decode_legacy_msgpack(self): for k, v in self.legacy_msgpack.items(): assert msgpack_loads(v) == self.data[k] + + def test_encode_native_datetime(self): + dt = datetime.datetime(2015, 1, 1, 12, 4, 42, 231455) + with self.assertRaises(ValueError, matches="naive datetime"): + msgpack_dumps(dt) + + def test_decode_naive_datetime(self): + expected_dt = datetime.datetime(2015, 1, 1, 12, 4, 42, 231455) + + # Current encoding + assert ( + msgpack_loads( + b"\x82\xc4\x07swhtype\xa8datetime\xc4\x01d\xba" + b"2015-01-01T12:04:42.231455" + ) + == expected_dt + ) + + # Legacy encoding + assert ( + msgpack_loads( + b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xba" + b"2015-01-01T12:04:42.231455" + ) + == expected_dt + ) diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py index d018792..78d3c81 100755 --- a/swh/core/cli/db.py +++ b/swh/core/cli/db.py @@ -1,188 +1,190 @@ #!/usr/bin/env python3 -# Copyright (C) 2018 The Software Heritage developers +# Copyright (C) 2018-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import glob import logging from os import path, environ import subprocess import warnings -warnings.filterwarnings("ignore") # noqa prevent psycopg from telling us sh*t - import click from swh.core.cli import CONTEXT_SETTINGS from swh.core.config import read as config_read + +warnings.filterwarnings("ignore") # noqa prevent psycopg from telling us sh*t + + logger = logging.getLogger(__name__) @click.group(name="db", context_settings=CONTEXT_SETTINGS) @click.option( "--config-file", "-C", default=None, type=click.Path(exists=True, dir_okay=False), help="Configuration file.", ) @click.pass_context def db(ctx, config_file): """Software Heritage database generic tools. """ ctx.ensure_object(dict) if config_file is None: config_file = environ.get("SWH_CONFIG_FILENAME") cfg = config_read(config_file) ctx.obj["config"] = cfg @db.command(name="init", context_settings=CONTEXT_SETTINGS) @click.pass_context def init(ctx): """Initialize the database for every Software Heritage module found in the configuration file. For every configuration section in the config file that: 1. has the name of an existing swh package, 2. has credentials for a local db access, it will run the initialization scripts from the swh package against the given database. Example for the config file:: \b storage: cls: local args: db: postgresql:///?service=swh-storage objstorage: cls: remote args: url: http://swh-objstorage:5003/ the command: swh db -C /path/to/config.yml init will initialize the database for the `storage` section using initialization scripts from the `swh.storage` package. """ for modname, cfg in ctx.obj["config"].items(): if cfg.get("cls") == "local" and cfg.get("args"): try: sqlfiles = get_sql_for_package(modname) except click.BadParameter: logger.info( "Failed to load/find sql initialization files for %s", modname ) if sqlfiles: conninfo = cfg["args"]["db"] for sqlfile in sqlfiles: subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", conninfo, "-f", sqlfile, ] ) @click.command(context_settings=CONTEXT_SETTINGS) @click.argument("module", nargs=-1, required=True) @click.option( "--db-name", "-d", help="Database name.", default="softwareheritage-dev", show_default=True, ) @click.option( "--create-db/--no-create-db", "-C", help="Attempt to create the database.", default=False, ) def db_init(module, db_name, create_db): """Initialise a database for the Software Heritage . By default, does not attempt to create the database. Example: swh db-init -d swh-test storage If you want to specify non-default postgresql connection parameters, please provide them using standard environment variables. See psql(1) man page (section ENVIRONMENTS) for details. Example: PGPORT=5434 swh db-init indexer """ # put import statements here so we can keep startup time of the main swh # command as short as possible from swh.core.db.tests.db_testing import ( pg_createdb, pg_restore, DB_DUMP_TYPES, swh_db_version, ) logger.debug("db_init %s dn_name=%s", module, db_name) dump_files = [] for modname in module: dump_files.extend(get_sql_for_package(modname)) if create_db: # Create the db (or fail silently if already existing) pg_createdb(db_name, check=False) # Try to retrieve the db version if any db_version = swh_db_version(db_name) if not db_version: # Initialize the db dump_files = [(x, DB_DUMP_TYPES[path.splitext(x)[1]]) for x in dump_files] for dump, dtype in dump_files: click.secho("Loading {}".format(dump), fg="yellow") pg_restore(db_name, dump, dtype) db_version = swh_db_version(db_name) # TODO: Ideally migrate the version from db_version to the latest # db version click.secho( "DONE database is {} version {}".format(db_name, db_version), fg="green", bold=True, ) def get_sql_for_package(modname): from importlib import import_module from swh.core.utils import numfile_sortkey as sortkey if not modname.startswith("swh."): modname = "swh.{}".format(modname) try: m = import_module(modname) except ImportError: raise click.BadParameter("Unable to load module {}".format(modname)) sqldir = path.join(path.dirname(m.__file__), "sql") if not path.isdir(sqldir): raise click.BadParameter( "Module {} does not provide a db schema " "(no sql/ dir)".format(modname) ) return list(sorted(glob.glob(path.join(sqldir, "*.sql")), key=sortkey)) diff --git a/swh/core/db/__init__.py b/swh/core/db/__init__.py index 307f038..a98d5e9 100644 --- a/swh/core/db/__init__.py +++ b/swh/core/db/__init__.py @@ -1,216 +1,299 @@ # Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import binascii import datetime import enum import json import logging import os import sys import threading +from typing import Any, Callable, Iterable, Iterator, Mapping, Optional, Type, TypeVar from contextlib import contextmanager import psycopg2 import psycopg2.extras +import psycopg2.pool logger = logging.getLogger(__name__) psycopg2.extras.register_uuid() -def escape(data): +def render_array(data) -> str: + """Render the data as a postgresql array""" + # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-IO + # "The external text representation of an array value consists of items that are + # interpreted according to the I/O conversion rules for the array's element type, + # plus decoration that indicates the array structure. The decoration consists of + # curly braces ({ and }) around the array value plus delimiter characters between + # adjacent items. The delimiter character is usually a comma (,)" + return "{%s}" % ",".join(render_array_element(e) for e in data) + + +def render_array_element(element) -> str: + """Render an element from an array.""" + if element is None: + # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-IO + # "If the value written for an element is NULL (in any case variant), the + # element is taken to be NULL." + return "NULL" + elif isinstance(element, (list, tuple)): + # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-INPUT + # "Each val is either a constant of the array element type, or a subarray." + return render_array(element) + else: + # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-IO + # "When writing an array value you can use double quotes around any individual + # array element. [...] Empty strings and strings matching the word NULL must be + # quoted, too. To put a double quote or backslash in a quoted array element + # value, precede it with a backslash." + ret = value_as_pg_text(element) + return '"%s"' % ret.replace("\\", "\\\\").replace('"', '\\"') + + +def value_as_pg_text(data: Any) -> str: + """Render the given data in the postgresql text format. + + NULL values are handled **outside** of this function (either by + :func:`render_array_element`, or by :meth:`BaseDb.copy_to`.) + """ + if data is None: - return "" + raise ValueError("value_as_pg_text doesn't handle NULLs") + if isinstance(data, bytes): - return "\\x%s" % binascii.hexlify(data).decode("ascii") - elif isinstance(data, str): - return '"%s"' % data.replace('"', '""') + return "\\x%s" % data.hex() elif isinstance(data, datetime.datetime): - # We escape twice to make sure the string generated by - # isoformat gets escaped - return escape(data.isoformat()) + return data.isoformat() elif isinstance(data, dict): - return escape(json.dumps(data)) - elif isinstance(data, list): - return escape("{%s}" % ",".join(escape(d) for d in data)) + return json.dumps(data) + elif isinstance(data, (list, tuple)): + return render_array(data) elif isinstance(data, psycopg2.extras.Range): - # We escape twice here too, so that we make sure - # everything gets passed to copy properly - return escape( - "%s%s,%s%s" - % ( - "[" if data.lower_inc else "(", - "-infinity" if data.lower_inf else escape(data.lower), - "infinity" if data.upper_inf else escape(data.upper), - "]" if data.upper_inc else ")", - ) + return "%s%s,%s%s" % ( + "[" if data.lower_inc else "(", + "-infinity" if data.lower_inf else value_as_pg_text(data.lower), + "infinity" if data.upper_inf else value_as_pg_text(data.upper), + "]" if data.upper_inc else ")", ) elif isinstance(data, enum.IntEnum): - return escape(int(data)) + return str(int(data)) else: - # We don't escape here to make sure we pass literals properly return str(data) +def escape_copy_column(column: str) -> str: + """Escape the text representation of a column for use by COPY.""" + # From https://www.postgresql.org/docs/11/sql-copy.html + # File Formats > Text Format + # "Backslash characters (\) can be used in the COPY data to quote data characters + # that might otherwise be taken as row or column delimiters. In particular, the + # following characters must be preceded by a backslash if they appear as part of a + # column value: backslash itself, newline, carriage return, and the current + # delimiter character." + ret = ( + column.replace("\\", "\\\\") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t") + ) + + return ret + + def typecast_bytea(value, cur): if value is not None: data = psycopg2.BINARY(value, cur) return data.tobytes() +BaseDbType = TypeVar("BaseDbType", bound="BaseDb") + + class BaseDb: """Base class for swh.*.*Db. cf. swh.storage.db.Db, swh.archiver.db.ArchiverDb """ - @classmethod - def adapt_conn(cls, conn): + @staticmethod + def adapt_conn(conn: psycopg2.extensions.connection): """Makes psycopg2 use 'bytes' to decode bytea instead of 'memoryview', for this connection.""" t_bytes = psycopg2.extensions.new_type((17,), "bytea", typecast_bytea) psycopg2.extensions.register_type(t_bytes, conn) t_bytes_array = psycopg2.extensions.new_array_type((1001,), "bytea[]", t_bytes) psycopg2.extensions.register_type(t_bytes_array, conn) @classmethod - def connect(cls, *args, **kwargs): + def connect(cls: Type[BaseDbType], *args, **kwargs) -> BaseDbType: """factory method to create a DB proxy Accepts all arguments of psycopg2.connect; only some specific possibilities are reported below. Args: connstring: libpq2 connection string """ conn = psycopg2.connect(*args, **kwargs) return cls(conn) @classmethod - def from_pool(cls, pool): + def from_pool( + cls: Type[BaseDbType], pool: psycopg2.pool.AbstractConnectionPool + ) -> BaseDbType: conn = pool.getconn() return cls(conn, pool=pool) - def __init__(self, conn, pool=None): + def __init__( + self, + conn: psycopg2.extensions.connection, + pool: Optional[psycopg2.pool.AbstractConnectionPool] = None, + ): """create a DB proxy Args: conn: psycopg2 connection to the SWH DB pool: psycopg2 pool of connections """ self.adapt_conn(conn) self.conn = conn self.pool = pool - def put_conn(self): + def put_conn(self) -> None: if self.pool: self.pool.putconn(self.conn) - def cursor(self, cur_arg=None): + def cursor( + self, cur_arg: Optional[psycopg2.extensions.cursor] = None + ) -> psycopg2.extensions.cursor: """get a cursor: from cur_arg if given, or a fresh one otherwise meant to avoid boilerplate if/then/else in methods that proxy stored procedures """ if cur_arg is not None: return cur_arg else: return self.conn.cursor() _cursor = cursor # for bw compat @contextmanager - def transaction(self): + def transaction(self) -> Iterator[psycopg2.extensions.cursor]: """context manager to execute within a DB transaction Yields: a psycopg2 cursor """ with self.conn.cursor() as cur: try: yield cur self.conn.commit() except Exception: if not self.conn.closed: self.conn.rollback() raise def copy_to( - self, items, tblname, columns, cur=None, item_cb=None, default_values={} - ): - """Copy items' entries to table tblname with columns information. + self, + items: Iterable[Mapping[str, Any]], + tblname: str, + columns: Iterable[str], + cur: Optional[psycopg2.extensions.cursor] = None, + item_cb: Optional[Callable[[Any], Any]] = None, + default_values: Optional[Mapping[str, Any]] = None, + ) -> None: + """Run the COPY command to insert the `columns` of each element of `items` into + `tblname`. Args: - items (List[dict]): dictionaries of data to copy over tblname. - tblname (str): destination table's name. - columns ([str]): keys to access data in items and also the - column names in the destination table. - default_values (dict): dictionary of default values to use when - inserting entried int the tblname table. + items: dictionaries of data to copy into `tblname`. + tblname: name of the destination table. + columns: columns of the destination table. Elements of `items` must have + these set as keys. + default_values: dictionary of default values to use when inserting entries + in `tblname`. cur: a db cursor; if not given, a new cursor will be created. - item_cb (fn): optional function to apply to items's entry. + item_cb: optional callback, run on each element of `items`, when it is + copied. + """ + if default_values is None: + default_values = {} read_file, write_file = os.pipe() exc_info = None def writer(): nonlocal exc_info cursor = self.cursor(cur) with open(read_file, "r") as f: try: cursor.copy_expert( - "COPY %s (%s) FROM STDIN CSV" % (tblname, ", ".join(columns)), f + "COPY %s (%s) FROM STDIN" % (tblname, ", ".join(columns)), f ) except Exception: # Tell the main thread about the exception exc_info = sys.exc_info() write_thread = threading.Thread(target=writer) write_thread.start() try: with open(write_file, "w") as f: + # From https://www.postgresql.org/docs/11/sql-copy.html + # File Formats > Text Format + # "When the text format is used, the data read or written is a text file + # with one line per table row. Columns in a row are separated by the + # delimiter character." + # NULL + # "The default is \N (backslash-N) in text format." + # DELIMITER + # "The default is a tab character in text format." for d in items: if item_cb is not None: item_cb(d) line = [] for k in columns: value = d.get(k, default_values.get(k)) try: - line.append(escape(value)) + if value is None: + line.append("\\N") + else: + line.append(escape_copy_column(value_as_pg_text(value))) except Exception as e: logger.error( "Could not escape value `%r` for column `%s`:" "Received exception: `%s`", value, k, e, ) raise e from None - f.write(",".join(line)) + f.write("\t".join(line)) f.write("\n") finally: # No problem bubbling up exceptions, but we still need to make sure # we finish copying, even though we're probably going to cancel the # transaction. write_thread.join() if exc_info: # postgresql returned an error, let's raise it. raise exc_info[1].with_traceback(exc_info[2]) - def mktemp(self, tblname, cur=None): + def mktemp(self, tblname: str, cur: Optional[psycopg2.extensions.cursor] = None): self.cursor(cur).execute("SELECT swh_mktemp(%s)", (tblname,)) diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py index 2cee3fc..93385d4 100644 --- a/swh/core/db/tests/test_db.py +++ b/swh/core/db/tests/test_db.py @@ -1,226 +1,438 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from dataclasses import dataclass +import datetime +from enum import IntEnum import inspect import os.path +from string import printable import tempfile +from typing import Any +from typing_extensions import Protocol import unittest from unittest.mock import Mock, MagicMock +import uuid from hypothesis import strategies, given +from hypothesis.extra.pytz import timezones import psycopg2 import pytest from swh.core.db import BaseDb from swh.core.db.common import db_transaction, db_transaction_generator from .db_testing import ( SingleDbTestFixture, db_create, db_destroy, db_close, ) -INIT_SQL = """ -create table test_table -( - i int, - txt text, - bytes bytea -); -""" - -db_rows = strategies.lists( - strategies.tuples( - strategies.integers(-2147483648, +2147483647), - strategies.text( - alphabet=strategies.characters( - blacklist_categories=["Cs"], # surrogates - blacklist_characters=[ - "\x00", # pgsql does not support the null codepoint - "\r", # pgsql normalizes those - ], +# workaround mypy bug https://github.com/python/mypy/issues/5485 +class Converter(Protocol): + def __call__(self, x: Any) -> Any: + ... + + +@dataclass +class Field: + name: str + """Column name""" + pg_type: str + """Type of the PostgreSQL column""" + example: Any + """Example value for the static tests""" + strategy: strategies.SearchStrategy + """Hypothesis strategy to generate these values""" + in_wrapper: Converter = lambda x: x + """Wrapper to convert this data type for the static tests""" + out_converter: Converter = lambda x: x + """Converter from the raw PostgreSQL column value to this data type""" + + +# Limit PostgreSQL integer values +pg_int = strategies.integers(-2147483648, +2147483647) + +pg_text = strategies.text( + alphabet=strategies.characters( + blacklist_categories=["Cs"], # surrogates + blacklist_characters=[ + "\x00", # pgsql does not support the null codepoint + "\r", # pgsql normalizes those + ], + ), +) + +pg_bytea = strategies.binary() + + +def pg_bytea_a(min_size: int, max_size: int) -> strategies.SearchStrategy: + """Generate a PostgreSQL bytea[]""" + return strategies.lists(pg_bytea, min_size=min_size, max_size=max_size) + + +def pg_bytea_a_a(min_size: int, max_size: int) -> strategies.SearchStrategy: + """Generate a PostgreSQL bytea[][]. The inner lists must all have the same size.""" + return strategies.integers(min_value=max(1, min_size), max_value=max_size).flatmap( + lambda n: strategies.lists( + pg_bytea_a(min_size=n, max_size=n), min_size=min_size, max_size=max_size + ) + ) + + +def pg_tstz() -> strategies.SearchStrategy: + """Generate values that fit in a PostgreSQL timestamptz. + + Notes: + We're forbidding old datetimes, because until 1956, many timezones had + seconds in their "UTC offsets" (see + ), which is + not representable by PostgreSQL. + + """ + min_value = datetime.datetime(1960, 1, 1, 0, 0, 0) + return strategies.datetimes(min_value=min_value, timezones=timezones()) + + +def pg_jsonb(min_size: int, max_size: int) -> strategies.SearchStrategy: + """Generate values representable as a PostgreSQL jsonb object (dict).""" + return strategies.dictionaries( + strategies.text(printable), + strategies.recursive( + # should use floats() instead of integers(), but PostgreSQL + # coerces large integers into floats, making the tests fail. We + # only store ints in our generated data anyway. + strategies.none() + | strategies.booleans() + | strategies.integers(-2147483648, +2147483647) + | strategies.text(printable), + lambda children: strategies.lists(children, max_size=max_size) + | strategies.dictionaries( + strategies.text(printable), children, max_size=max_size ), ), - strategies.binary(), + min_size=min_size, + max_size=max_size, ) + + +def tuple_2d_to_list_2d(v): + """Convert a 2D tuple to a 2D list""" + return [list(inner) for inner in v] + + +def list_2d_to_tuple_2d(v): + """Convert a 2D list to a 2D tuple""" + return tuple(tuple(inner) for inner in v) + + +class TestIntEnum(IntEnum): + foo = 1 + bar = 2 + + +def now(): + return datetime.datetime.now(tz=datetime.timezone.utc) + + +FIELDS = ( + Field("i", "int", 1, pg_int), + Field("txt", "text", "foo", pg_text), + Field("bytes", "bytea", b"bar", strategies.binary()), + Field( + "bytes_array", + "bytea[]", + [b"baz1", b"baz2"], + pg_bytea_a(min_size=0, max_size=5), + ), + Field( + "bytes_tuple", + "bytea[]", + (b"baz1", b"baz2"), + pg_bytea_a(min_size=0, max_size=5).map(tuple), + in_wrapper=list, + out_converter=tuple, + ), + Field( + "bytes_2d", + "bytea[][]", + [[b"quux1"], [b"quux2"]], + pg_bytea_a_a(min_size=0, max_size=5), + ), + Field( + "bytes_2d_tuple", + "bytea[][]", + ((b"quux1",), (b"quux2",)), + pg_bytea_a_a(min_size=0, max_size=5).map(list_2d_to_tuple_2d), + in_wrapper=tuple_2d_to_list_2d, + out_converter=list_2d_to_tuple_2d, + ), + Field("ts", "timestamptz", now(), pg_tstz(),), + Field( + "dict", + "jsonb", + {"str": "bar", "int": 1, "list": ["a", "b"], "nested": {"a": "b"}}, + pg_jsonb(min_size=0, max_size=5), + in_wrapper=psycopg2.extras.Json, + ), + Field( + "intenum", + "int", + TestIntEnum.foo, + strategies.sampled_from(TestIntEnum), + in_wrapper=int, + out_converter=TestIntEnum, + ), + Field("uuid", "uuid", uuid.uuid4(), strategies.uuids()), + Field( + "text_list", + "text[]", + # All the funky corner cases + ["null", "NULL", None, "\\", "\t", "\n", "\r", " ", "'", ",", '"', "{", "}"], + strategies.lists(pg_text, min_size=0, max_size=5), + ), + Field( + "tstz_list", + "timestamptz[]", + [now(), now() + datetime.timedelta(days=1)], + strategies.lists(pg_tstz(), min_size=0, max_size=5), + ), + Field( + "tstz_range", + "tstzrange", + psycopg2.extras.DateTimeTZRange( + lower=now(), upper=now() + datetime.timedelta(days=1), bounds="[)", + ), + strategies.tuples( + # generate two sorted timestamptzs for use as bounds + strategies.tuples(pg_tstz(), pg_tstz()).map(sorted), + # and a set of bounds + strategies.sampled_from(["[]", "()", "[)", "(]"]), + ).map( + # and build the actual DateTimeTZRange object from these args + lambda args: psycopg2.extras.DateTimeTZRange( + lower=args[0][0], upper=args[0][1], bounds=args[1], + ) + ), + ), ) +INIT_SQL = "create table test_table (%s)" % ", ".join( + f"{field.name} {field.pg_type}" for field in FIELDS +) + +COLUMNS = tuple(field.name for field in FIELDS) +INSERT_SQL = "insert into test_table (%s) values (%s)" % ( + ", ".join(COLUMNS), + ", ".join("%s" for i in range(len(COLUMNS))), +) + +STATIC_ROW_IN = tuple(field.in_wrapper(field.example) for field in FIELDS) +EXPECTED_ROW_OUT = tuple(field.example for field in FIELDS) + +db_rows = strategies.lists(strategies.tuples(*(field.strategy for field in FIELDS))) + + +def convert_lines(cur): + return [ + tuple(field.out_converter(x) for x, field in zip(line, FIELDS)) for line in cur + ] + @pytest.mark.db def test_connect(): db_name = db_create("test-db2", dumps=[]) try: db = BaseDb.connect("dbname=%s" % db_name) with db.cursor() as cur: + psycopg2.extras.register_default_jsonb(cur) cur.execute(INIT_SQL) - cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar")) + cur.execute(INSERT_SQL, STATIC_ROW_IN) cur.execute("select * from test_table;") - assert list(cur) == [(1, "foo", b"bar")] + output = convert_lines(cur) + assert len(output) == 1 + assert EXPECTED_ROW_OUT == output[0] finally: db_close(db.conn) db_destroy(db_name) @pytest.mark.db class TestDb(SingleDbTestFixture, unittest.TestCase): TEST_DB_NAME = "test-db" @classmethod def setUpClass(cls): with tempfile.TemporaryDirectory() as td: with open(os.path.join(td, "init.sql"), "a") as fd: fd.write(INIT_SQL) cls.TEST_DB_DUMP = os.path.join(td, "*.sql") super().setUpClass() def setUp(self): super().setUp() self.db = BaseDb(self.conn) def test_initialized(self): cur = self.db.cursor() - cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar")) + psycopg2.extras.register_default_jsonb(cur) + cur.execute(INSERT_SQL, STATIC_ROW_IN) cur.execute("select * from test_table;") - self.assertEqual(list(cur), [(1, "foo", b"bar")]) + output = convert_lines(cur) + assert len(output) == 1 + assert EXPECTED_ROW_OUT == output[0] def test_reset_tables(self): cur = self.db.cursor() - cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar")) + cur.execute(INSERT_SQL, STATIC_ROW_IN) self.reset_db_tables("test-db") cur.execute("select * from test_table;") - self.assertEqual(list(cur), []) + assert convert_lines(cur) == [] + + def test_copy_to_static(self): + items = [{field.name: field.example for field in FIELDS}] + self.db.copy_to(items, "test_table", COLUMNS) + + cur = self.db.cursor() + cur.execute("select * from test_table;") + output = convert_lines(cur) + assert len(output) == 1 + assert EXPECTED_ROW_OUT == output[0] @given(db_rows) def test_copy_to(self, data): - # the table is not reset between runs by hypothesis - self.reset_db_tables("test-db") + try: + # the table is not reset between runs by hypothesis + self.reset_db_tables("test-db") - items = [dict(zip(["i", "txt", "bytes"], item)) for item in data] - self.db.copy_to(items, "test_table", ["i", "txt", "bytes"]) + items = [dict(zip(COLUMNS, item)) for item in data] + self.db.copy_to(items, "test_table", COLUMNS) - cur = self.db.cursor() - cur.execute("select * from test_table;") - self.assertCountEqual(list(cur), data) + cur = self.db.cursor() + cur.execute("select * from test_table;") + assert convert_lines(cur) == data + finally: + self.db.conn.rollback() def test_copy_to_thread_exception(self): data = [(2 ** 65, "foo", b"bar")] - items = [dict(zip(["i", "txt", "bytes"], item)) for item in data] + items = [dict(zip(COLUMNS, item)) for item in data] with self.assertRaises(psycopg2.errors.NumericValueOutOfRange): - self.db.copy_to(items, "test_table", ["i", "txt", "bytes"]) + self.db.copy_to(items, "test_table", COLUMNS) def test_db_transaction(mocker): expected_cur = object() called = False class Storage: @db_transaction() def endpoint(self, cur=None, db=None): nonlocal called called = True assert cur is expected_cur storage = Storage() # 'with storage.get_db().transaction() as cur:' should cause # 'cur' to be 'expected_cur' db_mock = Mock() db_mock.transaction.return_value = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) put_db_mock = mocker.patch.object(storage, "put_db", create=True) storage.endpoint() assert called put_db_mock.assert_called_once_with(db_mock) def test_db_transaction__with_generator(): with pytest.raises(ValueError, match="generator"): class Storage: @db_transaction() def endpoint(self, cur=None, db=None): yield None def test_db_transaction_signature(): """Checks db_transaction removes the 'cur' and 'db' arguments.""" def f(self, foo, *, bar=None): pass expected_sig = inspect.signature(f) @db_transaction() def g(self, foo, *, bar=None, db=None, cur=None): pass actual_sig = inspect.signature(g) assert actual_sig == expected_sig def test_db_transaction_generator(mocker): expected_cur = object() called = False class Storage: @db_transaction_generator() def endpoint(self, cur=None, db=None): nonlocal called called = True assert cur is expected_cur yield None storage = Storage() # 'with storage.get_db().transaction() as cur:' should cause # 'cur' to be 'expected_cur' db_mock = Mock() db_mock.transaction.return_value = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) put_db_mock = mocker.patch.object(storage, "put_db", create=True) list(storage.endpoint()) assert called put_db_mock.assert_called_once_with(db_mock) def test_db_transaction_generator__with_nongenerator(): with pytest.raises(ValueError, match="generator"): class Storage: @db_transaction_generator() def endpoint(self, cur=None, db=None): pass def test_db_transaction_generator_signature(): """Checks db_transaction removes the 'cur' and 'db' arguments.""" def f(self, foo, *, bar=None): pass expected_sig = inspect.signature(f) @db_transaction_generator() def g(self, foo, *, bar=None, db=None, cur=None): yield None actual_sig = inspect.signature(g) assert actual_sig == expected_sig diff --git a/version.txt b/version.txt index 6e9dfe7..86ab1f6 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -v0.0.95-0-gdca9c5f \ No newline at end of file +v0.1.0-0-gce1e452 \ No newline at end of file