diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py --- a/swh/storage/tests/test_api_client.py +++ b/swh/storage/tests/test_api_client.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from contextlib import contextmanager import shutil import tempfile import unittest @@ -16,6 +17,7 @@ from swh.storage.api.server import app from swh.storage.in_memory import Storage as InMemoryStorage import swh.storage.storage +from swh.storage.db import Db from swh.storage.tests.test_storage import \ CommonTestStorage, CommonPropTestStorage, StorageTestDbFixture @@ -73,6 +75,10 @@ self.reset_db_tables(self.TEST_DB_NAME, excluded=excluded) self.journal_writer.objects[:] = [] + @contextmanager + def get_db(self): + yield Db(self.conn) + class RemoteMemStorageFixture(ServerTestFixture, unittest.TestCase): def setUp(self): diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -4,6 +4,7 @@ # See top-level LICENSE file for more information import copy +from contextlib import contextmanager import datetime import itertools import queue @@ -46,6 +47,15 @@ self.storage._pool.closeall() super().tearDown() + def get_db(self): + return self.storage.db() + + @contextmanager + def db_transaction(self): + with self.get_db() as db: + with db.transaction() as cur: + yield db, cur + class TestStorageData: def setUp(self, *args, **kwargs):