diff --git a/swh/core/db/common.py b/swh/core/db/common.py --- a/swh/core/db/common.py +++ b/swh/core/db/common.py @@ -36,9 +36,15 @@ def db_transaction(**client_options): """decorator to execute Backend methods within DB transactions - The decorated method must accept a `cur` and `db` keyword argument - - Client options are passed as `set` options to the postgresql server + The decorated method must accept a ``cur`` and ``db`` keyword argument + + Client options are passed as ``set`` options to the postgresql server. If + available, decorated ``self.query_options`` can be defined as a dict which + keys are (decorated) method names and values are dicts. These later dicts + are merged with the given ``client_options``. So it's possible to define + default client_options as decorator arguments and overload them from e.g. a + configuration file (e.g. making is the ``self.query_options`` attribute filled + from a config file). """ def decorator(meth, __client_options=client_options): @@ -48,9 +54,14 @@ @remove_kwargs(["cur", "db"]) @functools.wraps(meth) def _meth(self, *args, **kwargs): + options = getattr(self, "query_options", None) or {} + if meth.__name__ in options: + client_options = {**__client_options, **options[meth.__name__]} + else: + client_options = __client_options if "cur" in kwargs and kwargs["cur"]: cur = kwargs["cur"] - old_options = apply_options(cur, __client_options) + old_options = apply_options(cur, client_options) ret = meth(self, *args, **kwargs) apply_options(cur, old_options) return ret @@ -58,7 +69,7 @@ db = self.get_db() try: with db.transaction() as cur: - apply_options(cur, __client_options) + apply_options(cur, client_options) return meth(self, *args, db=db, cur=cur, **kwargs) finally: self.put_db(db) @@ -72,9 +83,15 @@ """decorator to execute Backend methods within DB transactions, while returning a generator - The decorated method must accept a `cur` and `db` keyword argument + The decorated method must accept a ``cur`` and ``db`` keyword argument - Client options are passed as `set` options to the postgresql server + Client options are passed as ``set`` options to the postgresql server. If + available, decorated ``self.query_options`` can be defined as a dict which + keys are (decorated) method names and values are dicts. These later dicts + are merged with the given ``client_options``. So it's possible to define + default client_options as decorator arguments and overload them from e.g. a + configuration file (e.g. making is the ``self.query_options`` attribute filled + from a config file). """ def decorator(meth, __client_options=client_options): @@ -84,16 +101,21 @@ @remove_kwargs(["cur", "db"]) @functools.wraps(meth) def _meth(self, *args, **kwargs): + options = getattr(self, "query_options", None) or {} + if meth.__name__ in options: + client_options = {**__client_options, **options[meth.__name__]} + else: + client_options = __client_options if "cur" in kwargs and kwargs["cur"]: cur = kwargs["cur"] - old_options = apply_options(cur, __client_options) + old_options = apply_options(cur, client_options) yield from meth(self, *args, **kwargs) apply_options(cur, old_options) else: db = self.get_db() try: with db.transaction() as cur: - apply_options(cur, __client_options) + apply_options(cur, client_options) yield from meth(self, *args, db=db, cur=cur, **kwargs) finally: self.put_db(db) diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py --- a/swh/core/db/tests/test_db.py +++ b/swh/core/db/tests/test_db.py @@ -415,3 +415,45 @@ actual_sig = inspect.signature(g) assert actual_sig == expected_sig + + +@pytest.mark.parametrize( + "query_options", (None, {"something": 42, "statement_timeout": 200}) +) +@pytest.mark.parametrize("use_generator", (True, False)) +def test_db_transaction_query_options(mocker, use_generator, query_options): + class Storage: + @db_transaction(statement_timeout=100) + def endpoint(self, cur=None, db=None): + return [None] + + @db_transaction_generator(statement_timeout=100) + def gen_endpoint(self, cur=None, db=None): + yield None + + storage = Storage() + + # mockers + mocked_apply = mocker.patch("swh.core.db.common.apply_options") + # 'with storage.get_db().transaction() as cur:' should cause + # 'cur' to be 'expected_cur' + expected_cur = object() + db_mock = MagicMock() + db_mock.transaction.return_value.__enter__.return_value = expected_cur + mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) + mocker.patch.object(storage, "put_db", create=True) + + if query_options: + storage.query_options = { + "endpoint": query_options, + "gen_endpoint": query_options, + } + if use_generator: + list(storage.gen_endpoint()) + else: + list(storage.endpoint()) + + mocked_apply.assert_called_once_with( + expected_cur, + query_options if query_options is not None else {"statement_timeout": 100}, + )