# swh-core diff --git a/swh/core/db/common.py b/swh/core/db/common.py index 65e6d75..6f89058 100644 --- a/swh/core/db/common.py +++ b/swh/core/db/common.py @@ -48,9 +48,14 @@ def db_transaction(**client_options): @remove_kwargs(["cur", "db"]) @functools.wraps(meth) def _meth(self, *args, **kwargs): + options = getattr(self, "query_options", None) or {} + if meth.__name__ in options: + client_options = {**__client_options, **options[meth.__name__]} + else: + client_options = __client_options if "cur" in kwargs and kwargs["cur"]: cur = kwargs["cur"] - old_options = apply_options(cur, __client_options) + old_options = apply_options(cur, client_options) ret = meth(self, *args, **kwargs) apply_options(cur, old_options) return ret @@ -58,7 +63,7 @@ def db_transaction(**client_options): db = self.get_db() try: with db.transaction() as cur: - apply_options(cur, __client_options) + apply_options(cur, client_options) return meth(self, *args, db=db, cur=cur, **kwargs) finally: self.put_db(db) # swh-storage diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py index c33551e1..57e5ce8f 100644 --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -105,7 +105,8 @@ class Storage: """ def __init__( - self, db, objstorage, min_pool_conns=1, max_pool_conns=10, journal_writer=None + self, db, objstorage, min_pool_conns=1, max_pool_conns=10, journal_writer=None, + query_options=None ): """ Args: @@ -130,6 +131,7 @@ class Storage: self.journal_writer = JournalWriter(journal_writer) self.objstorage = ObjStorage(objstorage) + self.query_options = query_options def get_db(self): if self._db: diff --git a/swh/storage/tests/test_postgresql.py b/swh/storage/tests/test_postgresql.py index 4509ff9e..eaecb363 100644 --- a/swh/storage/tests/test_postgresql.py +++ b/swh/storage/tests/test_postgresql.py @@ -52,6 +52,16 @@ class TestLocalStorage: missing = list(swh_storage.content_missing([content.hashes()])) assert missing == [content.sha1] + def test_custom_timeout(self, swh_storage, sample_data): + assert swh_storage.query_options is None + pop = list(range(256)) + from random import choices + from psycopg2.errors import QueryCanceled + swh_storage.query_options = {"content_get": {"statement_timeout": 1}} + hashes = [bytes(choices(pop, k=20)) for i in range(100000)] + with pytest.raises(QueryCanceled): + swh_storage.content_get(hashes) + @pytest.mark.db class TestStorageRaceConditions: