diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py index bb7f80b..6b7529c 100644 --- a/swh/core/db/tests/test_db.py +++ b/swh/core/db/tests/test_db.py @@ -7,11 +7,8 @@ from dataclasses import dataclass import datetime from enum import IntEnum import inspect -import os.path from string import printable -import tempfile from typing import Any -import unittest from unittest.mock import MagicMock, Mock import uuid @@ -23,8 +20,7 @@ from typing_extensions import Protocol from swh.core.db import BaseDb from swh.core.db.common import db_transaction, db_transaction_generator - -from .db_testing import SingleDbTestFixture, db_close, db_create, db_destroy +from swh.core.db.pytest_plugin import postgresql_fact # workaround mypy bug https://github.com/python/mypy/issues/5485 @@ -238,89 +234,57 @@ def convert_lines(cur): ] -@pytest.mark.db -def test_connect(): - db_name = db_create("test-db2", dumps=[]) - try: - db = BaseDb.connect("dbname=%s" % db_name) - with db.cursor() as cur: - psycopg2.extras.register_default_jsonb(cur) - cur.execute(INIT_SQL) - cur.execute(INSERT_SQL, STATIC_ROW_IN) - cur.execute("select * from test_table;") - output = convert_lines(cur) - assert len(output) == 1 - assert EXPECTED_ROW_OUT == output[0] - finally: - db_close(db.conn) - db_destroy(db_name) - +postgres_test_db = postgresql_fact("postgresql_proc", db_name="test-db") -@pytest.mark.db -class TestDb(SingleDbTestFixture, unittest.TestCase): - TEST_DB_NAME = "test-db" - @classmethod - def setUpClass(cls): - with tempfile.TemporaryDirectory() as td: - with open(os.path.join(td, "init.sql"), "a") as fd: - fd.write(INIT_SQL) - - cls.TEST_DB_DUMP = os.path.join(td, "*.sql") - - super().setUpClass() +@pytest.fixture +def db(postgres_test_db): + db = BaseDb.connect(postgres_test_db.dsn) + with db.cursor() as cur: + psycopg2.extras.register_default_jsonb(cur) + cur.execute(INIT_SQL) + yield db + db.conn.rollback() + db.conn.close() - def setUp(self): - super().setUp() - self.db = BaseDb(self.conn) - def test_initialized(self): - cur = self.db.cursor() - psycopg2.extras.register_default_jsonb(cur) +@pytest.mark.db +def test_db_connect(db): + with db.cursor() as cur: cur.execute(INSERT_SQL, STATIC_ROW_IN) cur.execute("select * from test_table;") output = convert_lines(cur) assert len(output) == 1 assert EXPECTED_ROW_OUT == output[0] - def test_reset_tables(self): - cur = self.db.cursor() - cur.execute(INSERT_SQL, STATIC_ROW_IN) - self.reset_db_tables("test-db") - cur.execute("select * from test_table;") - assert convert_lines(cur) == [] - - def test_copy_to_static(self): - items = [{field.name: field.example for field in FIELDS}] - self.db.copy_to(items, "test_table", COLUMNS) - cur = self.db.cursor() +def test_db_copy_to_static(db): + items = [{field.name: field.example for field in FIELDS}] + with db.cursor() as cur: + db.copy_to(items, "test_table", COLUMNS) cur.execute("select * from test_table;") output = convert_lines(cur) assert len(output) == 1 assert EXPECTED_ROW_OUT == output[0] + cur.execute("drop table test_table") - @given(db_rows) - def test_copy_to(self, data): - try: - # the table is not reset between runs by hypothesis - self.reset_db_tables("test-db") - items = [dict(zip(COLUMNS, item)) for item in data] - self.db.copy_to(items, "test_table", COLUMNS) +@given(db_rows) +def test_db_copy_to(db, data): + items = [dict(zip(COLUMNS, item)) for item in data] + with db.cursor() as cur: + cur.execute("truncate table test_table") + db.copy_to(items, "test_table", COLUMNS) + cur.execute("select * from test_table;") + assert convert_lines(cur) == data - cur = self.db.cursor() - cur.execute("select * from test_table;") - assert convert_lines(cur) == data - finally: - self.db.conn.rollback() - def test_copy_to_thread_exception(self): - data = [(2 ** 65, "foo", b"bar")] +def test_copy_to_thread_exception(db): + data = [(2 ** 65, "foo", b"bar")] - items = [dict(zip(COLUMNS, item)) for item in data] - with self.assertRaises(psycopg2.errors.NumericValueOutOfRange): - self.db.copy_to(items, "test_table", COLUMNS) + items = [dict(zip(COLUMNS, item)) for item in data] + with pytest.raises(psycopg2.errors.NumericValueOutOfRange): + db.copy_to(items, "test_table", COLUMNS) def test_db_transaction(mocker):