diff --git a/swh/core/db/common.py b/swh/core/db/common.py --- a/swh/core/db/common.py +++ b/swh/core/db/common.py @@ -38,7 +38,13 @@ The decorated method must accept a `cur` and `db` keyword argument - Client options are passed as `set` options to the postgresql server + Client options are passed as `set` options to the postgresql server. If + available, decorated `self.query_options` can be defined as a dict which + keys are (decorated) method names and values are dicts. These later dicts + are merged with the given `client+_options`. So it's possible to define + default client_options as decorator arguments and overload them from e.g. a + configuration file (e.g. making is the `self.query_options` attribute filled + from a config file). """ def decorator(meth, __client_options=client_options): @@ -48,9 +54,14 @@ @remove_kwargs(["cur", "db"]) @functools.wraps(meth) def _meth(self, *args, **kwargs): + options = getattr(self, "query_options", None) or {} + if meth.__name__ in options: + client_options = {**__client_options, **options[meth.__name__]} + else: + client_options = __client_options if "cur" in kwargs and kwargs["cur"]: cur = kwargs["cur"] - old_options = apply_options(cur, __client_options) + old_options = apply_options(cur, client_options) ret = meth(self, *args, **kwargs) apply_options(cur, old_options) return ret @@ -58,7 +69,7 @@ db = self.get_db() try: with db.transaction() as cur: - apply_options(cur, __client_options) + apply_options(cur, client_options) return meth(self, *args, db=db, cur=cur, **kwargs) finally: self.put_db(db) @@ -74,7 +85,13 @@ The decorated method must accept a `cur` and `db` keyword argument - Client options are passed as `set` options to the postgresql server + Client options are passed as `set` options to the postgresql server. If + available, decorated `self.query_options` can be defined as a dict which + keys are (decorated) method names and values are dicts. These later dicts + are merged with the given `client+_options`. So it's possible to define + default client_options as decorator arguments and overload them from e.g. a + configuration file (e.g. making is the `self.query_options` attribute filled + from a config file). """ def decorator(meth, __client_options=client_options): @@ -84,16 +101,21 @@ @remove_kwargs(["cur", "db"]) @functools.wraps(meth) def _meth(self, *args, **kwargs): + options = getattr(self, "query_options", None) or {} + if meth.__name__ in options: + client_options = {**__client_options, **options[meth.__name__]} + else: + client_options = __client_options if "cur" in kwargs and kwargs["cur"]: cur = kwargs["cur"] - old_options = apply_options(cur, __client_options) + old_options = apply_options(cur, client_options) yield from meth(self, *args, **kwargs) apply_options(cur, old_options) else: db = self.get_db() try: with db.transaction() as cur: - apply_options(cur, __client_options) + apply_options(cur, client_options) yield from meth(self, *args, db=db, cur=cur, **kwargs) finally: self.put_db(db) diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py --- a/swh/core/db/tests/test_db.py +++ b/swh/core/db/tests/test_db.py @@ -361,6 +361,57 @@ assert actual_sig == expected_sig +def test_db_transaction_apply_options(mocker): + expected_cur = object() + + class Storage: + @db_transaction(statement_timeout=100) + def endpoint(self, cur=None, db=None): + return None + + mocked_apply = mocker.patch("swh.core.db.common.apply_options") + + storage = Storage() + + # 'with storage.get_db().transaction() as cur:' should cause + # 'cur' to be 'expected_cur' + db_mock = Mock() + db_mock.transaction.return_value = MagicMock() + db_mock.transaction.return_value.__enter__.return_value = expected_cur + mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) + mocker.patch.object(storage, "put_db", create=True) + storage.endpoint() + + mocked_apply.assert_called_once_with(expected_cur, {"statement_timeout": 100}) + + +def test_db_transaction_query_options(mocker): + expected_cur = object() + + class Storage: + @db_transaction(statement_timeout=100) + def endpoint(self, cur=None, db=None): + return None + + mocked_apply = mocker.patch("swh.core.db.common.apply_options") + + storage = Storage() + + # 'with storage.get_db().transaction() as cur:' should cause + # 'cur' to be 'expected_cur' + db_mock = Mock() + db_mock.transaction.return_value = MagicMock() + db_mock.transaction.return_value.__enter__.return_value = expected_cur + mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) + mocker.patch.object(storage, "put_db", create=True) + storage.query_options = {"endpoint": {"something": 42, "statement_timeout": 200}} + storage.endpoint() + + mocked_apply.assert_called_once_with( + expected_cur, {"something": 42, "statement_timeout": 200} + ) + + def test_db_transaction_generator(mocker): expected_cur = object() @@ -415,3 +466,54 @@ actual_sig = inspect.signature(g) assert actual_sig == expected_sig + + +def test_db_transaction_generator_apply_options(mocker): + expected_cur = object() + + class Storage: + @db_transaction_generator(statement_timeout=100) + def endpoint(self, cur=None, db=None): + yield None + + mocked_apply = mocker.patch("swh.core.db.common.apply_options") + + storage = Storage() + + # 'with storage.get_db().transaction() as cur:' should cause + # 'cur' to be 'expected_cur' + db_mock = Mock() + db_mock.transaction.return_value = MagicMock() + db_mock.transaction.return_value.__enter__.return_value = expected_cur + mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) + mocker.patch.object(storage, "put_db", create=True) + list(storage.endpoint()) + + mocked_apply.assert_called_once_with(expected_cur, {"statement_timeout": 100}) + + +def test_db_transaction_generator_query_options(mocker): + expected_cur = object() + + class Storage: + @db_transaction_generator(statement_timeout=100) + def endpoint(self, cur=None, db=None): + yield None + + mocked_apply = mocker.patch("swh.core.db.common.apply_options") + + storage = Storage() + + # 'with storage.get_db().transaction() as cur:' should cause + # 'cur' to be 'expected_cur' + db_mock = Mock() + db_mock.transaction.return_value = MagicMock() + db_mock.transaction.return_value.__enter__.return_value = expected_cur + mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) + mocker.patch.object(storage, "put_db", create=True) + storage.query_options = {"endpoint": {"something": 42, "statement_timeout": 200}} + list(storage.endpoint()) + + mocked_apply.assert_called_once_with( + expected_cur, {"something": 42, "statement_timeout": 200} + )