diff --git a/mypy.ini b/mypy.ini --- a/mypy.ini +++ b/mypy.ini @@ -38,6 +38,9 @@ [mypy-pytest.*] ignore_missing_imports = True +[mypy-pytest_postgresql.*] +ignore_missing_imports = True + [mypy-requests_mock.*] ignore_missing_imports = True diff --git a/requirements-db.txt b/requirements-db.txt --- a/requirements-db.txt +++ b/requirements-db.txt @@ -1,3 +1,4 @@ # requirements for swh.core.db psycopg2 typing-extensions +pytest-postgresql diff --git a/requirements-test-db.txt b/requirements-test-db.txt --- a/requirements-test-db.txt +++ b/requirements-test-db.txt @@ -1 +1 @@ -pytest-postgresql + diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -1,15 +1,20 @@ -# Copyright (C) 2015-2019 The Software Heritage developers +# Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import functools +import glob import logging import re +import subprocess from typing import Optional, Union import psycopg2 import psycopg2.extensions +from pytest_postgresql.janitor import DatabaseJanitor, Version + +from swh.core.utils import numfile_sortkey as sortkey logger = logging.getLogger(__name__) @@ -250,3 +255,100 @@ parts[-1:] = post cur.execute(b"".join(parts)) yield from cur + + +class SWHDatabaseJanitor(DatabaseJanitor): + """SwH database janitor implementation with a a different setup/teardown policy than + than the stock one. Instead of dropping, creating and initializing the database for + each test, it creates and initializes the db only once, then truncates the tables in + between tests. + + This is needed to have acceptable test performances. + + """ + + def __init__( + self, + user: str, + host: str, + port: str, + db_name: str, + version: Union[str, float, Version], + dump_files: str = "*.sql", + ) -> None: + super().__init__(user, host, port, db_name, version) + self.dump_files = sorted(glob.glob(dump_files), key=sortkey) + + def db_setup(self): + conninfo = ( + f"host={self.host} user={self.user} port={self.port} dbname={self.db_name}" + ) + + for fname in self.dump_files: + subprocess.check_call( + [ + "psql", + "--quiet", + "--no-psqlrc", + "-v", + "ON_ERROR_STOP=1", + "-d", + conninfo, + "-f", + fname, + ] + ) + + def db_reset(self): + with psycopg2.connect( + dbname=self.db_name, user=self.user, host=self.host, port=self.port, + ) as cnx: + with cnx.cursor() as cur: + cur.execute( + "SELECT table_name FROM information_schema.tables " + "WHERE table_schema = %s", + ("public",), + ) + tables = set(table for (table,) in cur.fetchall()) - {"dbversion"} + for table in tables: + cur.execute("truncate table %s cascade" % table) + + cur.execute( + "SELECT sequence_name FROM information_schema.sequences " + "WHERE sequence_schema = %s", + ("public",), + ) + seqs = set(seq for (seq,) in cur.fetchall()) + for seq in seqs: + cur.execute("ALTER SEQUENCE %s RESTART;" % seq) + cnx.commit() + + def init(self): + with self.cursor() as cur: + cur.execute( + "SELECT COUNT(1) FROM pg_database WHERE datname=%s;", (self.db_name,) + ) + db_exists = cur.fetchone()[0] == 1 + if db_exists: + cur.execute( + "UPDATE pg_database SET datallowconn=true WHERE datname = %s;", + (self.db_name,), + ) + self.db_reset() + return + + # initialize the inexistent db + with self.cursor() as cur: + cur.execute('CREATE DATABASE "{}";'.format(self.db_name)) + self.db_setup() + + def drop(self): + """The original DatabaseJanitor implementatonJanitor prevents new connections from + happening, destroys current opened connections and finally drops the + database. + + We actually do not want to drop the db so we instead do nothing and truncates + the tables instead (in order to have some acceptable performance). + + """ + pass diff --git a/swh/core/db/pytest_plugin.py b/swh/core/db/pytest_plugin.py new file mode 100644 --- /dev/null +++ b/swh/core/db/pytest_plugin.py @@ -0,0 +1,59 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import logging + +import psycopg2 +import pytest +from pytest_postgresql import factories + +from swh.core.db.db_utils import SWHDatabaseJanitor + +logger = logging.getLogger(__name__) + + +# the postgres_fact factory fixture below is mostly a copy of the code +# from pytest-postgresql. We need a custom version here to be able to +# specify our version of the DBJanitor we use. +def postgresql_fact(process_fixture_name, db_name=None, dump_files=""): + @pytest.fixture + def postgresql_factory(request): + """ + Fixture factory for PostgreSQL. + + :param FixtureRequest request: fixture request object + :rtype: psycopg2.connection + :returns: postgresql client + """ + config = factories.get_config(request) + if not psycopg2: + raise ImportError("No module named psycopg2. Please install it.") + proc_fixture = request.getfixturevalue(process_fixture_name) + + # _, config = try_import('psycopg2', request) + pg_host = proc_fixture.host + pg_port = proc_fixture.port + pg_user = proc_fixture.user + pg_options = proc_fixture.options + pg_db = db_name or config["dbname"] + with SWHDatabaseJanitor( + pg_user, + pg_host, + pg_port, + pg_db, + proc_fixture.version, + dump_files=dump_files, + ): + connection = psycopg2.connect( + dbname=pg_db, + user=pg_user, + host=pg_host, + port=pg_port, + options=pg_options, + ) + yield connection + connection.close() + + return postgresql_factory diff --git a/swh/core/db/tests/data/0-schema.sql b/swh/core/db/tests/data/0-schema.sql new file mode 100644 --- /dev/null +++ b/swh/core/db/tests/data/0-schema.sql @@ -0,0 +1,16 @@ +-- schema version table which won't get truncated +create table dbversion +( + version int primary key, + release timestamptz, + description text +); + +-- a fun table which will get truncated for each test +create table fun +( + time timestamptz not null +); + +-- one sequence to check for reset as well +create sequence serial; diff --git a/swh/core/db/tests/data/1-data.sql b/swh/core/db/tests/data/1-data.sql new file mode 100644 --- /dev/null +++ b/swh/core/db/tests/data/1-data.sql @@ -0,0 +1,12 @@ +-- insert some values in dbversion +insert into dbversion(version, release, description) values (1, '2016-02-22 15:56:28.358587+00', 'Work In Progress'); +insert into dbversion(version, release, description) values (2, '2016-02-24 18:05:54.887217+00', 'Work In Progress'); +insert into dbversion(version, release, description) values (3, '2016-10-21 14:10:18.629763+00', 'Work In Progress'); +insert into dbversion(version, release, description) values (4, '2017-08-08 19:01:11.723113+00', 'Work In Progress'); +insert into dbversion(version, release, description) values (7, '2018-03-30 12:58:39.256679+00', 'Work In Progress'); + +insert into fun(time) values ('2020-10-19 09:00:00.666999+00'); +insert into fun(time) values ('2020-10-18 09:00:00.666999+00'); +insert into fun(time) values ('2020-10-17 09:00:00.666999+00'); + +select nextval('serial'); diff --git a/swh/core/db/tests/test_db_utils.py b/swh/core/db/tests/test_db_utils.py new file mode 100644 --- /dev/null +++ b/swh/core/db/tests/test_db_utils.py @@ -0,0 +1,67 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import os + +import pytest + +from swh.core.db import BaseDb +from swh.core.db.pytest_plugin import postgresql_fact + +SQL_DIR = os.path.join(os.path.dirname(__file__), "data") + + +postgres_fun = postgresql_fact( + "postgresql_proc", db_name="fun", dump_files=f"{SQL_DIR}/*.sql" +) + + +@pytest.fixture() +def db_conn(postgres_fun): + db = BaseDb.connect(postgres_fun.dsn) + return db + + +def test_smoke_test_db_is_up(db_conn): + """This ensures the db is created and configured according to its dumps files. + + """ + with db_conn.cursor() as cur: + cur.execute("select count(*) from dbversion") + nb_rows = cur.fetchone()[0] + assert nb_rows == 5 + + cur.execute("select count(*) from fun") + nb_rows = cur.fetchone()[0] + assert nb_rows == 3 + + # in data, we requested a value already so it starts at 2 + cur.execute("select nextval('serial')") + val = cur.fetchone()[0] + assert val == 2 + + +def test_smoke_test_db_is_still_up_and_got_reset(db_conn): + """This ensures that within another tests, the db is still up, created (and not + configured again). This time, most of the data has been reset: + - table dbversion was left as is + - the other tables from the schema (here only "fun") got truncated + - the sequences got truncated as well + + """ + with db_conn.cursor() as cur: + # db version is excluded from the reset + cur.execute("select count(*) from dbversion") + nb_rows = cur.fetchone()[0] + assert nb_rows == 5 + + # table and sequence are reset + cur.execute("select count(*) from fun") + nb_rows = cur.fetchone()[0] + assert nb_rows == 0 + + cur.execute("select nextval('serial')") + val = cur.fetchone()[0] + assert val == 1