diff --git a/swh/core/db/common.py b/swh/core/db/common.py --- a/swh/core/db/common.py +++ b/swh/core/db/common.py @@ -7,6 +7,19 @@ import functools +def remove_kwargs(names): + def decorator(f): + sig = inspect.signature(f) + params = sig.parameters + params = [param for param in params.values() + if param.name not in names] + sig = sig.replace(parameters=params) + f.__signature__ = sig + return f + + return decorator + + def apply_options(cursor, options): """Applies the given postgresql client options to the given cursor. @@ -33,7 +46,8 @@ raise ValueError( 'Use db_transaction_generator for generator functions.') - @functools.wraps(meth) + @remove_kwargs(['cur', 'db']) + @functools.wraps(meth, assigned=['cur', 'db']) def _meth(self, *args, **kwargs): if 'cur' in kwargs and kwargs['cur']: cur = kwargs['cur'] @@ -67,6 +81,7 @@ raise ValueError( 'Use db_transaction for non-generator functions.') + @remove_kwargs(['cur', 'db']) @functools.wraps(meth) def _meth(self, *args, **kwargs): if 'cur' in kwargs and kwargs['cur']: diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py --- a/swh/core/db/tests/test_db.py +++ b/swh/core/db/tests/test_db.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import inspect import os.path import tempfile import unittest @@ -151,6 +152,21 @@ yield None +def test_db_transaction_signature(): + """Checks db_transaction removes the 'cur' and 'db' arguments.""" + def f(self, foo, *, bar=None): + pass + expected_sig = inspect.signature(f) + + @db_transaction() + def g(self, foo, *, bar=None, db=None, cur=None): + pass + + actual_sig = inspect.signature(g) + + assert actual_sig == expected_sig + + def test_db_transaction_generator(mocker): expected_cur = object() @@ -189,3 +205,18 @@ @db_transaction_generator() def endpoint(self, cur=None, db=None): pass + + +def test_db_transaction_generator_signature(): + """Checks db_transaction removes the 'cur' and 'db' arguments.""" + def f(self, foo, *, bar=None): + pass + expected_sig = inspect.signature(f) + + @db_transaction_generator() + def g(self, foo, *, bar=None, db=None, cur=None): + yield None + + actual_sig = inspect.signature(g) + + assert actual_sig == expected_sig