Changeset View
Changeset View
Standalone View
Standalone View
swh/core/db/tests/test_db.py
Show All 9 Lines | |||||
from unittest.mock import Mock, MagicMock | from unittest.mock import Mock, MagicMock | ||||
from hypothesis import strategies, given | from hypothesis import strategies, given | ||||
import psycopg2 | import psycopg2 | ||||
import pytest | import pytest | ||||
from swh.core.db import BaseDb | from swh.core.db import BaseDb | ||||
from swh.core.db.common import db_transaction, db_transaction_generator | from swh.core.db.common import db_transaction, db_transaction_generator | ||||
from .db_testing import ( | from .db_testing import SingleDbTestFixture, db_create, db_destroy, db_close | ||||
SingleDbTestFixture, db_create, db_destroy, db_close, | |||||
) | |||||
INIT_SQL = ''' | INIT_SQL = """ | ||||
create table test_table | create table test_table | ||||
( | ( | ||||
i int, | i int, | ||||
txt text, | txt text, | ||||
bytes bytea | bytes bytea | ||||
); | ); | ||||
''' | """ | ||||
db_rows = strategies.lists(strategies.tuples( | db_rows = strategies.lists( | ||||
strategies.tuples( | |||||
strategies.integers(-2147483648, +2147483647), | strategies.integers(-2147483648, +2147483647), | ||||
strategies.text( | strategies.text( | ||||
alphabet=strategies.characters( | alphabet=strategies.characters( | ||||
blacklist_categories=['Cs'], # surrogates | blacklist_categories=["Cs"], # surrogates | ||||
blacklist_characters=[ | blacklist_characters=[ | ||||
'\x00', # pgsql does not support the null codepoint | "\x00", # pgsql does not support the null codepoint | ||||
'\r', # pgsql normalizes those | "\r", # pgsql normalizes those | ||||
] | ], | ||||
), | ) | ||||
), | ), | ||||
strategies.binary(), | strategies.binary(), | ||||
)) | ) | ||||
) | |||||
@pytest.mark.db | @pytest.mark.db | ||||
def test_connect(): | def test_connect(): | ||||
db_name = db_create('test-db2', dumps=[]) | db_name = db_create("test-db2", dumps=[]) | ||||
try: | try: | ||||
db = BaseDb.connect('dbname=%s' % db_name) | db = BaseDb.connect("dbname=%s" % db_name) | ||||
with db.cursor() as cur: | with db.cursor() as cur: | ||||
cur.execute(INIT_SQL) | cur.execute(INIT_SQL) | ||||
cur.execute("insert into test_table values (1, %s, %s);", | cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar")) | ||||
('foo', b'bar')) | |||||
cur.execute("select * from test_table;") | cur.execute("select * from test_table;") | ||||
assert list(cur) == [(1, 'foo', b'bar')] | assert list(cur) == [(1, "foo", b"bar")] | ||||
finally: | finally: | ||||
db_close(db.conn) | db_close(db.conn) | ||||
db_destroy(db_name) | db_destroy(db_name) | ||||
@pytest.mark.db | @pytest.mark.db | ||||
class TestDb(SingleDbTestFixture, unittest.TestCase): | class TestDb(SingleDbTestFixture, unittest.TestCase): | ||||
TEST_DB_NAME = 'test-db' | TEST_DB_NAME = "test-db" | ||||
@classmethod | @classmethod | ||||
def setUpClass(cls): | def setUpClass(cls): | ||||
with tempfile.TemporaryDirectory() as td: | with tempfile.TemporaryDirectory() as td: | ||||
with open(os.path.join(td, 'init.sql'), 'a') as fd: | with open(os.path.join(td, "init.sql"), "a") as fd: | ||||
fd.write(INIT_SQL) | fd.write(INIT_SQL) | ||||
cls.TEST_DB_DUMP = os.path.join(td, '*.sql') | cls.TEST_DB_DUMP = os.path.join(td, "*.sql") | ||||
super().setUpClass() | super().setUpClass() | ||||
def setUp(self): | def setUp(self): | ||||
super().setUp() | super().setUp() | ||||
self.db = BaseDb(self.conn) | self.db = BaseDb(self.conn) | ||||
def test_initialized(self): | def test_initialized(self): | ||||
cur = self.db.cursor() | cur = self.db.cursor() | ||||
cur.execute("insert into test_table values (1, %s, %s);", | cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar")) | ||||
('foo', b'bar')) | |||||
cur.execute("select * from test_table;") | cur.execute("select * from test_table;") | ||||
self.assertEqual(list(cur), [(1, 'foo', b'bar')]) | self.assertEqual(list(cur), [(1, "foo", b"bar")]) | ||||
def test_reset_tables(self): | def test_reset_tables(self): | ||||
cur = self.db.cursor() | cur = self.db.cursor() | ||||
cur.execute("insert into test_table values (1, %s, %s);", | cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar")) | ||||
('foo', b'bar')) | self.reset_db_tables("test-db") | ||||
self.reset_db_tables('test-db') | |||||
cur.execute("select * from test_table;") | cur.execute("select * from test_table;") | ||||
self.assertEqual(list(cur), []) | self.assertEqual(list(cur), []) | ||||
@given(db_rows) | @given(db_rows) | ||||
def test_copy_to(self, data): | def test_copy_to(self, data): | ||||
# the table is not reset between runs by hypothesis | # the table is not reset between runs by hypothesis | ||||
self.reset_db_tables('test-db') | self.reset_db_tables("test-db") | ||||
items = [dict(zip(['i', 'txt', 'bytes'], item)) for item in data] | items = [dict(zip(["i", "txt", "bytes"], item)) for item in data] | ||||
self.db.copy_to(items, 'test_table', ['i', 'txt', 'bytes']) | self.db.copy_to(items, "test_table", ["i", "txt", "bytes"]) | ||||
cur = self.db.cursor() | cur = self.db.cursor() | ||||
cur.execute('select * from test_table;') | cur.execute("select * from test_table;") | ||||
self.assertCountEqual(list(cur), data) | self.assertCountEqual(list(cur), data) | ||||
def test_copy_to_thread_exception(self): | def test_copy_to_thread_exception(self): | ||||
data = [(2**65, 'foo', b'bar')] | data = [(2 ** 65, "foo", b"bar")] | ||||
items = [dict(zip(['i', 'txt', 'bytes'], item)) for item in data] | items = [dict(zip(["i", "txt", "bytes"], item)) for item in data] | ||||
with self.assertRaises(psycopg2.errors.NumericValueOutOfRange): | with self.assertRaises(psycopg2.errors.NumericValueOutOfRange): | ||||
self.db.copy_to(items, 'test_table', ['i', 'txt', 'bytes']) | self.db.copy_to(items, "test_table", ["i", "txt", "bytes"]) | ||||
def test_db_transaction(mocker): | def test_db_transaction(mocker): | ||||
expected_cur = object() | expected_cur = object() | ||||
called = False | called = False | ||||
class Storage: | class Storage: | ||||
@db_transaction() | @db_transaction() | ||||
def endpoint(self, cur=None, db=None): | def endpoint(self, cur=None, db=None): | ||||
nonlocal called | nonlocal called | ||||
called = True | called = True | ||||
assert cur is expected_cur | assert cur is expected_cur | ||||
storage = Storage() | storage = Storage() | ||||
# 'with storage.get_db().transaction() as cur:' should cause | # 'with storage.get_db().transaction() as cur:' should cause | ||||
# 'cur' to be 'expected_cur' | # 'cur' to be 'expected_cur' | ||||
db_mock = Mock() | db_mock = Mock() | ||||
db_mock.transaction.return_value = MagicMock() | db_mock.transaction.return_value = MagicMock() | ||||
db_mock.transaction.return_value.__enter__.return_value = expected_cur | db_mock.transaction.return_value.__enter__.return_value = expected_cur | ||||
mocker.patch.object( | mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) | ||||
storage, 'get_db', return_value=db_mock, create=True) | |||||
put_db_mock = mocker.patch.object( | put_db_mock = mocker.patch.object(storage, "put_db", create=True) | ||||
storage, 'put_db', create=True) | |||||
storage.endpoint() | storage.endpoint() | ||||
assert called | assert called | ||||
put_db_mock.assert_called_once_with(db_mock) | put_db_mock.assert_called_once_with(db_mock) | ||||
def test_db_transaction__with_generator(): | def test_db_transaction__with_generator(): | ||||
with pytest.raises(ValueError, match='generator'): | with pytest.raises(ValueError, match="generator"): | ||||
class Storage: | class Storage: | ||||
@db_transaction() | @db_transaction() | ||||
def endpoint(self, cur=None, db=None): | def endpoint(self, cur=None, db=None): | ||||
yield None | yield None | ||||
def test_db_transaction_signature(): | def test_db_transaction_signature(): | ||||
"""Checks db_transaction removes the 'cur' and 'db' arguments.""" | """Checks db_transaction removes the 'cur' and 'db' arguments.""" | ||||
def f(self, foo, *, bar=None): | def f(self, foo, *, bar=None): | ||||
pass | pass | ||||
expected_sig = inspect.signature(f) | expected_sig = inspect.signature(f) | ||||
@db_transaction() | @db_transaction() | ||||
def g(self, foo, *, bar=None, db=None, cur=None): | def g(self, foo, *, bar=None, db=None, cur=None): | ||||
pass | pass | ||||
actual_sig = inspect.signature(g) | actual_sig = inspect.signature(g) | ||||
Show All 15 Lines | def test_db_transaction_generator(mocker): | ||||
storage = Storage() | storage = Storage() | ||||
# 'with storage.get_db().transaction() as cur:' should cause | # 'with storage.get_db().transaction() as cur:' should cause | ||||
# 'cur' to be 'expected_cur' | # 'cur' to be 'expected_cur' | ||||
db_mock = Mock() | db_mock = Mock() | ||||
db_mock.transaction.return_value = MagicMock() | db_mock.transaction.return_value = MagicMock() | ||||
db_mock.transaction.return_value.__enter__.return_value = expected_cur | db_mock.transaction.return_value.__enter__.return_value = expected_cur | ||||
mocker.patch.object( | mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) | ||||
storage, 'get_db', return_value=db_mock, create=True) | |||||
put_db_mock = mocker.patch.object( | put_db_mock = mocker.patch.object(storage, "put_db", create=True) | ||||
storage, 'put_db', create=True) | |||||
list(storage.endpoint()) | list(storage.endpoint()) | ||||
assert called | assert called | ||||
put_db_mock.assert_called_once_with(db_mock) | put_db_mock.assert_called_once_with(db_mock) | ||||
def test_db_transaction_generator__with_nongenerator(): | def test_db_transaction_generator__with_nongenerator(): | ||||
with pytest.raises(ValueError, match='generator'): | with pytest.raises(ValueError, match="generator"): | ||||
class Storage: | class Storage: | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def endpoint(self, cur=None, db=None): | def endpoint(self, cur=None, db=None): | ||||
pass | pass | ||||
def test_db_transaction_generator_signature(): | def test_db_transaction_generator_signature(): | ||||
"""Checks db_transaction removes the 'cur' and 'db' arguments.""" | """Checks db_transaction removes the 'cur' and 'db' arguments.""" | ||||
def f(self, foo, *, bar=None): | def f(self, foo, *, bar=None): | ||||
pass | pass | ||||
expected_sig = inspect.signature(f) | expected_sig = inspect.signature(f) | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def g(self, foo, *, bar=None, db=None, cur=None): | def g(self, foo, *, bar=None, db=None, cur=None): | ||||
yield None | yield None | ||||
actual_sig = inspect.signature(g) | actual_sig = inspect.signature(g) | ||||
assert actual_sig == expected_sig | assert actual_sig == expected_sig |