Page MenuHomeSoftware Heritage
Paste P875

(An Untitled Masterwork)
ActivePublic

Authored by anlambert on Nov 23 2020, 6:07 PM.
diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py
index bb7f80b..6b7529c 100644
--- a/swh/core/db/tests/test_db.py
+++ b/swh/core/db/tests/test_db.py
@@ -7,11 +7,8 @@ from dataclasses import dataclass
import datetime
from enum import IntEnum
import inspect
-import os.path
from string import printable
-import tempfile
from typing import Any
-import unittest
from unittest.mock import MagicMock, Mock
import uuid
@@ -23,8 +20,7 @@ from typing_extensions import Protocol
from swh.core.db import BaseDb
from swh.core.db.common import db_transaction, db_transaction_generator
-
-from .db_testing import SingleDbTestFixture, db_close, db_create, db_destroy
+from swh.core.db.pytest_plugin import postgresql_fact
# workaround mypy bug https://github.com/python/mypy/issues/5485
@@ -238,89 +234,57 @@ def convert_lines(cur):
]
-@pytest.mark.db
-def test_connect():
- db_name = db_create("test-db2", dumps=[])
- try:
- db = BaseDb.connect("dbname=%s" % db_name)
- with db.cursor() as cur:
- psycopg2.extras.register_default_jsonb(cur)
- cur.execute(INIT_SQL)
- cur.execute(INSERT_SQL, STATIC_ROW_IN)
- cur.execute("select * from test_table;")
- output = convert_lines(cur)
- assert len(output) == 1
- assert EXPECTED_ROW_OUT == output[0]
- finally:
- db_close(db.conn)
- db_destroy(db_name)
-
+postgres_test_db = postgresql_fact("postgresql_proc", db_name="test-db")
-@pytest.mark.db
-class TestDb(SingleDbTestFixture, unittest.TestCase):
- TEST_DB_NAME = "test-db"
- @classmethod
- def setUpClass(cls):
- with tempfile.TemporaryDirectory() as td:
- with open(os.path.join(td, "init.sql"), "a") as fd:
- fd.write(INIT_SQL)
-
- cls.TEST_DB_DUMP = os.path.join(td, "*.sql")
-
- super().setUpClass()
+@pytest.fixture
+def db(postgres_test_db):
+ db = BaseDb.connect(postgres_test_db.dsn)
+ with db.cursor() as cur:
+ psycopg2.extras.register_default_jsonb(cur)
+ cur.execute(INIT_SQL)
+ yield db
+ db.conn.rollback()
+ db.conn.close()
- def setUp(self):
- super().setUp()
- self.db = BaseDb(self.conn)
- def test_initialized(self):
- cur = self.db.cursor()
- psycopg2.extras.register_default_jsonb(cur)
+@pytest.mark.db
+def test_db_connect(db):
+ with db.cursor() as cur:
cur.execute(INSERT_SQL, STATIC_ROW_IN)
cur.execute("select * from test_table;")
output = convert_lines(cur)
assert len(output) == 1
assert EXPECTED_ROW_OUT == output[0]
- def test_reset_tables(self):
- cur = self.db.cursor()
- cur.execute(INSERT_SQL, STATIC_ROW_IN)
- self.reset_db_tables("test-db")
- cur.execute("select * from test_table;")
- assert convert_lines(cur) == []
-
- def test_copy_to_static(self):
- items = [{field.name: field.example for field in FIELDS}]
- self.db.copy_to(items, "test_table", COLUMNS)
- cur = self.db.cursor()
+def test_db_copy_to_static(db):
+ items = [{field.name: field.example for field in FIELDS}]
+ with db.cursor() as cur:
+ db.copy_to(items, "test_table", COLUMNS)
cur.execute("select * from test_table;")
output = convert_lines(cur)
assert len(output) == 1
assert EXPECTED_ROW_OUT == output[0]
+ cur.execute("drop table test_table")
- @given(db_rows)
- def test_copy_to(self, data):
- try:
- # the table is not reset between runs by hypothesis
- self.reset_db_tables("test-db")
- items = [dict(zip(COLUMNS, item)) for item in data]
- self.db.copy_to(items, "test_table", COLUMNS)
+@given(db_rows)
+def test_db_copy_to(db, data):
+ items = [dict(zip(COLUMNS, item)) for item in data]
+ with db.cursor() as cur:
+ cur.execute("truncate table test_table")
+ db.copy_to(items, "test_table", COLUMNS)
+ cur.execute("select * from test_table;")
+ assert convert_lines(cur) == data
- cur = self.db.cursor()
- cur.execute("select * from test_table;")
- assert convert_lines(cur) == data
- finally:
- self.db.conn.rollback()
- def test_copy_to_thread_exception(self):
- data = [(2 ** 65, "foo", b"bar")]
+def test_copy_to_thread_exception(db):
+ data = [(2 ** 65, "foo", b"bar")]
- items = [dict(zip(COLUMNS, item)) for item in data]
- with self.assertRaises(psycopg2.errors.NumericValueOutOfRange):
- self.db.copy_to(items, "test_table", COLUMNS)
+ items = [dict(zip(COLUMNS, item)) for item in data]
+ with pytest.raises(psycopg2.errors.NumericValueOutOfRange):
+ db.copy_to(items, "test_table", COLUMNS)
def test_db_transaction(mocker):