diff --git a/requirements-test.txt b/requirements-test.txt --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,4 +1,5 @@ pytest +pytest-mock requests-mock hypothesis >= 3.11.0 pre-commit diff --git a/swh/core/db/common.py b/swh/core/db/common.py --- a/swh/core/db/common.py +++ b/swh/core/db/common.py @@ -7,6 +7,19 @@ import functools +def remove_kwargs(names): + def decorator(f): + sig = inspect.signature(f) + params = sig.parameters + params = [param for param in params.values() + if param.name not in names] + sig = sig.replace(parameters=params) + f.__signature__ = sig + return f + + return decorator + + def apply_options(cursor, options): """Applies the given postgresql client options to the given cursor. @@ -33,6 +46,7 @@ raise ValueError( 'Use db_transaction_generator for generator functions.') + @remove_kwargs(['cur', 'db']) @functools.wraps(meth) def _meth(self, *args, **kwargs): if 'cur' in kwargs and kwargs['cur']: @@ -67,6 +81,7 @@ raise ValueError( 'Use db_transaction for non-generator functions.') + @remove_kwargs(['cur', 'db']) @functools.wraps(meth) def _meth(self, *args, **kwargs): if 'cur' in kwargs and kwargs['cur']: diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py --- a/swh/core/db/tests/test_db.py +++ b/swh/core/db/tests/test_db.py @@ -3,15 +3,18 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import inspect import os.path import tempfile import unittest +from unittest.mock import Mock, MagicMock from hypothesis import strategies, given import psycopg2 import pytest from swh.core.db import BaseDb +from swh.core.db.common import db_transaction, db_transaction_generator from .db_testing import ( SingleDbTestFixture, db_create, db_destroy, db_close, ) @@ -108,3 +111,112 @@ items = [dict(zip(['i', 'txt', 'bytes'], item)) for item in data] with self.assertRaises(psycopg2.errors.NumericValueOutOfRange): self.db.copy_to(items, 'test_table', ['i', 'txt', 'bytes']) + + +def test_db_transaction(mocker): + expected_cur = object() + + called = False + + class Storage: + @db_transaction() + def endpoint(self, cur=None, db=None): + nonlocal called + called = True + assert cur is expected_cur + + storage = Storage() + + # 'with storage.get_db().transaction() as cur:' should cause + # 'cur' to be 'expected_cur' + db_mock = Mock() + db_mock.transaction.return_value = MagicMock() + db_mock.transaction.return_value.__enter__.return_value = expected_cur + mocker.patch.object( + storage, 'get_db', return_value=db_mock, create=True) + + put_db_mock = mocker.patch.object( + storage, 'put_db', create=True) + + storage.endpoint() + + assert called + put_db_mock.assert_called_once_with(db_mock) + + +def test_db_transaction__with_generator(): + with pytest.raises(ValueError, match='generator'): + class Storage: + @db_transaction() + def endpoint(self, cur=None, db=None): + yield None + + +def test_db_transaction_signature(): + """Checks db_transaction removes the 'cur' and 'db' arguments.""" + def f(self, foo, *, bar=None): + pass + expected_sig = inspect.signature(f) + + @db_transaction() + def g(self, foo, *, bar=None, db=None, cur=None): + pass + + actual_sig = inspect.signature(g) + + assert actual_sig == expected_sig + + +def test_db_transaction_generator(mocker): + expected_cur = object() + + called = False + + class Storage: + @db_transaction_generator() + def endpoint(self, cur=None, db=None): + nonlocal called + called = True + assert cur is expected_cur + yield None + + storage = Storage() + + # 'with storage.get_db().transaction() as cur:' should cause + # 'cur' to be 'expected_cur' + db_mock = Mock() + db_mock.transaction.return_value = MagicMock() + db_mock.transaction.return_value.__enter__.return_value = expected_cur + mocker.patch.object( + storage, 'get_db', return_value=db_mock, create=True) + + put_db_mock = mocker.patch.object( + storage, 'put_db', create=True) + + list(storage.endpoint()) + + assert called + put_db_mock.assert_called_once_with(db_mock) + + +def test_db_transaction_generator__with_nongenerator(): + with pytest.raises(ValueError, match='generator'): + class Storage: + @db_transaction_generator() + def endpoint(self, cur=None, db=None): + pass + + +def test_db_transaction_generator_signature(): + """Checks db_transaction removes the 'cur' and 'db' arguments.""" + def f(self, foo, *, bar=None): + pass + expected_sig = inspect.signature(f) + + @db_transaction_generator() + def g(self, foo, *, bar=None, db=None, cur=None): + yield None + + actual_sig = inspect.signature(g) + + assert actual_sig == expected_sig