diff --git a/swh/core/tests/db_testing.py b/swh/core/tests/db_testing.py --- a/swh/core/tests/db_testing.py +++ b/swh/core/tests/db_testing.py @@ -4,9 +4,14 @@ # See top-level LICENSE file for more information import os +import glob import psycopg2 import subprocess +from swh.core.utils import numfile_sortkey as sortkey + +DB_DUMP_TYPES = {'.sql': 'psql', '.dump': 'pg_dump'} + def pg_restore(dbname, dumpfile, dumptype='pg_dump'): """ @@ -40,8 +45,10 @@ subprocess.check_call(['createdb', dbname]) -def db_create(dbname, dump=None, dumptype='pg_dump'): - """create the test DB and load the test data dump into it +def db_create(dbname, dumps=None): + """create the test DB and load the test data dumps into it + + dumps is an iterable of couples (dump_file, dump_type). context: setUpClass @@ -51,8 +58,8 @@ except subprocess.CalledProcessError: # try recovering once, in case pg_dropdb(dbname) # the db already existed pg_createdb(dbname) - if dump: - pg_restore(dbname, dump, dumptype) + for dump, dtype in dumps: + pg_restore(dbname, dump, dtype) return dbname @@ -104,16 +111,13 @@ class DbTestContext: - def __init__(self, name='softwareheritage-test', dump=None, - dump_type='pg_dump'): + def __init__(self, name='softwareheritage-test', dumps=None): self.dbname = name - self.dump = dump - self.dump_type = dump_type + self.dumps = dumps def __enter__(self): db_create(dbname=self.dbname, - dump=self.dump, - dumptype=self.dump_type) + dumps=self.dumps) return self def __exit__(self, *_): @@ -174,14 +178,13 @@ DB_TEST_FIXTURE_IMPORTED = True @classmethod - def add_db(cls, name='softwareheritage-test', dump=None, - dump_type='pg_dump'): - cls._DB_DUMP_LIST[name] = (dump, dump_type) + def add_db(cls, name='softwareheritage-test', dumps=None): + cls._DB_DUMP_LIST[name] = dumps @classmethod def setUpClass(cls): - for name, (dump, dump_type) in cls._DB_DUMP_LIST.items(): - cls._DB_LIST[name] = DbTestContext(name, dump, dump_type) + for name, dumps in cls._DB_DUMP_LIST.items(): + cls._DB_LIST[name] = DbTestContext(name, dumps) cls._DB_LIST[name].__enter__() super().setUpClass() @@ -247,17 +250,19 @@ TEST_DB_NAME = 'softwareheritage-test' TEST_DB_DUMP = None TEST_DB_DUMP_TYPE = None - DB_DUMP_TYPES = {'.sql': 'psql', '.dump': 'pg_dump'} @classmethod def setUpClass(cls): cls.dbname = cls.TEST_DB_NAME - dump_type = (cls.TEST_DB_DUMP_TYPE or - cls.DB_DUMP_TYPES[os.path.splitext(cls.TEST_DB_DUMP)[-1]]) + + dump_files = sorted(glob.glob(cls.TEST_DB_DUMP), + key=sortkey) + + dump_files = [(x, DB_DUMP_TYPES[os.path.splitext(x)[1]]) + for x in dump_files] cls.add_db(name=cls.TEST_DB_NAME, - dump=cls.TEST_DB_DUMP, - dump_type=dump_type) + dumps=dump_files) super().setUpClass() def setUp(self): diff --git a/swh/core/utils.py b/swh/core/utils.py --- a/swh/core/utils.py +++ b/swh/core/utils.py @@ -6,6 +6,7 @@ import os import itertools import codecs +import re from contextlib import contextmanager @@ -101,3 +102,16 @@ """ return path1.split(path0)[1] + + +def numfile_sortkey(fname): + """Simple function to sort filenames of the form: + + nnxxx.ext + + where nn is a number according to the numbers. + + Typically used to sort sql/nn-swh-xxx.sql files. + """ + num, rem = re.match(r'(\d*)(.*)', fname).groups() + return (num and int(num) or 99, rem)