diff --git a/swh/core/db/pytest_plugin.py b/swh/core/db/pytest_plugin.py index 8f5280e..1d6b1f7 100644 --- a/swh/core/db/pytest_plugin.py +++ b/swh/core/db/pytest_plugin.py @@ -1,175 +1,177 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import glob import logging import subprocess -from typing import Optional, Set, Union +from typing import List, Optional, Set, Union from _pytest.fixtures import FixtureRequest import psycopg2 import pytest from pytest_postgresql import factories from pytest_postgresql.janitor import DatabaseJanitor, Version from swh.core.utils import numfile_sortkey as sortkey logger = logging.getLogger(__name__) class SWHDatabaseJanitor(DatabaseJanitor): """SWH database janitor implementation with a a different setup/teardown policy than than the stock one. Instead of dropping, creating and initializing the database for each test, it creates and initializes the db once, then truncates the tables (and sequences) in between tests. This is needed to have acceptable test performances. """ def __init__( self, user: str, host: str, port: str, db_name: str, version: Union[str, float, Version], - dump_files: Optional[str] = None, + dump_files: Union[None, str, List[str]] = None, no_truncate_tables: Set[str] = set(), ) -> None: super().__init__(user, host, port, db_name, version) - if dump_files: + if dump_files is None: + self.dump_files = [] + elif isinstance(dump_files, str): self.dump_files = sorted(glob.glob(dump_files), key=sortkey) else: - self.dump_files = [] + self.dump_files = dump_files # do no truncate the following tables self.no_truncate_tables = set(no_truncate_tables) def db_setup(self): conninfo = ( f"host={self.host} user={self.user} port={self.port} dbname={self.db_name}" ) for fname in self.dump_files: subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", conninfo, "-f", fname, ] ) def db_reset(self): """Truncate tables (all but self.no_truncate_tables set) and sequences """ with psycopg2.connect( dbname=self.db_name, user=self.user, host=self.host, port=self.port, ) as cnx: with cnx.cursor() as cur: cur.execute( "SELECT table_name FROM information_schema.tables " "WHERE table_schema = %s", ("public",), ) all_tables = set(table for (table,) in cur.fetchall()) tables_to_truncate = all_tables - self.no_truncate_tables for table in tables_to_truncate: cur.execute("TRUNCATE TABLE %s CASCADE" % table) cur.execute( "SELECT sequence_name FROM information_schema.sequences " "WHERE sequence_schema = %s", ("public",), ) seqs = set(seq for (seq,) in cur.fetchall()) for seq in seqs: cur.execute("ALTER SEQUENCE %s RESTART;" % seq) cnx.commit() def init(self): """Initialize db. Create the db if it does not exist. Reset it if it exists.""" with self.cursor() as cur: cur.execute( "SELECT COUNT(1) FROM pg_database WHERE datname=%s;", (self.db_name,) ) db_exists = cur.fetchone()[0] == 1 if db_exists: cur.execute( "UPDATE pg_database SET datallowconn=true WHERE datname = %s;", (self.db_name,), ) self.db_reset() return # initialize the inexistent db with self.cursor() as cur: cur.execute('CREATE DATABASE "{}";'.format(self.db_name)) self.db_setup() def drop(self): """The original DatabaseJanitor implementation prevents new connections from happening, destroys current opened connections and finally drops the database. We actually do not want to drop the db so we instead do nothing and resets (truncate most tables and sequences) the db instead, in order to have some acceptable performance. """ pass # the postgres_fact factory fixture below is mostly a copy of the code # from pytest-postgresql. We need a custom version here to be able to # specify our version of the DBJanitor we use. def postgresql_fact( process_fixture_name: str, db_name: Optional[str] = None, - dump_files: str = "", + dump_files: Union[str, List[str]] = "", no_truncate_tables: Set[str] = {"dbversion"}, ): @pytest.fixture def postgresql_factory(request: FixtureRequest): """Fixture factory for PostgreSQL. :param FixtureRequest request: fixture request object :rtype: psycopg2.connection :returns: postgresql client """ config = factories.get_config(request) proc_fixture = request.getfixturevalue(process_fixture_name) pg_host = proc_fixture.host pg_port = proc_fixture.port pg_user = proc_fixture.user pg_options = proc_fixture.options pg_db = db_name or config["dbname"] with SWHDatabaseJanitor( pg_user, pg_host, pg_port, pg_db, proc_fixture.version, dump_files=dump_files, no_truncate_tables=no_truncate_tables, ): connection = psycopg2.connect( dbname=pg_db, user=pg_user, host=pg_host, port=pg_port, options=pg_options, ) yield connection connection.close() return postgresql_factory diff --git a/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py b/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py index d4cb0fa..87990a2 100644 --- a/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py +++ b/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py @@ -1,142 +1,173 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import glob import os from swh.core.db import BaseDb from swh.core.db.pytest_plugin import postgresql_fact SQL_DIR = os.path.join(os.path.dirname(__file__), "data") # db with special policy for tables dbversion and people postgres_fun = postgresql_fact( "postgresql_proc", db_name="fun", dump_files=f"{SQL_DIR}/*.sql", no_truncate_tables={"dbversion", "people"}, ) +postgres_fun2 = postgresql_fact( + "postgresql_proc", + db_name="fun2", + dump_files=sorted(glob.glob(f"{SQL_DIR}/*.sql")), + no_truncate_tables={"dbversion", "people"}, +) + def test_smoke_test_fun_db_is_up(postgres_fun): """This ensures the db is created and configured according to its dumps files. """ with BaseDb.connect(postgres_fun.dsn).cursor() as cur: cur.execute("select count(*) from dbversion") nb_rows = cur.fetchone()[0] assert nb_rows == 5 cur.execute("select count(*) from fun") nb_rows = cur.fetchone()[0] assert nb_rows == 3 cur.execute("select count(*) from people") nb_rows = cur.fetchone()[0] assert nb_rows == 2 # in data, we requested a value already so it starts at 2 cur.execute("select nextval('serial')") val = cur.fetchone()[0] assert val == 2 +def test_smoke_test_fun2_db_is_up(postgres_fun2): + """This ensures the db is created and configured according to its dumps files. + + """ + with BaseDb.connect(postgres_fun2.dsn).cursor() as cur: + cur.execute("select count(*) from dbversion") + nb_rows = cur.fetchone()[0] + assert nb_rows == 5 + + cur.execute("select count(*) from fun") + nb_rows = cur.fetchone()[0] + assert nb_rows == 3 + + cur.execute("select count(*) from people") + nb_rows = cur.fetchone()[0] + assert nb_rows == 2 + + # in data, we requested a value already so it starts at 2 + cur.execute("select nextval('serial')") + val = cur.fetchone()[0] + assert val == 2 + + def test_smoke_test_fun_db_is_still_up_and_got_reset(postgres_fun): """This ensures that within another tests, the 'fun' db is still up, created (and not configured again). This time, most of the data has been reset: - except for tables 'dbversion' and 'people' which were left as is - the other tables from the schema (here only "fun") got truncated - the sequences got truncated as well """ with BaseDb.connect(postgres_fun.dsn).cursor() as cur: # db version is excluded from the truncate cur.execute("select count(*) from dbversion") nb_rows = cur.fetchone()[0] assert nb_rows == 5 # people is also allowed not to be truncated cur.execute("select count(*) from people") nb_rows = cur.fetchone()[0] assert nb_rows == 2 # table and sequence are reset cur.execute("select count(*) from fun") nb_rows = cur.fetchone()[0] assert nb_rows == 0 cur.execute("select nextval('serial')") val = cur.fetchone()[0] assert val == 1 # db with no special policy for tables truncation, all tables are reset postgres_people = postgresql_fact( "postgresql_proc", db_name="people", dump_files=f"{SQL_DIR}/*.sql", no_truncate_tables=set(), ) def test_smoke_test_people_db_up(postgres_people): """'people' db is up and configured """ with BaseDb.connect(postgres_people.dsn).cursor() as cur: cur.execute("select count(*) from dbversion") nb_rows = cur.fetchone()[0] assert nb_rows == 5 cur.execute("select count(*) from people") nb_rows = cur.fetchone()[0] assert nb_rows == 2 cur.execute("select count(*) from fun") nb_rows = cur.fetchone()[0] assert nb_rows == 3 cur.execute("select nextval('serial')") val = cur.fetchone()[0] assert val == 2 def test_smoke_test_people_db_up_and_reset(postgres_people): """'people' db is up and got reset on every tables and sequences """ with BaseDb.connect(postgres_people.dsn).cursor() as cur: # tables are truncated after the first round cur.execute("select count(*) from dbversion") nb_rows = cur.fetchone()[0] assert nb_rows == 0 # tables are truncated after the first round cur.execute("select count(*) from people") nb_rows = cur.fetchone()[0] assert nb_rows == 0 # table and sequence are reset cur.execute("select count(*) from fun") nb_rows = cur.fetchone()[0] assert nb_rows == 0 cur.execute("select nextval('serial')") val = cur.fetchone()[0] assert val == 1 # db with no initialization step, an empty db postgres_no_init = postgresql_fact("postgresql_proc", db_name="something") def test_smoke_test_db_no_init(postgres_no_init): """We can connect to the db nonetheless """ with BaseDb.connect(postgres_no_init.dsn).cursor() as cur: cur.execute("select now()") data = cur.fetchone()[0] assert data is not None