diff --git a/mypy.ini b/mypy.ini index 5d30edb..696abf0 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,45 +1,48 @@ [mypy] namespace_packages = True warn_unused_ignores = True # 3rd party libraries without stubs (yet) [mypy-aiohttp_utils.*] ignore_missing_imports = True [mypy-arrow.*] ignore_missing_imports = True [mypy-celery.*] ignore_missing_imports = True [mypy-decorator.*] ignore_missing_imports = True [mypy-deprecated.*] ignore_missing_imports = True [mypy-django.*] # false positive, only used my hypotesis' extras ignore_missing_imports = True [mypy-iso8601.*] ignore_missing_imports = True [mypy-msgpack.*] ignore_missing_imports = True [mypy-pkg_resources.*] ignore_missing_imports = True [mypy-psycopg2.*] ignore_missing_imports = True [mypy-pytest.*] ignore_missing_imports = True +[mypy-pytest_postgresql.*] +ignore_missing_imports = True + [mypy-requests_mock.*] ignore_missing_imports = True [mypy-systemd.*] ignore_missing_imports = True diff --git a/requirements-db.txt b/requirements-db.txt index 921e04d..d0f0975 100644 --- a/requirements-db.txt +++ b/requirements-db.txt @@ -1,3 +1,4 @@ # requirements for swh.core.db psycopg2 typing-extensions +pytest-postgresql diff --git a/requirements-test-db.txt b/requirements-test-db.txt index cfd42eb..8b13789 100644 --- a/requirements-test-db.txt +++ b/requirements-test-db.txt @@ -1 +1 @@ -pytest-postgresql + diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py index 77ddb59..95b79ff 100644 --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -1,252 +1,366 @@ -# Copyright (C) 2015-2019 The Software Heritage developers +# Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import functools +import glob import logging import re -from typing import Optional, Union +import subprocess +from typing import Optional, Set, Union import psycopg2 import psycopg2.extensions +from pytest_postgresql.janitor import DatabaseJanitor, Version + +from swh.core.utils import numfile_sortkey as sortkey logger = logging.getLogger(__name__) def stored_procedure(stored_proc): """decorator to execute remote stored procedure, specified as argument Generally, the body of the decorated function should be empty. If it is not, the stored procedure will be executed first; the function body then. """ def wrap(meth): @functools.wraps(meth) def _meth(self, *args, **kwargs): cur = kwargs.get("cur", None) self._cursor(cur).execute("SELECT %s()" % stored_proc) meth(self, *args, **kwargs) return _meth return wrap def jsonize(value): """Convert a value to a psycopg2 JSON object if necessary""" if isinstance(value, dict): return psycopg2.extras.Json(value) return value def connect_to_conninfo( db_or_conninfo: Union[str, psycopg2.extensions.connection] ) -> psycopg2.extensions.connection: """Connect to the database passed in argument Args: db_or_conninfo: A database connection, or a database connection info string Returns: a connected database handle Raises: psycopg2.Error if the database doesn't exist """ if isinstance(db_or_conninfo, psycopg2.extensions.connection): return db_or_conninfo if "=" not in db_or_conninfo and "//" not in db_or_conninfo: # Database name db_or_conninfo = f"dbname={db_or_conninfo}" db = psycopg2.connect(db_or_conninfo) return db def swh_db_version( db_or_conninfo: Union[str, psycopg2.extensions.connection] ) -> Optional[int]: """Retrieve the swh version of the database. If the database is not initialized, this logs a warning and returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: Either the version of the database, or None if it couldn't be detected """ try: db = connect_to_conninfo(db_or_conninfo) except psycopg2.Error: logger.exception("Failed to connect to `%s`", db_or_conninfo) # Database not initialized return None try: with db.cursor() as c: query = "select version from dbversion order by dbversion desc limit 1" try: c.execute(query) return c.fetchone()[0] except psycopg2.errors.UndefinedTable: return None except Exception: logger.exception("Could not get version from `%s`", db_or_conninfo) return None def swh_db_flavor( db_or_conninfo: Union[str, psycopg2.extensions.connection] ) -> Optional[str]: """Retrieve the swh flavor of the database. If the database is not initialized, or the database doesn't support flavors, this returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: The flavor of the database, or None if it could not be detected. """ try: db = connect_to_conninfo(db_or_conninfo) except psycopg2.Error: logger.exception("Failed to connect to `%s`", db_or_conninfo) # Database not initialized return None try: with db.cursor() as c: query = "select swh_get_dbflavor()" try: c.execute(query) return c.fetchone()[0] except psycopg2.errors.UndefinedFunction: # function not found: no flavor return None except Exception: logger.exception("Could not get flavor from `%s`", db_or_conninfo) return None # The following code has been imported from psycopg2, version 2.7.4, # https://github.com/psycopg/psycopg2/tree/5afb2ce803debea9533e293eef73c92ffce95bcd # and modified by Software Heritage. # # Original file: lib/extras.py # # psycopg2 is free software: you can redistribute it and/or modify it under the # terms of the GNU Lesser General Public License as published by the Free # Software Foundation, either version 3 of the License, or (at your option) any # later version. def _paginate(seq, page_size): """Consume an iterable and return it in chunks. Every chunk is at most `page_size`. Never return an empty chunk. """ page = [] it = iter(seq) while 1: try: for i in range(page_size): page.append(next(it)) yield page page = [] except StopIteration: if page: yield page return def _split_sql(sql): """Split *sql* on a single ``%s`` placeholder. Split on the %s, perform %% replacement and return pre, post lists of snippets. """ curr = pre = [] post = [] tokens = re.split(br"(%.)", sql) for token in tokens: if len(token) != 2 or token[:1] != b"%": curr.append(token) continue if token[1:] == b"s": if curr is pre: curr = post else: raise ValueError("the query contains more than one '%s' placeholder") elif token[1:] == b"%": curr.append(b"%") else: raise ValueError( "unsupported format character: '%s'" % token[1:].decode("ascii", "replace") ) if curr is pre: raise ValueError("the query doesn't contain any '%s' placeholder") return pre, post def execute_values_generator(cur, sql, argslist, template=None, page_size=100): """Execute a statement using SQL ``VALUES`` with a sequence of parameters. Rows returned by the query are returned through a generator. You need to consume the generator for the queries to be executed! :param cur: the cursor to use to execute the query. :param sql: the query to execute. It must contain a single ``%s`` placeholder, which will be replaced by a `VALUES list`__. Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``. :param argslist: sequence of sequences or dictionaries with the arguments to send to the query. The type and content must be consistent with *template*. :param template: the snippet to merge to every item in *argslist* to compose the query. - If the *argslist* items are sequences it should contain positional placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there are constants value...). - If the *argslist* items are mappings it should contain named placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``). If not specified, assume the arguments are sequence and use a simple positional template (i.e. ``(%s, %s, ...)``), with the number of placeholders sniffed by the first element in *argslist*. :param page_size: maximum number of *argslist* items to include in every statement. If there are more items the function will execute more than one statement. :param yield_from_cur: Whether to yield results from the cursor in this function directly. .. __: https://www.postgresql.org/docs/current/static/queries-values.html After the execution of the function the `cursor.rowcount` property will **not** contain a total result. """ # we can't just use sql % vals because vals is bytes: if sql is bytes # there will be some decoding error because of stupid codec used, and Py3 # doesn't implement % on bytes. if not isinstance(sql, bytes): sql = sql.encode(psycopg2.extensions.encodings[cur.connection.encoding]) pre, post = _split_sql(sql) for page in _paginate(argslist, page_size=page_size): if template is None: template = b"(" + b",".join([b"%s"] * len(page[0])) + b")" parts = pre[:] for args in page: parts.append(cur.mogrify(template, args)) parts.append(b",") parts[-1:] = post cur.execute(b"".join(parts)) yield from cur + + +class SWHDatabaseJanitor(DatabaseJanitor): + """SWH database janitor implementation with a a different setup/teardown policy than + than the stock one. Instead of dropping, creating and initializing the database for + each test, it creates and initializes the db once, then truncates the tables (and + sequences) in between tests. + + This is needed to have acceptable test performances. + + """ + + def __init__( + self, + user: str, + host: str, + port: str, + db_name: str, + version: Union[str, float, Version], + dump_files: Optional[str] = None, + no_truncate_tables: Set[str] = set(), + ) -> None: + super().__init__(user, host, port, db_name, version) + if dump_files: + self.dump_files = sorted(glob.glob(dump_files), key=sortkey) + else: + self.dump_files = [] + # do no truncate the following tables + self.no_truncate_tables = set(no_truncate_tables) + + def db_setup(self): + conninfo = ( + f"host={self.host} user={self.user} port={self.port} dbname={self.db_name}" + ) + + for fname in self.dump_files: + subprocess.check_call( + [ + "psql", + "--quiet", + "--no-psqlrc", + "-v", + "ON_ERROR_STOP=1", + "-d", + conninfo, + "-f", + fname, + ] + ) + + def db_reset(self): + """Truncate tables (all but self.no_truncate_tables set) and sequences + + """ + with psycopg2.connect( + dbname=self.db_name, user=self.user, host=self.host, port=self.port, + ) as cnx: + with cnx.cursor() as cur: + cur.execute( + "SELECT table_name FROM information_schema.tables " + "WHERE table_schema = %s", + ("public",), + ) + all_tables = set(table for (table,) in cur.fetchall()) + tables_to_truncate = all_tables - self.no_truncate_tables + + for table in tables_to_truncate: + cur.execute("TRUNCATE TABLE %s CASCADE" % table) + + cur.execute( + "SELECT sequence_name FROM information_schema.sequences " + "WHERE sequence_schema = %s", + ("public",), + ) + seqs = set(seq for (seq,) in cur.fetchall()) + for seq in seqs: + cur.execute("ALTER SEQUENCE %s RESTART;" % seq) + cnx.commit() + + def init(self): + """Initialize db. Create the db if it does not exist. Reset it if it exists.""" + with self.cursor() as cur: + cur.execute( + "SELECT COUNT(1) FROM pg_database WHERE datname=%s;", (self.db_name,) + ) + db_exists = cur.fetchone()[0] == 1 + if db_exists: + cur.execute( + "UPDATE pg_database SET datallowconn=true WHERE datname = %s;", + (self.db_name,), + ) + self.db_reset() + return + + # initialize the inexistent db + with self.cursor() as cur: + cur.execute('CREATE DATABASE "{}";'.format(self.db_name)) + self.db_setup() + + def drop(self): + """The original DatabaseJanitor implementation prevents new connections from happening, + destroys current opened connections and finally drops the database. + + We actually do not want to drop the db so we instead do nothing and resets + (truncate most tables and sequences) the db instead, in order to have some + acceptable performance. + + """ + pass diff --git a/swh/core/db/pytest_plugin.py b/swh/core/db/pytest_plugin.py new file mode 100644 index 0000000..56e3305 --- /dev/null +++ b/swh/core/db/pytest_plugin.py @@ -0,0 +1,62 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import logging +from typing import Optional, Set + +import psycopg2 +import pytest +from pytest_postgresql import factories + +from swh.core.db.db_utils import SWHDatabaseJanitor + +logger = logging.getLogger(__name__) + + +# the postgres_fact factory fixture below is mostly a copy of the code +# from pytest-postgresql. We need a custom version here to be able to +# specify our version of the DBJanitor we use. +def postgresql_fact( + process_fixture_name: str, + db_name: Optional[str] = None, + dump_files: str = "", + no_truncate_tables: Set[str] = {"dbversion"}, +): + @pytest.fixture + def postgresql_factory(request): + """Fixture factory for PostgreSQL. + + :param FixtureRequest request: fixture request object + :rtype: psycopg2.connection + :returns: postgresql client + """ + config = factories.get_config(request) + proc_fixture = request.getfixturevalue(process_fixture_name) + + pg_host = proc_fixture.host + pg_port = proc_fixture.port + pg_user = proc_fixture.user + pg_options = proc_fixture.options + pg_db = db_name or config["dbname"] + with SWHDatabaseJanitor( + pg_user, + pg_host, + pg_port, + pg_db, + proc_fixture.version, + dump_files=dump_files, + no_truncate_tables=no_truncate_tables, + ): + connection = psycopg2.connect( + dbname=pg_db, + user=pg_user, + host=pg_host, + port=pg_port, + options=pg_options, + ) + yield connection + connection.close() + + return postgresql_factory diff --git a/swh/core/db/tests/data/0-schema.sql b/swh/core/db/tests/data/0-schema.sql new file mode 100644 index 0000000..e6b008b --- /dev/null +++ b/swh/core/db/tests/data/0-schema.sql @@ -0,0 +1,19 @@ +-- schema version table which won't get truncated +create table dbversion ( + version int primary key, + release timestamptz, + description text +); + +-- a people table which won't get truncated +create table people ( + fullname text not null +); + +-- a fun table which will get truncated for each test +create table fun ( + time timestamptz not null +); + +-- one sequence to check for reset as well +create sequence serial; diff --git a/swh/core/db/tests/data/1-data.sql b/swh/core/db/tests/data/1-data.sql new file mode 100644 index 0000000..8680263 --- /dev/null +++ b/swh/core/db/tests/data/1-data.sql @@ -0,0 +1,15 @@ +-- insert some values in dbversion +insert into dbversion(version, release, description) values (1, '2016-02-22 15:56:28.358587+00', 'Work In Progress'); +insert into dbversion(version, release, description) values (2, '2016-02-24 18:05:54.887217+00', 'Work In Progress'); +insert into dbversion(version, release, description) values (3, '2016-10-21 14:10:18.629763+00', 'Work In Progress'); +insert into dbversion(version, release, description) values (4, '2017-08-08 19:01:11.723113+00', 'Work In Progress'); +insert into dbversion(version, release, description) values (7, '2018-03-30 12:58:39.256679+00', 'Work In Progress'); + +insert into fun(time) values ('2020-10-19 09:00:00.666999+00'); +insert into fun(time) values ('2020-10-18 09:00:00.666999+00'); +insert into fun(time) values ('2020-10-17 09:00:00.666999+00'); + +select nextval('serial'); + +insert into people(fullname) values ('dudess'); +insert into people(fullname) values ('dude'); diff --git a/swh/core/db/tests/test_db_utils.py b/swh/core/db/tests/test_db_utils.py new file mode 100644 index 0000000..d4cb0fa --- /dev/null +++ b/swh/core/db/tests/test_db_utils.py @@ -0,0 +1,142 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import os + +from swh.core.db import BaseDb +from swh.core.db.pytest_plugin import postgresql_fact + +SQL_DIR = os.path.join(os.path.dirname(__file__), "data") + + +# db with special policy for tables dbversion and people +postgres_fun = postgresql_fact( + "postgresql_proc", + db_name="fun", + dump_files=f"{SQL_DIR}/*.sql", + no_truncate_tables={"dbversion", "people"}, +) + + +def test_smoke_test_fun_db_is_up(postgres_fun): + """This ensures the db is created and configured according to its dumps files. + + """ + with BaseDb.connect(postgres_fun.dsn).cursor() as cur: + cur.execute("select count(*) from dbversion") + nb_rows = cur.fetchone()[0] + assert nb_rows == 5 + + cur.execute("select count(*) from fun") + nb_rows = cur.fetchone()[0] + assert nb_rows == 3 + + cur.execute("select count(*) from people") + nb_rows = cur.fetchone()[0] + assert nb_rows == 2 + + # in data, we requested a value already so it starts at 2 + cur.execute("select nextval('serial')") + val = cur.fetchone()[0] + assert val == 2 + + +def test_smoke_test_fun_db_is_still_up_and_got_reset(postgres_fun): + """This ensures that within another tests, the 'fun' db is still up, created (and not + configured again). This time, most of the data has been reset: + - except for tables 'dbversion' and 'people' which were left as is + - the other tables from the schema (here only "fun") got truncated + - the sequences got truncated as well + + """ + with BaseDb.connect(postgres_fun.dsn).cursor() as cur: + # db version is excluded from the truncate + cur.execute("select count(*) from dbversion") + nb_rows = cur.fetchone()[0] + assert nb_rows == 5 + + # people is also allowed not to be truncated + cur.execute("select count(*) from people") + nb_rows = cur.fetchone()[0] + assert nb_rows == 2 + + # table and sequence are reset + cur.execute("select count(*) from fun") + nb_rows = cur.fetchone()[0] + assert nb_rows == 0 + + cur.execute("select nextval('serial')") + val = cur.fetchone()[0] + assert val == 1 + + +# db with no special policy for tables truncation, all tables are reset +postgres_people = postgresql_fact( + "postgresql_proc", + db_name="people", + dump_files=f"{SQL_DIR}/*.sql", + no_truncate_tables=set(), +) + + +def test_smoke_test_people_db_up(postgres_people): + """'people' db is up and configured + + """ + with BaseDb.connect(postgres_people.dsn).cursor() as cur: + cur.execute("select count(*) from dbversion") + nb_rows = cur.fetchone()[0] + assert nb_rows == 5 + + cur.execute("select count(*) from people") + nb_rows = cur.fetchone()[0] + assert nb_rows == 2 + + cur.execute("select count(*) from fun") + nb_rows = cur.fetchone()[0] + assert nb_rows == 3 + + cur.execute("select nextval('serial')") + val = cur.fetchone()[0] + assert val == 2 + + +def test_smoke_test_people_db_up_and_reset(postgres_people): + """'people' db is up and got reset on every tables and sequences + + """ + with BaseDb.connect(postgres_people.dsn).cursor() as cur: + # tables are truncated after the first round + cur.execute("select count(*) from dbversion") + nb_rows = cur.fetchone()[0] + assert nb_rows == 0 + + # tables are truncated after the first round + cur.execute("select count(*) from people") + nb_rows = cur.fetchone()[0] + assert nb_rows == 0 + + # table and sequence are reset + cur.execute("select count(*) from fun") + nb_rows = cur.fetchone()[0] + assert nb_rows == 0 + + cur.execute("select nextval('serial')") + val = cur.fetchone()[0] + assert val == 1 + + +# db with no initialization step, an empty db +postgres_no_init = postgresql_fact("postgresql_proc", db_name="something") + + +def test_smoke_test_db_no_init(postgres_no_init): + """We can connect to the db nonetheless + + """ + with BaseDb.connect(postgres_no_init.dsn).cursor() as cur: + cur.execute("select now()") + data = cur.fetchone()[0] + assert data is not None