Page MenuHomeSoftware Heritage

D4300.id15219.diff
No OneTemporary

D4300.id15219.diff

diff --git a/mypy.ini b/mypy.ini
--- a/mypy.ini
+++ b/mypy.ini
@@ -38,6 +38,9 @@
[mypy-pytest.*]
ignore_missing_imports = True
+[mypy-pytest_postgresql.*]
+ignore_missing_imports = True
+
[mypy-requests_mock.*]
ignore_missing_imports = True
diff --git a/requirements-db.txt b/requirements-db.txt
--- a/requirements-db.txt
+++ b/requirements-db.txt
@@ -1,3 +1,4 @@
# requirements for swh.core.db
psycopg2
typing-extensions
+pytest-postgresql
diff --git a/requirements-test-db.txt b/requirements-test-db.txt
--- a/requirements-test-db.txt
+++ b/requirements-test-db.txt
@@ -1 +1 @@
-pytest-postgresql
+
diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py
--- a/swh/core/db/db_utils.py
+++ b/swh/core/db/db_utils.py
@@ -1,15 +1,20 @@
-# Copyright (C) 2015-2019 The Software Heritage developers
+# Copyright (C) 2015-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import functools
+import glob
import logging
import re
+import subprocess
from typing import Optional, Union
import psycopg2
import psycopg2.extensions
+from pytest_postgresql.janitor import DatabaseJanitor, Version
+
+from swh.core.utils import numfile_sortkey as sortkey
logger = logging.getLogger(__name__)
@@ -250,3 +255,100 @@
parts[-1:] = post
cur.execute(b"".join(parts))
yield from cur
+
+
+class SWHDatabaseJanitor(DatabaseJanitor):
+ """SwH database janitor implementation with a a different setup/teardown policy than
+ than the stock one. Instead of dropping, creating and initializing the database for
+ each test, it creates and initializes the db only once, then truncates the tables in
+ between tests.
+
+ This is needed to have acceptable test performances.
+
+ """
+
+ def __init__(
+ self,
+ user: str,
+ host: str,
+ port: str,
+ db_name: str,
+ version: Union[str, float, Version],
+ dump_files: str = "*.sql",
+ ) -> None:
+ super().__init__(user, host, port, db_name, version)
+ self.dump_files = sorted(glob.glob(dump_files), key=sortkey)
+
+ def db_setup(self):
+ conninfo = (
+ f"host={self.host} user={self.user} port={self.port} dbname={self.db_name}"
+ )
+
+ for fname in self.dump_files:
+ subprocess.check_call(
+ [
+ "psql",
+ "--quiet",
+ "--no-psqlrc",
+ "-v",
+ "ON_ERROR_STOP=1",
+ "-d",
+ conninfo,
+ "-f",
+ fname,
+ ]
+ )
+
+ def db_reset(self):
+ with psycopg2.connect(
+ dbname=self.db_name, user=self.user, host=self.host, port=self.port,
+ ) as cnx:
+ with cnx.cursor() as cur:
+ cur.execute(
+ "SELECT table_name FROM information_schema.tables "
+ "WHERE table_schema = %s",
+ ("public",),
+ )
+ tables = set(table for (table,) in cur.fetchall()) - {"dbversion"}
+ for table in tables:
+ cur.execute("truncate table %s cascade" % table)
+
+ cur.execute(
+ "SELECT sequence_name FROM information_schema.sequences "
+ "WHERE sequence_schema = %s",
+ ("public",),
+ )
+ seqs = set(seq for (seq,) in cur.fetchall())
+ for seq in seqs:
+ cur.execute("ALTER SEQUENCE %s RESTART;" % seq)
+ cnx.commit()
+
+ def init(self):
+ with self.cursor() as cur:
+ cur.execute(
+ "SELECT COUNT(1) FROM pg_database WHERE datname=%s;", (self.db_name,)
+ )
+ db_exists = cur.fetchone()[0] == 1
+ if db_exists:
+ cur.execute(
+ "UPDATE pg_database SET datallowconn=true WHERE datname = %s;",
+ (self.db_name,),
+ )
+ self.db_reset()
+ return
+
+ # initialize the inexistent db
+ with self.cursor() as cur:
+ cur.execute('CREATE DATABASE "{}";'.format(self.db_name))
+ self.db_setup()
+
+ def drop(self):
+ """The original DatabaseJanitor implementatonJanitor prevents new connections from
+ happening, destroys current opened connections and finally drops the
+ database.
+
+ We actually do not want to drop the db so we instead do nothing and truncates
+ the tables instead (in order to have some acceptable performance).
+
+ """
+ pass
diff --git a/swh/core/db/pytest_plugin.py b/swh/core/db/pytest_plugin.py
new file mode 100644
--- /dev/null
+++ b/swh/core/db/pytest_plugin.py
@@ -0,0 +1,59 @@
+# Copyright (C) 2020 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+import logging
+
+import psycopg2
+import pytest
+from pytest_postgresql import factories
+
+from swh.core.db.db_utils import SWHDatabaseJanitor
+
+logger = logging.getLogger(__name__)
+
+
+# the postgres_fact factory fixture below is mostly a copy of the code
+# from pytest-postgresql. We need a custom version here to be able to
+# specify our version of the DBJanitor we use.
+def postgresql_fact(process_fixture_name, db_name=None, dump_files=""):
+ @pytest.fixture
+ def postgresql_factory(request):
+ """
+ Fixture factory for PostgreSQL.
+
+ :param FixtureRequest request: fixture request object
+ :rtype: psycopg2.connection
+ :returns: postgresql client
+ """
+ config = factories.get_config(request)
+ if not psycopg2:
+ raise ImportError("No module named psycopg2. Please install it.")
+ proc_fixture = request.getfixturevalue(process_fixture_name)
+
+ # _, config = try_import('psycopg2', request)
+ pg_host = proc_fixture.host
+ pg_port = proc_fixture.port
+ pg_user = proc_fixture.user
+ pg_options = proc_fixture.options
+ pg_db = db_name or config["dbname"]
+ with SWHDatabaseJanitor(
+ pg_user,
+ pg_host,
+ pg_port,
+ pg_db,
+ proc_fixture.version,
+ dump_files=dump_files,
+ ):
+ connection = psycopg2.connect(
+ dbname=pg_db,
+ user=pg_user,
+ host=pg_host,
+ port=pg_port,
+ options=pg_options,
+ )
+ yield connection
+ connection.close()
+
+ return postgresql_factory
diff --git a/swh/core/db/tests/data/0-schema.sql b/swh/core/db/tests/data/0-schema.sql
new file mode 100644
--- /dev/null
+++ b/swh/core/db/tests/data/0-schema.sql
@@ -0,0 +1,16 @@
+-- schema version table which won't get truncated
+create table dbversion
+(
+ version int primary key,
+ release timestamptz,
+ description text
+);
+
+-- a fun table which will get truncated for each test
+create table fun
+(
+ time timestamptz not null
+);
+
+-- one sequence to check for reset as well
+create sequence serial;
diff --git a/swh/core/db/tests/data/1-data.sql b/swh/core/db/tests/data/1-data.sql
new file mode 100644
--- /dev/null
+++ b/swh/core/db/tests/data/1-data.sql
@@ -0,0 +1,12 @@
+-- insert some values in dbversion
+insert into dbversion(version, release, description) values (1, '2016-02-22 15:56:28.358587+00', 'Work In Progress');
+insert into dbversion(version, release, description) values (2, '2016-02-24 18:05:54.887217+00', 'Work In Progress');
+insert into dbversion(version, release, description) values (3, '2016-10-21 14:10:18.629763+00', 'Work In Progress');
+insert into dbversion(version, release, description) values (4, '2017-08-08 19:01:11.723113+00', 'Work In Progress');
+insert into dbversion(version, release, description) values (7, '2018-03-30 12:58:39.256679+00', 'Work In Progress');
+
+insert into fun(time) values ('2020-10-19 09:00:00.666999+00');
+insert into fun(time) values ('2020-10-18 09:00:00.666999+00');
+insert into fun(time) values ('2020-10-17 09:00:00.666999+00');
+
+select nextval('serial');
diff --git a/swh/core/db/tests/test_db_utils.py b/swh/core/db/tests/test_db_utils.py
new file mode 100644
--- /dev/null
+++ b/swh/core/db/tests/test_db_utils.py
@@ -0,0 +1,67 @@
+# Copyright (C) 2020 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+import os
+
+import pytest
+
+from swh.core.db import BaseDb
+from swh.core.db.pytest_plugin import postgresql_fact
+
+SQL_DIR = os.path.join(os.path.dirname(__file__), "data")
+
+
+postgres_fun = postgresql_fact(
+ "postgresql_proc", db_name="fun", dump_files=f"{SQL_DIR}/*.sql"
+)
+
+
+@pytest.fixture()
+def db_conn(postgres_fun):
+ db = BaseDb.connect(postgres_fun.dsn)
+ return db
+
+
+def test_smoke_test_db_is_up(db_conn):
+ """This ensures the db is created and configured according to its dumps files.
+
+ """
+ with db_conn.cursor() as cur:
+ cur.execute("select count(*) from dbversion")
+ nb_rows = cur.fetchone()[0]
+ assert nb_rows == 5
+
+ cur.execute("select count(*) from fun")
+ nb_rows = cur.fetchone()[0]
+ assert nb_rows == 3
+
+ # in data, we requested a value already so it starts at 2
+ cur.execute("select nextval('serial')")
+ val = cur.fetchone()[0]
+ assert val == 2
+
+
+def test_smoke_test_db_is_still_up_and_got_reset(db_conn):
+ """This ensures that within another tests, the db is still up, created (and not
+ configured again). This time, most of the data has been reset:
+ - table dbversion was left as is
+ - the other tables from the schema (here only "fun") got truncated
+ - the sequences got truncated as well
+
+ """
+ with db_conn.cursor() as cur:
+ # db version is excluded from the reset
+ cur.execute("select count(*) from dbversion")
+ nb_rows = cur.fetchone()[0]
+ assert nb_rows == 5
+
+ # table and sequence are reset
+ cur.execute("select count(*) from fun")
+ nb_rows = cur.fetchone()[0]
+ assert nb_rows == 0
+
+ cur.execute("select nextval('serial')")
+ val = cur.fetchone()[0]
+ assert val == 1

File Metadata

Mime Type
text/plain
Expires
Nov 5 2024, 1:15 PM (18 w, 4 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3228598

Event Timeline