diff --git a/requirements-db-pytestplugin.txt b/requirements-db-pytestplugin.txt index 9f25d5b..54e57b4 100644 --- a/requirements-db-pytestplugin.txt +++ b/requirements-db-pytestplugin.txt @@ -1,2 +1,2 @@ # requirements for swh.core.db.pytest_plugin -pytest-postgresql < 3.0.0 +pytest-postgresql diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py index c90b13a..3958772 100755 --- a/swh/core/cli/db.py +++ b/swh/core/cli/db.py @@ -1,332 +1,335 @@ #!/usr/bin/env python3 # Copyright (C) 2018-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging from os import environ, path from typing import Collection, Dict, Optional, Tuple import warnings import click from swh.core.cli import CONTEXT_SETTINGS from swh.core.cli import swh as swh_cli_group warnings.filterwarnings("ignore") # noqa prevent psycopg from telling us sh*t logger = logging.getLogger(__name__) @swh_cli_group.group(name="db", context_settings=CONTEXT_SETTINGS) @click.option( "--config-file", "-C", default=None, type=click.Path(exists=True, dir_okay=False), help="Configuration file.", ) @click.pass_context def db(ctx, config_file): """Software Heritage database generic tools.""" from swh.core.config import read as config_read ctx.ensure_object(dict) if config_file is None: config_file = environ.get("SWH_CONFIG_FILENAME") cfg = config_read(config_file) ctx.obj["config"] = cfg @db.command(name="create", context_settings=CONTEXT_SETTINGS) @click.argument("module", required=True) @click.option( + "--dbname", "--db-name", "-d", help="Database name.", default="softwareheritage-dev", show_default=True, ) @click.option( "--template", "-T", help="Template database from which to build this database.", default="template1", show_default=True, ) -def db_create(module, db_name, template): +def db_create(module, dbname, template): """Create a database for the Software Heritage . and potentially execute superuser-level initialization steps. Example:: swh db create -d swh-test storage If you want to specify non-default postgresql connection parameters, please provide them using standard environment variables or by the mean of a properly crafted libpq connection URI. See psql(1) man page (section ENVIRONMENTS) for details. Note: this command requires a postgresql connection with superuser permissions. Example:: PGPORT=5434 swh db create indexer swh db create -d postgresql://superuser:passwd@pghost:5433/swh-storage storage """ - logger.debug("db_create %s dn_name=%s", module, db_name) - create_database_for_package(module, db_name, template) + logger.debug("db_create %s dn_name=%s", module, dbname) + create_database_for_package(module, dbname, template) @db.command(name="init-admin", context_settings=CONTEXT_SETTINGS) @click.argument("module", required=True) @click.option( + "--dbname", "--db-name", "-d", help="Database name.", default="softwareheritage-dev", show_default=True, ) -def db_init_admin(module: str, db_name: str) -> None: +def db_init_admin(module: str, dbname: str) -> None: """Execute superuser-level initialization steps (e.g pg extensions, admin functions, ...) Example:: PGPASSWORD=... swh db init-admin -d swh-test scheduler If you want to specify non-default postgresql connection parameters, please provide them using standard environment variables or by the mean of a properly crafted libpq connection URI. See psql(1) man page (section ENVIRONMENTS) for details. Note: this command requires a postgresql connection with superuser permissions (e.g postgres, swh-admin, ...) Example:: PGPORT=5434 swh db init-admin scheduler swh db init-admin -d postgresql://superuser:passwd@pghost:5433/swh-scheduler \ scheduler """ - logger.debug("db_init_admin %s db_name=%s", module, db_name) - init_admin_extensions(module, db_name) + logger.debug("db_init_admin %s dbname=%s", module, dbname) + init_admin_extensions(module, dbname) @db.command(name="init", context_settings=CONTEXT_SETTINGS) @click.argument("module", required=True) @click.option( + "--dbname", "--db-name", "-d", help="Database name.", default="softwareheritage-dev", show_default=True, ) @click.option( "--flavor", help="Database flavor.", default=None, ) -def db_init(module, db_name, flavor): +def db_init(module, dbname, flavor): """Initialize a database for the Software Heritage . Example:: swh db init -d swh-test storage If you want to specify non-default postgresql connection parameters, please provide them using standard environment variables. See psql(1) man page (section ENVIRONMENTS) for details. Examples:: PGPORT=5434 swh db init indexer swh db init -d postgresql://user:passwd@pghost:5433/swh-storage storage swh db init --flavor read_replica -d swh-storage storage """ - logger.debug("db_init %s flavor=%s dn_name=%s", module, flavor, db_name) + logger.debug("db_init %s flavor=%s dbname=%s", module, flavor, dbname) initialized, dbversion, dbflavor = populate_database_for_package( - module, db_name, flavor + module, dbname, flavor ) # TODO: Ideally migrate the version from db_version to the latest # db version click.secho( "DONE database for {} {}{} at version {}".format( module, "initialized" if initialized else "exists", f" (flavor {dbflavor})" if dbflavor is not None else "", dbversion, ), fg="green", bold=True, ) if flavor is not None and dbflavor != flavor: click.secho( f"WARNING requested flavor '{flavor}' != recorded flavor '{dbflavor}'", fg="red", bold=True, ) def get_sql_for_package(modname): import glob from importlib import import_module from swh.core.utils import numfile_sortkey as sortkey if not modname.startswith("swh."): modname = "swh.{}".format(modname) try: m = import_module(modname) except ImportError: raise click.BadParameter("Unable to load module {}".format(modname)) sqldir = path.join(path.dirname(m.__file__), "sql") if not path.isdir(sqldir): raise click.BadParameter( "Module {} does not provide a db schema " "(no sql/ dir)".format(modname) ) return sorted(glob.glob(path.join(sqldir, "*.sql")), key=sortkey) def populate_database_for_package( modname: str, conninfo: str, flavor: Optional[str] = None ) -> Tuple[bool, int, Optional[str]]: """Populate the database, pointed at with ``conninfo``, using the SQL files found in the package ``modname``. Args: modname: Name of the module of which we're loading the files conninfo: connection info string for the SQL database flavor: the module-specific flavor which we want to initialize the database under Returns: Tuple with three elements: whether the database has been initialized; the current version of the database; if it exists, the flavor of the database. """ from swh.core.db.db_utils import swh_db_flavor, swh_db_version current_version = swh_db_version(conninfo) if current_version is not None: dbflavor = swh_db_flavor(conninfo) return False, current_version, dbflavor sqlfiles = get_sql_for_package(modname) sqlfiles = [fname for fname in sqlfiles if "-superuser-" not in fname] execute_sqlfiles(sqlfiles, conninfo, flavor) current_version = swh_db_version(conninfo) assert current_version is not None dbflavor = swh_db_flavor(conninfo) return True, current_version, dbflavor def parse_dsn_or_dbname(dsn_or_dbname: str) -> Dict[str, str]: """Parse a psycopg2 dsn, falling back to supporting plain database names as well""" import psycopg2 from psycopg2.extensions import parse_dsn as _parse_dsn try: return _parse_dsn(dsn_or_dbname) except psycopg2.ProgrammingError: # psycopg2 failed to parse the DSN; it's probably a database name, # handle it as such return _parse_dsn(f"dbname={dsn_or_dbname}") def init_admin_extensions(modname: str, conninfo: str) -> None: """The remaining initialization process -- running -superuser- SQL files -- is done using the given conninfo, thus connecting to the newly created database """ sqlfiles = get_sql_for_package(modname) sqlfiles = [fname for fname in sqlfiles if "-superuser-" in fname] execute_sqlfiles(sqlfiles, conninfo) def create_database_for_package( modname: str, conninfo: str, template: str = "template1" ): """Create the database pointed at with ``conninfo``, and initialize it using -superuser- SQL files found in the package ``modname``. Args: modname: Name of the module of which we're loading the files conninfo: connection info string or plain database name for the SQL database template: the name of the database to connect to and use as template to create the new database """ import subprocess from psycopg2.extensions import make_dsn # Use the given conninfo string, but with dbname replaced by the template dbname # for the database creation step creation_dsn = parse_dsn_or_dbname(conninfo) - db_name = creation_dsn["dbname"] + dbname = creation_dsn["dbname"] creation_dsn["dbname"] = template - logger.debug("db_create db_name=%s (from %s)", db_name, template) + logger.debug("db_create dbname=%s (from %s)", dbname, template) subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", make_dsn(**creation_dsn), "-c", - f'CREATE DATABASE "{db_name}"', + f'CREATE DATABASE "{dbname}"', ] ) init_admin_extensions(modname, conninfo) def execute_sqlfiles( sqlfiles: Collection[str], conninfo: str, flavor: Optional[str] = None ): """Execute a list of SQL files on the database pointed at with ``conninfo``. Args: sqlfiles: List of SQL files to execute conninfo: connection info string for the SQL database flavor: the database flavor to initialize """ import subprocess psql_command = [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", conninfo, ] flavor_set = False for sqlfile in sqlfiles: - logger.debug(f"execute SQL file {sqlfile} db_name={conninfo}") + logger.debug(f"execute SQL file {sqlfile} dbname={conninfo}") subprocess.check_call(psql_command + ["-f", sqlfile]) if flavor is not None and not flavor_set and sqlfile.endswith("-flavor.sql"): logger.debug("Setting database flavor %s", flavor) query = f"insert into dbflavor (flavor) values ('{flavor}')" subprocess.check_call(psql_command + ["-c", query]) flavor_set = True if flavor is not None and not flavor_set: logger.warn( "Asked for flavor %s, but module does not support database flavors", flavor, ) diff --git a/swh/core/db/pytest_plugin.py b/swh/core/db/pytest_plugin.py index 1d6b1f7..1a77fa8 100644 --- a/swh/core/db/pytest_plugin.py +++ b/swh/core/db/pytest_plugin.py @@ -1,177 +1,185 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import glob import logging import subprocess from typing import List, Optional, Set, Union from _pytest.fixtures import FixtureRequest import psycopg2 import pytest -from pytest_postgresql import factories from pytest_postgresql.janitor import DatabaseJanitor, Version +try: + from pytest_postgresql.config import get_config as pytest_postgresql_get_config +except ImportError: + # pytest_postgresql < 3.0.0 + from pytest_postgresql.factories import get_config as pytest_postgresql_get_config + from swh.core.utils import numfile_sortkey as sortkey logger = logging.getLogger(__name__) class SWHDatabaseJanitor(DatabaseJanitor): """SWH database janitor implementation with a a different setup/teardown policy than than the stock one. Instead of dropping, creating and initializing the database for each test, it creates and initializes the db once, then truncates the tables (and sequences) in between tests. This is needed to have acceptable test performances. """ def __init__( self, user: str, host: str, port: str, - db_name: str, + dbname: str, version: Union[str, float, Version], dump_files: Union[None, str, List[str]] = None, no_truncate_tables: Set[str] = set(), ) -> None: - super().__init__(user, host, port, db_name, version) + super().__init__(user, host, port, dbname, version) + if not hasattr(self, "dbname"): + # pytest_postgresql < 3.0.0 + self.dbname = self.db_name if dump_files is None: self.dump_files = [] elif isinstance(dump_files, str): self.dump_files = sorted(glob.glob(dump_files), key=sortkey) else: self.dump_files = dump_files # do no truncate the following tables self.no_truncate_tables = set(no_truncate_tables) def db_setup(self): conninfo = ( - f"host={self.host} user={self.user} port={self.port} dbname={self.db_name}" + f"host={self.host} user={self.user} port={self.port} dbname={self.dbname}" ) for fname in self.dump_files: subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", conninfo, "-f", fname, ] ) def db_reset(self): """Truncate tables (all but self.no_truncate_tables set) and sequences """ with psycopg2.connect( - dbname=self.db_name, user=self.user, host=self.host, port=self.port, + dbname=self.dbname, user=self.user, host=self.host, port=self.port, ) as cnx: with cnx.cursor() as cur: cur.execute( "SELECT table_name FROM information_schema.tables " "WHERE table_schema = %s", ("public",), ) all_tables = set(table for (table,) in cur.fetchall()) tables_to_truncate = all_tables - self.no_truncate_tables for table in tables_to_truncate: cur.execute("TRUNCATE TABLE %s CASCADE" % table) cur.execute( "SELECT sequence_name FROM information_schema.sequences " "WHERE sequence_schema = %s", ("public",), ) seqs = set(seq for (seq,) in cur.fetchall()) for seq in seqs: cur.execute("ALTER SEQUENCE %s RESTART;" % seq) cnx.commit() def init(self): """Initialize db. Create the db if it does not exist. Reset it if it exists.""" with self.cursor() as cur: cur.execute( - "SELECT COUNT(1) FROM pg_database WHERE datname=%s;", (self.db_name,) + "SELECT COUNT(1) FROM pg_database WHERE datname=%s;", (self.dbname,) ) db_exists = cur.fetchone()[0] == 1 if db_exists: cur.execute( "UPDATE pg_database SET datallowconn=true WHERE datname = %s;", - (self.db_name,), + (self.dbname,), ) self.db_reset() return # initialize the inexistent db with self.cursor() as cur: - cur.execute('CREATE DATABASE "{}";'.format(self.db_name)) + cur.execute('CREATE DATABASE "{}";'.format(self.dbname)) self.db_setup() def drop(self): """The original DatabaseJanitor implementation prevents new connections from happening, destroys current opened connections and finally drops the database. We actually do not want to drop the db so we instead do nothing and resets (truncate most tables and sequences) the db instead, in order to have some acceptable performance. """ pass # the postgres_fact factory fixture below is mostly a copy of the code # from pytest-postgresql. We need a custom version here to be able to # specify our version of the DBJanitor we use. def postgresql_fact( process_fixture_name: str, - db_name: Optional[str] = None, + dbname: Optional[str] = None, dump_files: Union[str, List[str]] = "", no_truncate_tables: Set[str] = {"dbversion"}, ): @pytest.fixture def postgresql_factory(request: FixtureRequest): """Fixture factory for PostgreSQL. :param FixtureRequest request: fixture request object :rtype: psycopg2.connection :returns: postgresql client """ - config = factories.get_config(request) + config = pytest_postgresql_get_config(request) proc_fixture = request.getfixturevalue(process_fixture_name) pg_host = proc_fixture.host pg_port = proc_fixture.port pg_user = proc_fixture.user pg_options = proc_fixture.options - pg_db = db_name or config["dbname"] + pg_db = dbname or config["dbname"] with SWHDatabaseJanitor( pg_user, pg_host, pg_port, pg_db, proc_fixture.version, dump_files=dump_files, no_truncate_tables=no_truncate_tables, ): connection = psycopg2.connect( dbname=pg_db, user=pg_user, host=pg_host, port=pg_port, options=pg_options, ) yield connection connection.close() return postgresql_factory diff --git a/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py b/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py index 87990a2..84b0039 100644 --- a/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py +++ b/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py @@ -1,173 +1,173 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import glob import os from swh.core.db import BaseDb from swh.core.db.pytest_plugin import postgresql_fact SQL_DIR = os.path.join(os.path.dirname(__file__), "data") # db with special policy for tables dbversion and people postgres_fun = postgresql_fact( "postgresql_proc", - db_name="fun", + dbname="fun", dump_files=f"{SQL_DIR}/*.sql", no_truncate_tables={"dbversion", "people"}, ) postgres_fun2 = postgresql_fact( "postgresql_proc", - db_name="fun2", + dbname="fun2", dump_files=sorted(glob.glob(f"{SQL_DIR}/*.sql")), no_truncate_tables={"dbversion", "people"}, ) def test_smoke_test_fun_db_is_up(postgres_fun): """This ensures the db is created and configured according to its dumps files. """ with BaseDb.connect(postgres_fun.dsn).cursor() as cur: cur.execute("select count(*) from dbversion") nb_rows = cur.fetchone()[0] assert nb_rows == 5 cur.execute("select count(*) from fun") nb_rows = cur.fetchone()[0] assert nb_rows == 3 cur.execute("select count(*) from people") nb_rows = cur.fetchone()[0] assert nb_rows == 2 # in data, we requested a value already so it starts at 2 cur.execute("select nextval('serial')") val = cur.fetchone()[0] assert val == 2 def test_smoke_test_fun2_db_is_up(postgres_fun2): """This ensures the db is created and configured according to its dumps files. """ with BaseDb.connect(postgres_fun2.dsn).cursor() as cur: cur.execute("select count(*) from dbversion") nb_rows = cur.fetchone()[0] assert nb_rows == 5 cur.execute("select count(*) from fun") nb_rows = cur.fetchone()[0] assert nb_rows == 3 cur.execute("select count(*) from people") nb_rows = cur.fetchone()[0] assert nb_rows == 2 # in data, we requested a value already so it starts at 2 cur.execute("select nextval('serial')") val = cur.fetchone()[0] assert val == 2 def test_smoke_test_fun_db_is_still_up_and_got_reset(postgres_fun): """This ensures that within another tests, the 'fun' db is still up, created (and not configured again). This time, most of the data has been reset: - except for tables 'dbversion' and 'people' which were left as is - the other tables from the schema (here only "fun") got truncated - the sequences got truncated as well """ with BaseDb.connect(postgres_fun.dsn).cursor() as cur: # db version is excluded from the truncate cur.execute("select count(*) from dbversion") nb_rows = cur.fetchone()[0] assert nb_rows == 5 # people is also allowed not to be truncated cur.execute("select count(*) from people") nb_rows = cur.fetchone()[0] assert nb_rows == 2 # table and sequence are reset cur.execute("select count(*) from fun") nb_rows = cur.fetchone()[0] assert nb_rows == 0 cur.execute("select nextval('serial')") val = cur.fetchone()[0] assert val == 1 # db with no special policy for tables truncation, all tables are reset postgres_people = postgresql_fact( "postgresql_proc", - db_name="people", + dbname="people", dump_files=f"{SQL_DIR}/*.sql", no_truncate_tables=set(), ) def test_smoke_test_people_db_up(postgres_people): """'people' db is up and configured """ with BaseDb.connect(postgres_people.dsn).cursor() as cur: cur.execute("select count(*) from dbversion") nb_rows = cur.fetchone()[0] assert nb_rows == 5 cur.execute("select count(*) from people") nb_rows = cur.fetchone()[0] assert nb_rows == 2 cur.execute("select count(*) from fun") nb_rows = cur.fetchone()[0] assert nb_rows == 3 cur.execute("select nextval('serial')") val = cur.fetchone()[0] assert val == 2 def test_smoke_test_people_db_up_and_reset(postgres_people): """'people' db is up and got reset on every tables and sequences """ with BaseDb.connect(postgres_people.dsn).cursor() as cur: # tables are truncated after the first round cur.execute("select count(*) from dbversion") nb_rows = cur.fetchone()[0] assert nb_rows == 0 # tables are truncated after the first round cur.execute("select count(*) from people") nb_rows = cur.fetchone()[0] assert nb_rows == 0 # table and sequence are reset cur.execute("select count(*) from fun") nb_rows = cur.fetchone()[0] assert nb_rows == 0 cur.execute("select nextval('serial')") val = cur.fetchone()[0] assert val == 1 # db with no initialization step, an empty db -postgres_no_init = postgresql_fact("postgresql_proc", db_name="something") +postgres_no_init = postgresql_fact("postgresql_proc", dbname="something") def test_smoke_test_db_no_init(postgres_no_init): """We can connect to the db nonetheless """ with BaseDb.connect(postgres_no_init.dsn).cursor() as cur: cur.execute("select now()") data = cur.fetchone()[0] assert data is not None diff --git a/swh/core/db/tests/test_cli.py b/swh/core/db/tests/test_cli.py index d817295..b3fc568 100644 --- a/swh/core/db/tests/test_cli.py +++ b/swh/core/db/tests/test_cli.py @@ -1,242 +1,240 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import copy import glob from os import path from click.testing import CliRunner import pytest from swh.core.cli.db import db as swhdb from swh.core.db import BaseDb from swh.core.db.pytest_plugin import postgresql_fact from swh.core.tests.test_cli import assert_section_contains @pytest.fixture def cli_runner(): return CliRunner() def test_cli_swh_help(swhmain, cli_runner): swhmain.add_command(swhdb) result = cli_runner.invoke(swhmain, ["-h"]) assert result.exit_code == 0 assert_section_contains( result.output, "Commands", "db Software Heritage database generic tools." ) help_db_snippets = ( ( "Usage", ( "Usage: swh db [OPTIONS] COMMAND [ARGS]...", "Software Heritage database generic tools.", ), ), ( "Commands", ( "create Create a database for the Software Heritage .", "init Initialize a database for the Software Heritage .", "init-admin Execute superuser-level initialization steps", ), ), ) def test_cli_swh_db_help(swhmain, cli_runner): swhmain.add_command(swhdb) result = cli_runner.invoke(swhmain, ["db", "-h"]) assert result.exit_code == 0 for section, snippets in help_db_snippets: for snippet in snippets: assert_section_contains(result.output, section, snippet) @pytest.fixture() def mock_package_sql(mocker, datadir): """This bypasses the module manipulation to only returns the data test files. """ from swh.core.utils import numfile_sortkey as sortkey mock_sql_files = mocker.patch("swh.core.cli.db.get_sql_for_package") sql_files = sorted(glob.glob(path.join(datadir, "cli", "*.sql")), key=sortkey) mock_sql_files.return_value = sql_files return mock_sql_files # We do not want the truncate behavior for those tests test_db = postgresql_fact( - "postgresql_proc", db_name="clidb", no_truncate_tables={"dbversion", "origin"} + "postgresql_proc", dbname="clidb", no_truncate_tables={"dbversion", "origin"} ) @pytest.fixture def swh_db_cli(cli_runner, monkeypatch, test_db): """This initializes a cli_runner and sets the correct environment variable expected by - the cli to run appropriately (when not specifying the --db-name flag) + the cli to run appropriately (when not specifying the --dbname flag) """ db_params = test_db.get_dsn_parameters() monkeypatch.setenv("PGHOST", db_params["host"]) monkeypatch.setenv("PGUSER", db_params["user"]) monkeypatch.setenv("PGPORT", db_params["port"]) return cli_runner, db_params def craft_conninfo(test_db, dbname=None) -> str: """Craft conninfo string out of the test_db object. This also allows to override the dbname.""" db_params = test_db.get_dsn_parameters() if dbname: params = copy.deepcopy(db_params) params["dbname"] = dbname else: params = db_params return "postgresql://{user}@{host}:{port}/{dbname}".format(**params) def test_cli_swh_db_create_and_init_db(cli_runner, test_db, mock_package_sql): """Create a db then initializing it should be ok """ module_name = "something" conninfo = craft_conninfo(test_db, "new-db") # This creates the db and installs the necessary admin extensions - result = cli_runner.invoke(swhdb, ["create", module_name, "--db-name", conninfo]) + result = cli_runner.invoke(swhdb, ["create", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" # This initializes the schema and data - result = cli_runner.invoke(swhdb, ["init", module_name, "--db-name", conninfo]) + result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" # the origin value in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, installed during db creation step) with BaseDb.connect(conninfo).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 def test_cli_swh_db_initialization_fail_without_creation_first( cli_runner, test_db, mock_package_sql ): """Init command on an inexisting db cannot work """ module_name = "anything" # it's mocked here conninfo = craft_conninfo(test_db, "inexisting-db") - result = cli_runner.invoke(swhdb, ["init", module_name, "--db-name", conninfo]) + result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) # Fails because we cannot connect to an inexisting db assert result.exit_code == 1, f"Unexpected output: {result.output}" def test_cli_swh_db_initialization_fail_without_extension( cli_runner, test_db, mock_package_sql ): """Init command cannot work without privileged extension. In this test, the schema needs privileged extension to work. """ module_name = "anything" # it's mocked here conninfo = craft_conninfo(test_db) - result = cli_runner.invoke(swhdb, ["init", module_name, "--db-name", conninfo]) + result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) # Fails as the function `public.digest` is not installed, init-admin calls is needed # first (the next tests show such behavior) assert result.exit_code == 1, f"Unexpected output: {result.output}" def test_cli_swh_db_initialization_works_with_flags( cli_runner, test_db, mock_package_sql ): """Init commands with carefully crafted libpq conninfo works """ module_name = "anything" # it's mocked here conninfo = craft_conninfo(test_db) - result = cli_runner.invoke( - swhdb, ["init-admin", module_name, "--db-name", conninfo] - ) + result = cli_runner.invoke(swhdb, ["init-admin", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" - result = cli_runner.invoke(swhdb, ["init", module_name, "--db-name", conninfo]) + result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" # the origin values in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, init-admin calls installs it) with BaseDb.connect(test_db.dsn).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 def test_cli_swh_db_initialization_with_env(swh_db_cli, mock_package_sql, test_db): """Init commands with standard environment variables works """ module_name = "anything" # it's mocked here cli_runner, db_params = swh_db_cli result = cli_runner.invoke( - swhdb, ["init-admin", module_name, "--db-name", db_params["dbname"]] + swhdb, ["init-admin", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke( - swhdb, ["init", module_name, "--db-name", db_params["dbname"]] + swhdb, ["init", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" # the origin values in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, init-admin calls installs it) with BaseDb.connect(test_db.dsn).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 def test_cli_swh_db_initialization_idempotent(swh_db_cli, mock_package_sql, test_db): """Multiple runs of the init commands are idempotent """ module_name = "anything" # mocked cli_runner, db_params = swh_db_cli result = cli_runner.invoke( - swhdb, ["init-admin", module_name, "--db-name", db_params["dbname"]] + swhdb, ["init-admin", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke( - swhdb, ["init", module_name, "--db-name", db_params["dbname"]] + swhdb, ["init", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke( - swhdb, ["init-admin", module_name, "--db-name", db_params["dbname"]] + swhdb, ["init-admin", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke( - swhdb, ["init", module_name, "--db-name", db_params["dbname"]] + swhdb, ["init", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" # the origin values in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, init-admin calls installs it) with BaseDb.connect(test_db.dsn).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py index 2a1585a..3800947 100644 --- a/swh/core/db/tests/test_db.py +++ b/swh/core/db/tests/test_db.py @@ -1,417 +1,417 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from dataclasses import dataclass import datetime from enum import IntEnum import inspect from string import printable from typing import Any from unittest.mock import MagicMock, Mock import uuid from hypothesis import given, settings, strategies from hypothesis.extra.pytz import timezones import psycopg2 import pytest from typing_extensions import Protocol from swh.core.db import BaseDb from swh.core.db.common import db_transaction, db_transaction_generator from swh.core.db.pytest_plugin import postgresql_fact from swh.core.db.tests.conftest import function_scoped_fixture_check # workaround mypy bug https://github.com/python/mypy/issues/5485 class Converter(Protocol): def __call__(self, x: Any) -> Any: ... @dataclass class Field: name: str """Column name""" pg_type: str """Type of the PostgreSQL column""" example: Any """Example value for the static tests""" strategy: strategies.SearchStrategy """Hypothesis strategy to generate these values""" in_wrapper: Converter = lambda x: x """Wrapper to convert this data type for the static tests""" out_converter: Converter = lambda x: x """Converter from the raw PostgreSQL column value to this data type""" # Limit PostgreSQL integer values pg_int = strategies.integers(-2147483648, +2147483647) pg_text = strategies.text( alphabet=strategies.characters( blacklist_categories=["Cs"], # surrogates blacklist_characters=[ "\x00", # pgsql does not support the null codepoint "\r", # pgsql normalizes those ], ), ) pg_bytea = strategies.binary() def pg_bytea_a(min_size: int, max_size: int) -> strategies.SearchStrategy: """Generate a PostgreSQL bytea[]""" return strategies.lists(pg_bytea, min_size=min_size, max_size=max_size) def pg_bytea_a_a(min_size: int, max_size: int) -> strategies.SearchStrategy: """Generate a PostgreSQL bytea[][]. The inner lists must all have the same size.""" return strategies.integers(min_value=max(1, min_size), max_value=max_size).flatmap( lambda n: strategies.lists( pg_bytea_a(min_size=n, max_size=n), min_size=min_size, max_size=max_size ) ) def pg_tstz() -> strategies.SearchStrategy: """Generate values that fit in a PostgreSQL timestamptz. Notes: We're forbidding old datetimes, because until 1956, many timezones had seconds in their "UTC offsets" (see ), which is not representable by PostgreSQL. """ min_value = datetime.datetime(1960, 1, 1, 0, 0, 0) return strategies.datetimes(min_value=min_value, timezones=timezones()) def pg_jsonb(min_size: int, max_size: int) -> strategies.SearchStrategy: """Generate values representable as a PostgreSQL jsonb object (dict).""" return strategies.dictionaries( strategies.text(printable), strategies.recursive( # should use floats() instead of integers(), but PostgreSQL # coerces large integers into floats, making the tests fail. We # only store ints in our generated data anyway. strategies.none() | strategies.booleans() | strategies.integers(-2147483648, +2147483647) | strategies.text(printable), lambda children: strategies.lists(children, max_size=max_size) | strategies.dictionaries( strategies.text(printable), children, max_size=max_size ), ), min_size=min_size, max_size=max_size, ) def tuple_2d_to_list_2d(v): """Convert a 2D tuple to a 2D list""" return [list(inner) for inner in v] def list_2d_to_tuple_2d(v): """Convert a 2D list to a 2D tuple""" return tuple(tuple(inner) for inner in v) class TestIntEnum(IntEnum): foo = 1 bar = 2 def now(): return datetime.datetime.now(tz=datetime.timezone.utc) FIELDS = ( Field("i", "int", 1, pg_int), Field("txt", "text", "foo", pg_text), Field("bytes", "bytea", b"bar", strategies.binary()), Field( "bytes_array", "bytea[]", [b"baz1", b"baz2"], pg_bytea_a(min_size=0, max_size=5), ), Field( "bytes_tuple", "bytea[]", (b"baz1", b"baz2"), pg_bytea_a(min_size=0, max_size=5).map(tuple), in_wrapper=list, out_converter=tuple, ), Field( "bytes_2d", "bytea[][]", [[b"quux1"], [b"quux2"]], pg_bytea_a_a(min_size=0, max_size=5), ), Field( "bytes_2d_tuple", "bytea[][]", ((b"quux1",), (b"quux2",)), pg_bytea_a_a(min_size=0, max_size=5).map(list_2d_to_tuple_2d), in_wrapper=tuple_2d_to_list_2d, out_converter=list_2d_to_tuple_2d, ), Field("ts", "timestamptz", now(), pg_tstz(),), Field( "dict", "jsonb", {"str": "bar", "int": 1, "list": ["a", "b"], "nested": {"a": "b"}}, pg_jsonb(min_size=0, max_size=5), in_wrapper=psycopg2.extras.Json, ), Field( "intenum", "int", TestIntEnum.foo, strategies.sampled_from(TestIntEnum), in_wrapper=int, out_converter=TestIntEnum, ), Field("uuid", "uuid", uuid.uuid4(), strategies.uuids()), Field( "text_list", "text[]", # All the funky corner cases ["null", "NULL", None, "\\", "\t", "\n", "\r", " ", "'", ",", '"', "{", "}"], strategies.lists(pg_text, min_size=0, max_size=5), ), Field( "tstz_list", "timestamptz[]", [now(), now() + datetime.timedelta(days=1)], strategies.lists(pg_tstz(), min_size=0, max_size=5), ), Field( "tstz_range", "tstzrange", psycopg2.extras.DateTimeTZRange( lower=now(), upper=now() + datetime.timedelta(days=1), bounds="[)", ), strategies.tuples( # generate two sorted timestamptzs for use as bounds strategies.tuples(pg_tstz(), pg_tstz()).map(sorted), # and a set of bounds strategies.sampled_from(["[]", "()", "[)", "(]"]), ).map( # and build the actual DateTimeTZRange object from these args lambda args: psycopg2.extras.DateTimeTZRange( lower=args[0][0], upper=args[0][1], bounds=args[1], ) ), ), ) INIT_SQL = "create table test_table (%s)" % ", ".join( f"{field.name} {field.pg_type}" for field in FIELDS ) COLUMNS = tuple(field.name for field in FIELDS) INSERT_SQL = "insert into test_table (%s) values (%s)" % ( ", ".join(COLUMNS), ", ".join("%s" for i in range(len(COLUMNS))), ) STATIC_ROW_IN = tuple(field.in_wrapper(field.example) for field in FIELDS) EXPECTED_ROW_OUT = tuple(field.example for field in FIELDS) db_rows = strategies.lists(strategies.tuples(*(field.strategy for field in FIELDS))) def convert_lines(cur): return [ tuple(field.out_converter(x) for x, field in zip(line, FIELDS)) for line in cur ] -test_db = postgresql_fact("postgresql_proc", db_name="test-db2") +test_db = postgresql_fact("postgresql_proc", dbname="test-db2") @pytest.fixture def db_with_data(test_db, request): """Fixture to initialize a db with some data out of the "INIT_SQL above """ db = BaseDb.connect(test_db.dsn) with db.cursor() as cur: psycopg2.extras.register_default_jsonb(cur) cur.execute(INIT_SQL) yield db db.conn.rollback() db.conn.close() @pytest.mark.db def test_db_connect(db_with_data): with db_with_data.cursor() as cur: psycopg2.extras.register_default_jsonb(cur) cur.execute(INSERT_SQL, STATIC_ROW_IN) cur.execute("select * from test_table;") output = convert_lines(cur) assert len(output) == 1 assert EXPECTED_ROW_OUT == output[0] def test_db_initialized(db_with_data): with db_with_data.cursor() as cur: psycopg2.extras.register_default_jsonb(cur) cur.execute(INSERT_SQL, STATIC_ROW_IN) cur.execute("select * from test_table;") output = convert_lines(cur) assert len(output) == 1 assert EXPECTED_ROW_OUT == output[0] def test_db_copy_to_static(db_with_data): items = [{field.name: field.example for field in FIELDS}] db_with_data.copy_to(items, "test_table", COLUMNS) with db_with_data.cursor() as cur: cur.execute("select * from test_table;") output = convert_lines(cur) assert len(output) == 1 assert EXPECTED_ROW_OUT == output[0] @settings(suppress_health_check=function_scoped_fixture_check) @given(db_rows) def test_db_copy_to(db_with_data, data): items = [dict(zip(COLUMNS, item)) for item in data] with db_with_data.cursor() as cur: cur.execute("TRUNCATE TABLE test_table CASCADE") db_with_data.copy_to(items, "test_table", COLUMNS) with db_with_data.cursor() as cur: cur.execute("select * from test_table;") converted_lines = convert_lines(cur) assert converted_lines == data def test_db_copy_to_thread_exception(db_with_data): data = [(2 ** 65, "foo", b"bar")] items = [dict(zip(COLUMNS, item)) for item in data] with pytest.raises(psycopg2.errors.NumericValueOutOfRange): db_with_data.copy_to(items, "test_table", COLUMNS) def test_db_transaction(mocker): expected_cur = object() called = False class Storage: @db_transaction() def endpoint(self, cur=None, db=None): nonlocal called called = True assert cur is expected_cur storage = Storage() # 'with storage.get_db().transaction() as cur:' should cause # 'cur' to be 'expected_cur' db_mock = Mock() db_mock.transaction.return_value = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) put_db_mock = mocker.patch.object(storage, "put_db", create=True) storage.endpoint() assert called put_db_mock.assert_called_once_with(db_mock) def test_db_transaction__with_generator(): with pytest.raises(ValueError, match="generator"): class Storage: @db_transaction() def endpoint(self, cur=None, db=None): yield None def test_db_transaction_signature(): """Checks db_transaction removes the 'cur' and 'db' arguments.""" def f(self, foo, *, bar=None): pass expected_sig = inspect.signature(f) @db_transaction() def g(self, foo, *, bar=None, db=None, cur=None): pass actual_sig = inspect.signature(g) assert actual_sig == expected_sig def test_db_transaction_generator(mocker): expected_cur = object() called = False class Storage: @db_transaction_generator() def endpoint(self, cur=None, db=None): nonlocal called called = True assert cur is expected_cur yield None storage = Storage() # 'with storage.get_db().transaction() as cur:' should cause # 'cur' to be 'expected_cur' db_mock = Mock() db_mock.transaction.return_value = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) put_db_mock = mocker.patch.object(storage, "put_db", create=True) list(storage.endpoint()) assert called put_db_mock.assert_called_once_with(db_mock) def test_db_transaction_generator__with_nongenerator(): with pytest.raises(ValueError, match="generator"): class Storage: @db_transaction_generator() def endpoint(self, cur=None, db=None): pass def test_db_transaction_generator_signature(): """Checks db_transaction removes the 'cur' and 'db' arguments.""" def f(self, foo, *, bar=None): pass expected_sig = inspect.signature(f) @db_transaction_generator() def g(self, foo, *, bar=None, db=None, cur=None): yield None actual_sig = inspect.signature(g) assert actual_sig == expected_sig