Changeset View
Changeset View
Standalone View
Standalone View
swh/core/db/tests/db_testing.py
# Copyright (C) 2015-2018 The Software Heritage developers | # Copyright (C) 2015-2018 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
import os | import os | ||||
import glob | import glob | ||||
import subprocess | import subprocess | ||||
import psycopg2 | import psycopg2 | ||||
from typing import Dict, Iterable, Optional, Tuple, Union | from typing import Dict, Iterable, Optional, Tuple, Union | ||||
from swh.core.utils import numfile_sortkey as sortkey | from swh.core.utils import numfile_sortkey as sortkey | ||||
DB_DUMP_TYPES = {'.sql': 'psql', '.dump': 'pg_dump'} # type: Dict[str, str] | DB_DUMP_TYPES = {".sql": "psql", ".dump": "pg_dump"} # type: Dict[str, str] | ||||
def swh_db_version(dbname_or_service): | def swh_db_version(dbname_or_service): | ||||
"""Retrieve the swh version if any. In case of the db not initialized, | """Retrieve the swh version if any. In case of the db not initialized, | ||||
this returns None. Otherwise, this returns the db's version. | this returns None. Otherwise, this returns the db's version. | ||||
Args: | Args: | ||||
dbname_or_service (str): The db's name or service | dbname_or_service (str): The db's name or service | ||||
Returns: | Returns: | ||||
Optional[Int]: Either the db's version or None | Optional[Int]: Either the db's version or None | ||||
""" | """ | ||||
query = 'select version from dbversion order by dbversion desc limit 1' | query = "select version from dbversion order by dbversion desc limit 1" | ||||
cmd = [ | cmd = [ | ||||
'psql', '--tuples-only', '--no-psqlrc', '--quiet', | "psql", | ||||
'-v', 'ON_ERROR_STOP=1', "--command=%s" % query, | "--tuples-only", | ||||
dbname_or_service | "--no-psqlrc", | ||||
"--quiet", | |||||
"-v", | |||||
"ON_ERROR_STOP=1", | |||||
"--command=%s" % query, | |||||
dbname_or_service, | |||||
] | ] | ||||
try: | try: | ||||
r = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, | r = subprocess.run( | ||||
universal_newlines=True) | cmd, check=True, stdout=subprocess.PIPE, universal_newlines=True | ||||
) | |||||
result = int(r.stdout.strip()) | result = int(r.stdout.strip()) | ||||
except Exception: # db not initialized | except Exception: # db not initialized | ||||
result = None | result = None | ||||
return result | return result | ||||
def pg_restore(dbname, dumpfile, dumptype='pg_dump'): | def pg_restore(dbname, dumpfile, dumptype="pg_dump"): | ||||
""" | """ | ||||
Args: | Args: | ||||
dbname: name of the DB to restore into | dbname: name of the DB to restore into | ||||
dumpfile: path of the dump file | dumpfile: path of the dump file | ||||
dumptype: one of 'pg_dump' (for binary dumps), 'psql' (for SQL dumps) | dumptype: one of 'pg_dump' (for binary dumps), 'psql' (for SQL dumps) | ||||
""" | """ | ||||
assert dumptype in ['pg_dump', 'psql'] | assert dumptype in ["pg_dump", "psql"] | ||||
if dumptype == 'pg_dump': | if dumptype == "pg_dump": | ||||
subprocess.check_call(['pg_restore', '--no-owner', '--no-privileges', | subprocess.check_call( | ||||
'--dbname', dbname, dumpfile]) | [ | ||||
elif dumptype == 'psql': | "pg_restore", | ||||
subprocess.check_call(['psql', '--quiet', | "--no-owner", | ||||
'--no-psqlrc', | "--no-privileges", | ||||
'-v', 'ON_ERROR_STOP=1', | "--dbname", | ||||
'-f', dumpfile, | dbname, | ||||
dbname]) | dumpfile, | ||||
] | |||||
) | |||||
elif dumptype == "psql": | |||||
subprocess.check_call( | |||||
[ | |||||
"psql", | |||||
"--quiet", | |||||
"--no-psqlrc", | |||||
"-v", | |||||
"ON_ERROR_STOP=1", | |||||
"-f", | |||||
dumpfile, | |||||
dbname, | |||||
] | |||||
) | |||||
def pg_dump(dbname, dumpfile): | def pg_dump(dbname, dumpfile): | ||||
subprocess.check_call(['pg_dump', '--no-owner', '--no-privileges', '-Fc', | subprocess.check_call( | ||||
'-f', dumpfile, dbname]) | ["pg_dump", "--no-owner", "--no-privileges", "-Fc", "-f", dumpfile, dbname] | ||||
) | |||||
def pg_dropdb(dbname): | def pg_dropdb(dbname): | ||||
subprocess.check_call(['dropdb', dbname]) | subprocess.check_call(["dropdb", dbname]) | ||||
def pg_createdb(dbname, check=True): | def pg_createdb(dbname, check=True): | ||||
"""Create a db. If check is True and the db already exists, this will | """Create a db. If check is True and the db already exists, this will | ||||
raise an exception (original behavior). If check is False and | raise an exception (original behavior). If check is False and | ||||
the db already exists, this will fail silently. If the db does | the db already exists, this will fail silently. If the db does | ||||
not exist, the db will be created. | not exist, the db will be created. | ||||
""" | """ | ||||
subprocess.run(['createdb', dbname], check=check) | subprocess.run(["createdb", dbname], check=check) | ||||
def db_create(dbname, dumps=None): | def db_create(dbname, dumps=None): | ||||
"""create the test DB and load the test data dumps into it | """create the test DB and load the test data dumps into it | ||||
dumps is an iterable of couples (dump_file, dump_type). | dumps is an iterable of couples (dump_file, dump_type). | ||||
context: setUpClass | context: setUpClass | ||||
""" | """ | ||||
try: | try: | ||||
pg_createdb(dbname) | pg_createdb(dbname) | ||||
except subprocess.CalledProcessError: # try recovering once, in case | except subprocess.CalledProcessError: # try recovering once, in case | ||||
pg_dropdb(dbname) # the db already existed | pg_dropdb(dbname) # the db already existed | ||||
pg_createdb(dbname) | pg_createdb(dbname) | ||||
for dump, dtype in dumps: | for dump, dtype in dumps: | ||||
pg_restore(dbname, dump, dtype) | pg_restore(dbname, dump, dtype) | ||||
return dbname | return dbname | ||||
def db_destroy(dbname): | def db_destroy(dbname): | ||||
"""destroy the test DB | """destroy the test DB | ||||
context: tearDownClass | context: tearDownClass | ||||
""" | """ | ||||
pg_dropdb(dbname) | pg_dropdb(dbname) | ||||
def db_connect(dbname): | def db_connect(dbname): | ||||
"""connect to the test DB and open a cursor | """connect to the test DB and open a cursor | ||||
context: setUp | context: setUp | ||||
""" | """ | ||||
conn = psycopg2.connect('dbname=' + dbname) | conn = psycopg2.connect("dbname=" + dbname) | ||||
return { | return {"conn": conn, "cursor": conn.cursor()} | ||||
'conn': conn, | |||||
'cursor': conn.cursor() | |||||
} | |||||
def db_close(conn): | def db_close(conn): | ||||
"""rollback current transaction and disconnect from the test DB | """rollback current transaction and disconnect from the test DB | ||||
context: tearDown | context: tearDown | ||||
""" | """ | ||||
if not conn.closed: | if not conn.closed: | ||||
conn.rollback() | conn.rollback() | ||||
conn.close() | conn.close() | ||||
class DbTestConn: | class DbTestConn: | ||||
def __init__(self, dbname): | def __init__(self, dbname): | ||||
self.dbname = dbname | self.dbname = dbname | ||||
def __enter__(self): | def __enter__(self): | ||||
self.db_setup = db_connect(self.dbname) | self.db_setup = db_connect(self.dbname) | ||||
self.conn = self.db_setup['conn'] | self.conn = self.db_setup["conn"] | ||||
self.cursor = self.db_setup['cursor'] | self.cursor = self.db_setup["cursor"] | ||||
return self | return self | ||||
def __exit__(self, *_): | def __exit__(self, *_): | ||||
db_close(self.conn) | db_close(self.conn) | ||||
class DbTestContext: | class DbTestContext: | ||||
def __init__(self, name='softwareheritage-test', dumps=None): | def __init__(self, name="softwareheritage-test", dumps=None): | ||||
self.dbname = name | self.dbname = name | ||||
self.dumps = dumps | self.dumps = dumps | ||||
def __enter__(self): | def __enter__(self): | ||||
db_create(dbname=self.dbname, | db_create(dbname=self.dbname, dumps=self.dumps) | ||||
dumps=self.dumps) | |||||
return self | return self | ||||
def __exit__(self, *_): | def __exit__(self, *_): | ||||
db_destroy(self.dbname) | db_destroy(self.dbname) | ||||
class DbTestFixture: | class DbTestFixture: | ||||
"""Mix this in a test subject class to get DB testing support. | """Mix this in a test subject class to get DB testing support. | ||||
▲ Show 20 Lines • Show All 44 Lines • ▼ Show 20 Lines | class DbTestFixture: | ||||
""" | """ | ||||
_DB_DUMP_LIST = {} # type: Dict[str, Iterable[Tuple[str, str]]] | _DB_DUMP_LIST = {} # type: Dict[str, Iterable[Tuple[str, str]]] | ||||
_DB_LIST = {} # type: Dict[str, DbTestContext] | _DB_LIST = {} # type: Dict[str, DbTestContext] | ||||
DB_TEST_FIXTURE_IMPORTED = True | DB_TEST_FIXTURE_IMPORTED = True | ||||
@classmethod | @classmethod | ||||
def add_db(cls, name='softwareheritage-test', dumps=None): | def add_db(cls, name="softwareheritage-test", dumps=None): | ||||
cls._DB_DUMP_LIST[name] = dumps | cls._DB_DUMP_LIST[name] = dumps | ||||
@classmethod | @classmethod | ||||
def setUpClass(cls): | def setUpClass(cls): | ||||
for name, dumps in cls._DB_DUMP_LIST.items(): | for name, dumps in cls._DB_DUMP_LIST.items(): | ||||
cls._DB_LIST[name] = DbTestContext(name, dumps) | cls._DB_LIST[name] = DbTestContext(name, dumps) | ||||
cls._DB_LIST[name].__enter__() | cls._DB_LIST[name].__enter__() | ||||
super().setUpClass() | super().setUpClass() | ||||
Show All 16 Lines | def tearDown(self): | ||||
for name in self._DB_LIST.keys(): | for name in self._DB_LIST.keys(): | ||||
self.test_db[name].__exit__() | self.test_db[name].__exit__() | ||||
def reset_db_tables(self, name, excluded=None): | def reset_db_tables(self, name, excluded=None): | ||||
db = self.test_db[name] | db = self.test_db[name] | ||||
conn = db.conn | conn = db.conn | ||||
cursor = db.cursor | cursor = db.cursor | ||||
cursor.execute("""SELECT table_name FROM information_schema.tables | cursor.execute( | ||||
WHERE table_schema = %s""", ('public',)) | """SELECT table_name FROM information_schema.tables | ||||
WHERE table_schema = %s""", | |||||
("public",), | |||||
) | |||||
tables = set(table for (table,) in cursor.fetchall()) | tables = set(table for (table,) in cursor.fetchall()) | ||||
if excluded is not None: | if excluded is not None: | ||||
tables -= set(excluded) | tables -= set(excluded) | ||||
for table in tables: | for table in tables: | ||||
cursor.execute('truncate table %s cascade' % table) | cursor.execute("truncate table %s cascade" % table) | ||||
conn.commit() | conn.commit() | ||||
class SingleDbTestFixture(DbTestFixture): | class SingleDbTestFixture(DbTestFixture): | ||||
"""Simplified fixture like DbTest but that can only handle a single DB. | """Simplified fixture like DbTest but that can only handle a single DB. | ||||
Gives access to shortcuts like self.cursor and self.conn. | Gives access to shortcuts like self.cursor and self.conn. | ||||
Show All 15 Lines | class SingleDbTestFixture(DbTestFixture): | ||||
The test case class will then have the following attributes, accessible via | The test case class will then have the following attributes, accessible via | ||||
self: | self: | ||||
dbname: name of the test database | dbname: name of the test database | ||||
conn: psycopg2 connection object | conn: psycopg2 connection object | ||||
cursor: open psycopg2 cursor to the DB | cursor: open psycopg2 cursor to the DB | ||||
""" | """ | ||||
TEST_DB_NAME = 'softwareheritage-test' | TEST_DB_NAME = "softwareheritage-test" | ||||
TEST_DB_DUMP = None # type: Optional[Union[str, Iterable[str]]] | TEST_DB_DUMP = None # type: Optional[Union[str, Iterable[str]]] | ||||
@classmethod | @classmethod | ||||
def setUpClass(cls): | def setUpClass(cls): | ||||
cls.dbname = cls.TEST_DB_NAME # XXX to kill? | cls.dbname = cls.TEST_DB_NAME # XXX to kill? | ||||
dump_files = cls.TEST_DB_DUMP | dump_files = cls.TEST_DB_DUMP | ||||
if dump_files is None: | if dump_files is None: | ||||
dump_files = [] | dump_files = [] | ||||
elif isinstance(dump_files, str): | elif isinstance(dump_files, str): | ||||
dump_files = [dump_files] | dump_files = [dump_files] | ||||
all_dump_files = [] | all_dump_files = [] | ||||
for files in dump_files: | for files in dump_files: | ||||
all_dump_files.extend( | all_dump_files.extend(sorted(glob.glob(files), key=sortkey)) | ||||
sorted(glob.glob(files), key=sortkey)) | |||||
all_dump_files = [(x, DB_DUMP_TYPES[os.path.splitext(x)[1]]) | all_dump_files = [ | ||||
for x in all_dump_files] | (x, DB_DUMP_TYPES[os.path.splitext(x)[1]]) for x in all_dump_files | ||||
] | |||||
cls.add_db(name=cls.TEST_DB_NAME, | cls.add_db(name=cls.TEST_DB_NAME, dumps=all_dump_files) | ||||
dumps=all_dump_files) | |||||
super().setUpClass() | super().setUpClass() | ||||
def setUp(self, *args, **kwargs): | def setUp(self, *args, **kwargs): | ||||
super().setUp(*args, **kwargs) | super().setUp(*args, **kwargs) | ||||
db = self.test_db[self.TEST_DB_NAME] | db = self.test_db[self.TEST_DB_NAME] | ||||
self.conn = db.conn | self.conn = db.conn | ||||
self.cursor = db.cursor | self.cursor = db.cursor |