diff --git a/swh/storage/tests/storage_testing.py b/swh/storage/tests/storage_testing.py new file mode 100644 --- /dev/null +++ b/swh/storage/tests/storage_testing.py @@ -0,0 +1,34 @@ +# Copyright (C) 2015-2017 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import tempfile + +from swh.core.tests.db_testing import DbTestFixture +from swh.storage import get_storage + + +class StorageTestFixture(DbTestFixture): + def setUp(self): + super().setUp() + self.objtmp = tempfile.TemporaryDirectory() + + storage_conf = { + 'cls': 'local', + 'args': { + 'db': self.conn, + 'objstorage': { + 'cls': 'pathslicing', + 'args': { + 'root': self.objtmp.name, + 'slicing': '0:1/1:5', + }, + }, + }, + } + self.storage = get_storage(**storage_conf) + + def tearDown(self): + self.objtmp.cleanup() + super().tearDown() diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -8,8 +8,6 @@ from operator import itemgetter import os import psycopg2 -import shutil -import tempfile import unittest from uuid import UUID @@ -18,11 +16,10 @@ from nose.tools import istest from nose.plugins.attrib import attr -from swh.core.tests.db_testing import DbTestFixture from swh.model import identifiers from swh.model.hashutil import hash_to_bytes -from swh.storage import get_storage +from swh.core.tests.storage_testing import StorageTestFixture from swh.storage.db import cursor_to_bytes @@ -31,29 +28,12 @@ @attr('db') -class BaseTestStorage(DbTestFixture): +class BaseTestStorage(StorageTestFixture): TEST_DB_DUMP = os.path.join(TEST_DATA_DIR, 'dumps/swh.dump') def setUp(self): super().setUp() self.maxDiff = None - self.objroot = tempfile.mkdtemp() - - storage_conf = { - 'cls': 'local', - 'args': { - 'db': self.conn, - 'objstorage': { - 'cls': 'pathslicing', - 'args': { - 'root': self.objroot, - 'slicing': '0:2/2:4/4:6', - }, - }, - }, - } - - self.storage = get_storage(**storage_conf) self.cont = { 'data': b'42\n', @@ -567,8 +547,6 @@ } def tearDown(self): - shutil.rmtree(self.objroot) - self.cursor.execute("""SELECT table_name FROM information_schema.tables WHERE table_schema = %s""", ('public',))