Page MenuHomeSoftware Heritage

test_db.py
No OneTemporary

test_db.py

# Copyright (C) 2019-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from dataclasses import dataclass
import datetime
from enum import IntEnum
import inspect
from string import printable
from typing import Any
from unittest.mock import MagicMock, Mock
import uuid
from hypothesis import given, settings, strategies
from hypothesis.extra.pytz import timezones
import psycopg2
import pytest
from typing_extensions import Protocol
from swh.core.db import BaseDb
from swh.core.db.common import db_transaction, db_transaction_generator
from swh.core.db.pytest_plugin import postgresql_fact
from swh.core.db.tests.conftest import function_scoped_fixture_check
# workaround mypy bug https://github.com/python/mypy/issues/5485
class Converter(Protocol):
def __call__(self, x: Any) -> Any:
...
@dataclass
class Field:
name: str
"""Column name"""
pg_type: str
"""Type of the PostgreSQL column"""
example: Any
"""Example value for the static tests"""
strategy: strategies.SearchStrategy
"""Hypothesis strategy to generate these values"""
in_wrapper: Converter = lambda x: x
"""Wrapper to convert this data type for the static tests"""
out_converter: Converter = lambda x: x
"""Converter from the raw PostgreSQL column value to this data type"""
# Limit PostgreSQL integer values
pg_int = strategies.integers(-2147483648, +2147483647)
pg_text = strategies.text(
alphabet=strategies.characters(
blacklist_categories=["Cs"], # surrogates
blacklist_characters=[
"\x00", # pgsql does not support the null codepoint
"\r", # pgsql normalizes those
],
),
)
pg_bytea = strategies.binary()
def pg_bytea_a(min_size: int, max_size: int) -> strategies.SearchStrategy:
"""Generate a PostgreSQL bytea[]"""
return strategies.lists(pg_bytea, min_size=min_size, max_size=max_size)
def pg_bytea_a_a(min_size: int, max_size: int) -> strategies.SearchStrategy:
"""Generate a PostgreSQL bytea[][]. The inner lists must all have the same size."""
return strategies.integers(min_value=max(1, min_size), max_value=max_size).flatmap(
lambda n: strategies.lists(
pg_bytea_a(min_size=n, max_size=n), min_size=min_size, max_size=max_size
)
)
def pg_tstz() -> strategies.SearchStrategy:
"""Generate values that fit in a PostgreSQL timestamptz.
Notes:
We're forbidding old datetimes, because until 1956, many timezones had
seconds in their "UTC offsets" (see
<https://en.wikipedia.org/wiki/Time_zone Worldwide_time_zones>), which is
not representable by PostgreSQL.
"""
min_value = datetime.datetime(1960, 1, 1, 0, 0, 0)
return strategies.datetimes(min_value=min_value, timezones=timezones())
def pg_jsonb(min_size: int, max_size: int) -> strategies.SearchStrategy:
"""Generate values representable as a PostgreSQL jsonb object (dict)."""
return strategies.dictionaries(
strategies.text(printable),
strategies.recursive(
# should use floats() instead of integers(), but PostgreSQL
# coerces large integers into floats, making the tests fail. We
# only store ints in our generated data anyway.
strategies.none()
| strategies.booleans()
| strategies.integers(-2147483648, +2147483647)
| strategies.text(printable),
lambda children: strategies.lists(children, max_size=max_size)
| strategies.dictionaries(
strategies.text(printable), children, max_size=max_size
),
),
min_size=min_size,
max_size=max_size,
)
def tuple_2d_to_list_2d(v):
"""Convert a 2D tuple to a 2D list"""
return [list(inner) for inner in v]
def list_2d_to_tuple_2d(v):
"""Convert a 2D list to a 2D tuple"""
return tuple(tuple(inner) for inner in v)
class TestIntEnum(IntEnum):
foo = 1
bar = 2
def now():
return datetime.datetime.now(tz=datetime.timezone.utc)
FIELDS = (
Field("i", "int", 1, pg_int),
Field("txt", "text", "foo", pg_text),
Field("bytes", "bytea", b"bar", strategies.binary()),
Field(
"bytes_array",
"bytea[]",
[b"baz1", b"baz2"],
pg_bytea_a(min_size=0, max_size=5),
),
Field(
"bytes_tuple",
"bytea[]",
(b"baz1", b"baz2"),
pg_bytea_a(min_size=0, max_size=5).map(tuple),
in_wrapper=list,
out_converter=tuple,
),
Field(
"bytes_2d",
"bytea[][]",
[[b"quux1"], [b"quux2"]],
pg_bytea_a_a(min_size=0, max_size=5),
),
Field(
"bytes_2d_tuple",
"bytea[][]",
((b"quux1",), (b"quux2",)),
pg_bytea_a_a(min_size=0, max_size=5).map(list_2d_to_tuple_2d),
in_wrapper=tuple_2d_to_list_2d,
out_converter=list_2d_to_tuple_2d,
),
Field("ts", "timestamptz", now(), pg_tstz(),),
Field(
"dict",
"jsonb",
{"str": "bar", "int": 1, "list": ["a", "b"], "nested": {"a": "b"}},
pg_jsonb(min_size=0, max_size=5),
in_wrapper=psycopg2.extras.Json,
),
Field(
"intenum",
"int",
TestIntEnum.foo,
strategies.sampled_from(TestIntEnum),
in_wrapper=int,
out_converter=TestIntEnum,
),
Field("uuid", "uuid", uuid.uuid4(), strategies.uuids()),
Field(
"text_list",
"text[]",
# All the funky corner cases
["null", "NULL", None, "\\", "\t", "\n", "\r", " ", "'", ",", '"', "{", "}"],
strategies.lists(pg_text, min_size=0, max_size=5),
),
Field(
"tstz_list",
"timestamptz[]",
[now(), now() + datetime.timedelta(days=1)],
strategies.lists(pg_tstz(), min_size=0, max_size=5),
),
Field(
"tstz_range",
"tstzrange",
psycopg2.extras.DateTimeTZRange(
lower=now(), upper=now() + datetime.timedelta(days=1), bounds="[)",
),
strategies.tuples(
# generate two sorted timestamptzs for use as bounds
strategies.tuples(pg_tstz(), pg_tstz()).map(sorted),
# and a set of bounds
strategies.sampled_from(["[]", "()", "[)", "(]"]),
).map(
# and build the actual DateTimeTZRange object from these args
lambda args: psycopg2.extras.DateTimeTZRange(
lower=args[0][0], upper=args[0][1], bounds=args[1],
)
),
),
)
INIT_SQL = "create table test_table (%s)" % ", ".join(
f"{field.name} {field.pg_type}" for field in FIELDS
)
COLUMNS = tuple(field.name for field in FIELDS)
INSERT_SQL = "insert into test_table (%s) values (%s)" % (
", ".join(COLUMNS),
", ".join("%s" for i in range(len(COLUMNS))),
)
STATIC_ROW_IN = tuple(field.in_wrapper(field.example) for field in FIELDS)
EXPECTED_ROW_OUT = tuple(field.example for field in FIELDS)
db_rows = strategies.lists(strategies.tuples(*(field.strategy for field in FIELDS)))
def convert_lines(cur):
return [
tuple(field.out_converter(x) for x, field in zip(line, FIELDS)) for line in cur
]
test_db = postgresql_fact("postgresql_proc", dbname="test-db2")
@pytest.fixture
def db_with_data(test_db, request):
"""Fixture to initialize a db with some data out of the "INIT_SQL above
"""
db = BaseDb.connect(test_db.dsn)
with db.cursor() as cur:
psycopg2.extras.register_default_jsonb(cur)
cur.execute(INIT_SQL)
yield db
db.conn.rollback()
db.conn.close()
@pytest.mark.db
def test_db_connect(db_with_data):
with db_with_data.cursor() as cur:
psycopg2.extras.register_default_jsonb(cur)
cur.execute(INSERT_SQL, STATIC_ROW_IN)
cur.execute("select * from test_table;")
output = convert_lines(cur)
assert len(output) == 1
assert EXPECTED_ROW_OUT == output[0]
def test_db_initialized(db_with_data):
with db_with_data.cursor() as cur:
psycopg2.extras.register_default_jsonb(cur)
cur.execute(INSERT_SQL, STATIC_ROW_IN)
cur.execute("select * from test_table;")
output = convert_lines(cur)
assert len(output) == 1
assert EXPECTED_ROW_OUT == output[0]
def test_db_copy_to_static(db_with_data):
items = [{field.name: field.example for field in FIELDS}]
db_with_data.copy_to(items, "test_table", COLUMNS)
with db_with_data.cursor() as cur:
cur.execute("select * from test_table;")
output = convert_lines(cur)
assert len(output) == 1
assert EXPECTED_ROW_OUT == output[0]
@settings(suppress_health_check=function_scoped_fixture_check)
@given(db_rows)
def test_db_copy_to(db_with_data, data):
items = [dict(zip(COLUMNS, item)) for item in data]
with db_with_data.cursor() as cur:
cur.execute("TRUNCATE TABLE test_table CASCADE")
db_with_data.copy_to(items, "test_table", COLUMNS)
with db_with_data.cursor() as cur:
cur.execute("select * from test_table;")
converted_lines = convert_lines(cur)
assert converted_lines == data
def test_db_copy_to_thread_exception(db_with_data):
data = [(2 ** 65, "foo", b"bar")]
items = [dict(zip(COLUMNS, item)) for item in data]
with pytest.raises(psycopg2.errors.NumericValueOutOfRange):
db_with_data.copy_to(items, "test_table", COLUMNS)
def test_db_transaction(mocker):
expected_cur = object()
called = False
class Storage:
@db_transaction()
def endpoint(self, cur=None, db=None):
nonlocal called
called = True
assert cur is expected_cur
storage = Storage()
# 'with storage.get_db().transaction() as cur:' should cause
# 'cur' to be 'expected_cur'
db_mock = Mock()
db_mock.transaction.return_value = MagicMock()
db_mock.transaction.return_value.__enter__.return_value = expected_cur
mocker.patch.object(storage, "get_db", return_value=db_mock, create=True)
put_db_mock = mocker.patch.object(storage, "put_db", create=True)
storage.endpoint()
assert called
put_db_mock.assert_called_once_with(db_mock)
def test_db_transaction__with_generator():
with pytest.raises(ValueError, match="generator"):
class Storage:
@db_transaction()
def endpoint(self, cur=None, db=None):
yield None
def test_db_transaction_signature():
"""Checks db_transaction removes the 'cur' and 'db' arguments."""
def f(self, foo, *, bar=None):
pass
expected_sig = inspect.signature(f)
@db_transaction()
def g(self, foo, *, bar=None, db=None, cur=None):
pass
actual_sig = inspect.signature(g)
assert actual_sig == expected_sig
def test_db_transaction_generator(mocker):
expected_cur = object()
called = False
class Storage:
@db_transaction_generator()
def endpoint(self, cur=None, db=None):
nonlocal called
called = True
assert cur is expected_cur
yield None
storage = Storage()
# 'with storage.get_db().transaction() as cur:' should cause
# 'cur' to be 'expected_cur'
db_mock = Mock()
db_mock.transaction.return_value = MagicMock()
db_mock.transaction.return_value.__enter__.return_value = expected_cur
mocker.patch.object(storage, "get_db", return_value=db_mock, create=True)
put_db_mock = mocker.patch.object(storage, "put_db", create=True)
list(storage.endpoint())
assert called
put_db_mock.assert_called_once_with(db_mock)
def test_db_transaction_generator__with_nongenerator():
with pytest.raises(ValueError, match="generator"):
class Storage:
@db_transaction_generator()
def endpoint(self, cur=None, db=None):
pass
def test_db_transaction_generator_signature():
"""Checks db_transaction removes the 'cur' and 'db' arguments."""
def f(self, foo, *, bar=None):
pass
expected_sig = inspect.signature(f)
@db_transaction_generator()
def g(self, foo, *, bar=None, db=None, cur=None):
yield None
actual_sig = inspect.signature(g)
assert actual_sig == expected_sig

File Metadata

Mime Type
text/x-python
Expires
Sat, Jun 21, 6:50 PM (2 w, 8 h ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3390274

Event Timeline