Changeset View
Changeset View
Standalone View
Standalone View
swh/core/db/tests/test_db.py
Show First 20 Lines • Show All 164 Lines • ▼ Show 20 Lines | def test_db_transaction_signature(): | ||||
def g(self, foo, *, bar=None, db=None, cur=None): | def g(self, foo, *, bar=None, db=None, cur=None): | ||||
pass | pass | ||||
actual_sig = inspect.signature(g) | actual_sig = inspect.signature(g) | ||||
assert actual_sig == expected_sig | assert actual_sig == expected_sig | ||||
def test_db_transaction_nested(mocker): | |||||
expected_cur = object() | |||||
called = False | |||||
class Storage: | |||||
@db_transaction() | |||||
def endpoint_bad_nesting(self, cur=None, db=None): | |||||
self.inner_endpoint() | |||||
@db_transaction() | |||||
def endpoint_good_nesting(self, cur=None, db=None): | |||||
self.inner_endpoint(cur=cur, db=db) | |||||
@db_transaction() | |||||
def inner_endpoint(self, cur=None, db=None): | |||||
nonlocal called | |||||
called = True | |||||
storage = Storage() | |||||
# 'with storage.get_db().transaction() as cur:' should cause | |||||
# 'cur' to be 'expected_cur' | |||||
db_mock = Mock() | |||||
db_mock.transaction.return_value = MagicMock() | |||||
db_mock.transaction.return_value.__enter__.return_value = expected_cur | |||||
mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) | |||||
put_db_mock = mocker.patch.object(storage, "put_db", create=True) | |||||
with pytest.raises(AssertionError, match="Calling function inner_endpoint"): | |||||
storage.endpoint_bad_nesting() | |||||
assert not called | |||||
put_db_mock.assert_called_once_with(db_mock) | |||||
put_db_mock.reset_mock() | |||||
storage.endpoint_good_nesting() | |||||
assert called | |||||
put_db_mock.assert_called_once_with(db_mock) | |||||
def test_db_transaction_generator(mocker): | def test_db_transaction_generator(mocker): | ||||
expected_cur = object() | expected_cur = object() | ||||
called = False | called = False | ||||
class Storage: | class Storage: | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def endpoint(self, cur=None, db=None): | def endpoint(self, cur=None, db=None): | ||||
▲ Show 20 Lines • Show All 46 Lines • Show Last 20 Lines |