Changeset View
Changeset View
Standalone View
Standalone View
swh/core/db/pytest_plugin.py
# Copyright (C) 2020 The Software Heritage developers | # Copyright (C) 2020 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
import glob | import glob | ||||
import logging | import logging | ||||
import subprocess | import subprocess | ||||
from typing import List, Optional, Set, Union | from typing import List, Optional, Set, Union | ||||
from _pytest.fixtures import FixtureRequest | from _pytest.fixtures import FixtureRequest | ||||
import psycopg2 | import psycopg2 | ||||
import pytest | import pytest | ||||
from pytest_postgresql import factories | |||||
from pytest_postgresql.janitor import DatabaseJanitor, Version | from pytest_postgresql.janitor import DatabaseJanitor, Version | ||||
try: | |||||
from pytest_postgresql.config import get_config as pytest_postgresql_get_config | |||||
except ImportError: | |||||
# pytest_postgresql < 3.0.0 | |||||
from pytest_postgresql.factories import get_config as pytest_postgresql_get_config | |||||
from swh.core.utils import numfile_sortkey as sortkey | from swh.core.utils import numfile_sortkey as sortkey | ||||
logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
class SWHDatabaseJanitor(DatabaseJanitor): | class SWHDatabaseJanitor(DatabaseJanitor): | ||||
"""SWH database janitor implementation with a a different setup/teardown policy than | """SWH database janitor implementation with a a different setup/teardown policy than | ||||
than the stock one. Instead of dropping, creating and initializing the database for | than the stock one. Instead of dropping, creating and initializing the database for | ||||
each test, it creates and initializes the db once, then truncates the tables (and | each test, it creates and initializes the db once, then truncates the tables (and | ||||
sequences) in between tests. | sequences) in between tests. | ||||
This is needed to have acceptable test performances. | This is needed to have acceptable test performances. | ||||
""" | """ | ||||
def __init__( | def __init__( | ||||
self, | self, | ||||
user: str, | user: str, | ||||
host: str, | host: str, | ||||
port: str, | port: str, | ||||
db_name: str, | dbname: str, | ||||
version: Union[str, float, Version], | version: Union[str, float, Version], | ||||
dump_files: Union[None, str, List[str]] = None, | dump_files: Union[None, str, List[str]] = None, | ||||
no_truncate_tables: Set[str] = set(), | no_truncate_tables: Set[str] = set(), | ||||
) -> None: | ) -> None: | ||||
super().__init__(user, host, port, db_name, version) | super().__init__(user, host, port, dbname, version) | ||||
if not hasattr(self, "dbname"): | |||||
# pytest_postgresql < 3.0.0 | |||||
self.dbname = self.db_name | |||||
if dump_files is None: | if dump_files is None: | ||||
self.dump_files = [] | self.dump_files = [] | ||||
elif isinstance(dump_files, str): | elif isinstance(dump_files, str): | ||||
self.dump_files = sorted(glob.glob(dump_files), key=sortkey) | self.dump_files = sorted(glob.glob(dump_files), key=sortkey) | ||||
else: | else: | ||||
self.dump_files = dump_files | self.dump_files = dump_files | ||||
# do no truncate the following tables | # do no truncate the following tables | ||||
self.no_truncate_tables = set(no_truncate_tables) | self.no_truncate_tables = set(no_truncate_tables) | ||||
def db_setup(self): | def db_setup(self): | ||||
conninfo = ( | conninfo = ( | ||||
f"host={self.host} user={self.user} port={self.port} dbname={self.db_name}" | f"host={self.host} user={self.user} port={self.port} dbname={self.dbname}" | ||||
) | ) | ||||
for fname in self.dump_files: | for fname in self.dump_files: | ||||
subprocess.check_call( | subprocess.check_call( | ||||
[ | [ | ||||
"psql", | "psql", | ||||
"--quiet", | "--quiet", | ||||
"--no-psqlrc", | "--no-psqlrc", | ||||
"-v", | "-v", | ||||
"ON_ERROR_STOP=1", | "ON_ERROR_STOP=1", | ||||
"-d", | "-d", | ||||
conninfo, | conninfo, | ||||
"-f", | "-f", | ||||
fname, | fname, | ||||
] | ] | ||||
) | ) | ||||
def db_reset(self): | def db_reset(self): | ||||
"""Truncate tables (all but self.no_truncate_tables set) and sequences | """Truncate tables (all but self.no_truncate_tables set) and sequences | ||||
""" | """ | ||||
with psycopg2.connect( | with psycopg2.connect( | ||||
dbname=self.db_name, user=self.user, host=self.host, port=self.port, | dbname=self.dbname, user=self.user, host=self.host, port=self.port, | ||||
) as cnx: | ) as cnx: | ||||
with cnx.cursor() as cur: | with cnx.cursor() as cur: | ||||
cur.execute( | cur.execute( | ||||
"SELECT table_name FROM information_schema.tables " | "SELECT table_name FROM information_schema.tables " | ||||
"WHERE table_schema = %s", | "WHERE table_schema = %s", | ||||
("public",), | ("public",), | ||||
) | ) | ||||
all_tables = set(table for (table,) in cur.fetchall()) | all_tables = set(table for (table,) in cur.fetchall()) | ||||
Show All 11 Lines | def db_reset(self): | ||||
for seq in seqs: | for seq in seqs: | ||||
cur.execute("ALTER SEQUENCE %s RESTART;" % seq) | cur.execute("ALTER SEQUENCE %s RESTART;" % seq) | ||||
cnx.commit() | cnx.commit() | ||||
def init(self): | def init(self): | ||||
"""Initialize db. Create the db if it does not exist. Reset it if it exists.""" | """Initialize db. Create the db if it does not exist. Reset it if it exists.""" | ||||
with self.cursor() as cur: | with self.cursor() as cur: | ||||
cur.execute( | cur.execute( | ||||
"SELECT COUNT(1) FROM pg_database WHERE datname=%s;", (self.db_name,) | "SELECT COUNT(1) FROM pg_database WHERE datname=%s;", (self.dbname,) | ||||
) | ) | ||||
db_exists = cur.fetchone()[0] == 1 | db_exists = cur.fetchone()[0] == 1 | ||||
if db_exists: | if db_exists: | ||||
cur.execute( | cur.execute( | ||||
"UPDATE pg_database SET datallowconn=true WHERE datname = %s;", | "UPDATE pg_database SET datallowconn=true WHERE datname = %s;", | ||||
(self.db_name,), | (self.dbname,), | ||||
) | ) | ||||
self.db_reset() | self.db_reset() | ||||
return | return | ||||
# initialize the inexistent db | # initialize the inexistent db | ||||
with self.cursor() as cur: | with self.cursor() as cur: | ||||
cur.execute('CREATE DATABASE "{}";'.format(self.db_name)) | cur.execute('CREATE DATABASE "{}";'.format(self.dbname)) | ||||
self.db_setup() | self.db_setup() | ||||
def drop(self): | def drop(self): | ||||
"""The original DatabaseJanitor implementation prevents new connections from happening, | """The original DatabaseJanitor implementation prevents new connections from happening, | ||||
destroys current opened connections and finally drops the database. | destroys current opened connections and finally drops the database. | ||||
We actually do not want to drop the db so we instead do nothing and resets | We actually do not want to drop the db so we instead do nothing and resets | ||||
(truncate most tables and sequences) the db instead, in order to have some | (truncate most tables and sequences) the db instead, in order to have some | ||||
acceptable performance. | acceptable performance. | ||||
""" | """ | ||||
pass | pass | ||||
# the postgres_fact factory fixture below is mostly a copy of the code | # the postgres_fact factory fixture below is mostly a copy of the code | ||||
# from pytest-postgresql. We need a custom version here to be able to | # from pytest-postgresql. We need a custom version here to be able to | ||||
# specify our version of the DBJanitor we use. | # specify our version of the DBJanitor we use. | ||||
def postgresql_fact( | def postgresql_fact( | ||||
process_fixture_name: str, | process_fixture_name: str, | ||||
db_name: Optional[str] = None, | dbname: Optional[str] = None, | ||||
dump_files: Union[str, List[str]] = "", | dump_files: Union[str, List[str]] = "", | ||||
no_truncate_tables: Set[str] = {"dbversion"}, | no_truncate_tables: Set[str] = {"dbversion"}, | ||||
): | ): | ||||
@pytest.fixture | @pytest.fixture | ||||
def postgresql_factory(request: FixtureRequest): | def postgresql_factory(request: FixtureRequest): | ||||
"""Fixture factory for PostgreSQL. | """Fixture factory for PostgreSQL. | ||||
:param FixtureRequest request: fixture request object | :param FixtureRequest request: fixture request object | ||||
:rtype: psycopg2.connection | :rtype: psycopg2.connection | ||||
:returns: postgresql client | :returns: postgresql client | ||||
""" | """ | ||||
config = factories.get_config(request) | config = pytest_postgresql_get_config(request) | ||||
proc_fixture = request.getfixturevalue(process_fixture_name) | proc_fixture = request.getfixturevalue(process_fixture_name) | ||||
pg_host = proc_fixture.host | pg_host = proc_fixture.host | ||||
pg_port = proc_fixture.port | pg_port = proc_fixture.port | ||||
pg_user = proc_fixture.user | pg_user = proc_fixture.user | ||||
pg_options = proc_fixture.options | pg_options = proc_fixture.options | ||||
pg_db = db_name or config["dbname"] | pg_db = dbname or config["dbname"] | ||||
with SWHDatabaseJanitor( | with SWHDatabaseJanitor( | ||||
pg_user, | pg_user, | ||||
pg_host, | pg_host, | ||||
pg_port, | pg_port, | ||||
pg_db, | pg_db, | ||||
proc_fixture.version, | proc_fixture.version, | ||||
dump_files=dump_files, | dump_files=dump_files, | ||||
no_truncate_tables=no_truncate_tables, | no_truncate_tables=no_truncate_tables, | ||||
Show All 12 Lines |