diff --git a/swh/core/db/common.py b/swh/core/db/common.py --- a/swh/core/db/common.py +++ b/swh/core/db/common.py @@ -3,8 +3,9 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import inspect import functools +import inspect +import sys def remove_kwargs(names): @@ -33,6 +34,21 @@ return old_options +def check_no_transaction(f_name): + """Checks there is no open DB transaction in the stack + (db + cur variables). + While it is not necessarily broken code, it is very likely + a mistake.""" + frame = sys._getframe().f_back.f_back + while frame: + if "db" in frame.f_locals and "cur" in frame.f_locals: + raise AssertionError( + f'Calling function {f_name} without "db" and "cur" arguments ' + "from a function ({frame.f_code.co_name}) with these variables." + ) + frame = frame.f_back + + def db_transaction(**client_options): """decorator to execute Backend methods within DB transactions @@ -55,6 +71,7 @@ apply_options(cur, old_options) return ret else: + check_no_transaction(meth.__name__) db = self.get_db() try: with db.transaction() as cur: @@ -90,6 +107,7 @@ yield from meth(self, *args, **kwargs) apply_options(cur, old_options) else: + check_no_transaction(meth.__name__) db = self.get_db() try: with db.transaction() as cur: diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py --- a/swh/core/db/tests/test_db.py +++ b/swh/core/db/tests/test_db.py @@ -170,6 +170,47 @@ assert actual_sig == expected_sig +def test_db_transaction_nested(mocker): + expected_cur = object() + + called = False + + class Storage: + @db_transaction() + def endpoint_bad_nesting(self, cur=None, db=None): + self.inner_endpoint() + + @db_transaction() + def endpoint_good_nesting(self, cur=None, db=None): + self.inner_endpoint(cur=cur, db=db) + + @db_transaction() + def inner_endpoint(self, cur=None, db=None): + nonlocal called + called = True + + storage = Storage() + + # 'with storage.get_db().transaction() as cur:' should cause + # 'cur' to be 'expected_cur' + db_mock = Mock() + db_mock.transaction.return_value = MagicMock() + db_mock.transaction.return_value.__enter__.return_value = expected_cur + mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) + + put_db_mock = mocker.patch.object(storage, "put_db", create=True) + + with pytest.raises(AssertionError, match="Calling function inner_endpoint"): + storage.endpoint_bad_nesting() + assert not called + put_db_mock.assert_called_once_with(db_mock) + put_db_mock.reset_mock() + + storage.endpoint_good_nesting() + assert called + put_db_mock.assert_called_once_with(db_mock) + + def test_db_transaction_generator(mocker): expected_cur = object()