diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py index 86f81be..b3f8482 100755 --- a/swh/core/cli/db.py +++ b/swh/core/cli/db.py @@ -1,208 +1,209 @@ #!/usr/bin/env python3 # Copyright (C) 2018-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging from os import environ, path from typing import Tuple import warnings import click from swh.core.cli import CONTEXT_SETTINGS warnings.filterwarnings("ignore") # noqa prevent psycopg from telling us sh*t logger = logging.getLogger(__name__) @click.group(name="db", context_settings=CONTEXT_SETTINGS) @click.option( "--config-file", "-C", default=None, type=click.Path(exists=True, dir_okay=False), help="Configuration file.", ) @click.pass_context def db(ctx, config_file): """Software Heritage database generic tools. """ from swh.core.config import read as config_read ctx.ensure_object(dict) if config_file is None: config_file = environ.get("SWH_CONFIG_FILENAME") cfg = config_read(config_file) ctx.obj["config"] = cfg @db.command(name="init", context_settings=CONTEXT_SETTINGS) @click.pass_context def init(ctx): """Initialize the database for every Software Heritage module found in the configuration file. For every configuration section in the config file that: 1. has the name of an existing swh package, 2. has credentials for a local db access, it will run the initialization scripts from the swh package against the given database. Example for the config file:: \b storage: cls: local args: db: postgresql:///?service=swh-storage objstorage: cls: remote args: url: http://swh-objstorage:5003/ the command: swh db -C /path/to/config.yml init will initialize the database for the `storage` section using initialization scripts from the `swh.storage` package. """ for modname, cfg in ctx.obj["config"].items(): if cfg.get("cls") == "local" and cfg.get("args", {}).get("db"): try: initialized, dbversion = populate_database_for_package( modname, cfg["args"]["db"] ) except click.BadParameter: logger.info( "Failed to load/find sql initialization files for %s", modname ) click.secho( "DONE database for {} {} at version {}".format( modname, "initialized" if initialized else "exists", dbversion ), fg="green", bold=True, ) @click.command(context_settings=CONTEXT_SETTINGS) @click.argument("module", required=True) @click.option( "--db-name", "-d", help="Database name.", default="softwareheritage-dev", show_default=True, ) @click.option( "--create-db/--no-create-db", "-C", help="Attempt to create the database.", default=False, ) def db_init(module, db_name, create_db): """Initialize a database for the Software Heritage . By default, does not attempt to create the database. Example: swh db-init -d swh-test storage If you want to specify non-default postgresql connection parameters, please provide them using standard environment variables. See psql(1) man page (section ENVIRONMENTS) for details. Example: PGPORT=5434 swh db-init indexer """ logger.debug("db_init %s dn_name=%s", module, db_name) if create_db: from swh.core.db.tests.db_testing import pg_createdb # Create the db (or fail silently if already existing) pg_createdb(db_name, check=False) initialized, dbversion = populate_database_for_package(module, db_name) # TODO: Ideally migrate the version from db_version to the latest # db version click.secho( "DONE database for {} {} at version {}".format( module, "initialized" if initialized else "exists", dbversion ), fg="green", bold=True, ) def get_sql_for_package(modname): import glob from importlib import import_module from swh.core.utils import numfile_sortkey as sortkey if not modname.startswith("swh."): modname = "swh.{}".format(modname) try: m = import_module(modname) except ImportError: raise click.BadParameter("Unable to load module {}".format(modname)) sqldir = path.join(path.dirname(m.__file__), "sql") if not path.isdir(sqldir): raise click.BadParameter( "Module {} does not provide a db schema " "(no sql/ dir)".format(modname) ) return list(sorted(glob.glob(path.join(sqldir, "*.sql")), key=sortkey)) def populate_database_for_package(modname: str, conninfo: str) -> Tuple[bool, int]: """Populate the database, pointed at with `conninfo`, using the SQL files found in the package `modname`. Args: modname: Name of the module of which we're loading the files conninfo: connection info string for the SQL database Returns: Tuple with two elements: whether the database has been initialized; the current version of the database. """ import subprocess - from swh.core.db.tests.db_testing import swh_db_version + from swh.core.db.db_utils import swh_db_version current_version = swh_db_version(conninfo) if current_version is not None: return False, current_version sqlfiles = get_sql_for_package(modname) for sqlfile in sqlfiles: subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", conninfo, "-f", sqlfile, ] ) current_version = swh_db_version(conninfo) + assert current_version is not None return True, current_version diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py index f21989f..5b7e64d 100644 --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -1,153 +1,198 @@ # Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import functools +import logging import re +from typing import Optional, Union +import psycopg2 import psycopg2.extensions +logger = logging.getLogger(__name__) + def stored_procedure(stored_proc): """decorator to execute remote stored procedure, specified as argument Generally, the body of the decorated function should be empty. If it is not, the stored procedure will be executed first; the function body then. """ def wrap(meth): @functools.wraps(meth) def _meth(self, *args, **kwargs): cur = kwargs.get("cur", None) self._cursor(cur).execute("SELECT %s()" % stored_proc) meth(self, *args, **kwargs) return _meth return wrap def jsonize(value): """Convert a value to a psycopg2 JSON object if necessary""" if isinstance(value, dict): return psycopg2.extras.Json(value) return value +def swh_db_version( + db_or_conninfo: Union[str, psycopg2.extensions.connection] +) -> Optional[int]: + """Retrieve the swh version if any. In case of the db not initialized, + this returns None. Otherwise, this returns the db's version. + + Args: + db_or_conninfo: A database connection, or a database connection info string + + Returns: + Optional[Int]: Either the db's version or None + + """ + + if isinstance(db_or_conninfo, psycopg2.extensions.connection): + db = db_or_conninfo + else: + try: + if "=" not in db_or_conninfo: + # Database name + db_or_conninfo = f"dbname={db_or_conninfo}" + db = psycopg2.connect(db_or_conninfo) + except psycopg2.Error: + logger.exception("Failed to connect to `%s`", db_or_conninfo) + # Database not initialized + return None + + try: + with db.cursor() as c: + query = "select version from dbversion order by dbversion desc limit 1" + try: + c.execute(query) + return c.fetchone()[0] + except psycopg2.errors.UndefinedTable: + return None + except Exception: + logger.exception("Could not get version from `%s`", db_or_conninfo) + return None + + # The following code has been imported from psycopg2, version 2.7.4, # https://github.com/psycopg/psycopg2/tree/5afb2ce803debea9533e293eef73c92ffce95bcd # and modified by Software Heritage. # # Original file: lib/extras.py # # psycopg2 is free software: you can redistribute it and/or modify it under the # terms of the GNU Lesser General Public License as published by the Free # Software Foundation, either version 3 of the License, or (at your option) any # later version. def _paginate(seq, page_size): """Consume an iterable and return it in chunks. Every chunk is at most `page_size`. Never return an empty chunk. """ page = [] it = iter(seq) while 1: try: for i in range(page_size): page.append(next(it)) yield page page = [] except StopIteration: if page: yield page return def _split_sql(sql): """Split *sql* on a single ``%s`` placeholder. Split on the %s, perform %% replacement and return pre, post lists of snippets. """ curr = pre = [] post = [] tokens = re.split(br"(%.)", sql) for token in tokens: if len(token) != 2 or token[:1] != b"%": curr.append(token) continue if token[1:] == b"s": if curr is pre: curr = post else: raise ValueError("the query contains more than one '%s' placeholder") elif token[1:] == b"%": curr.append(b"%") else: raise ValueError( "unsupported format character: '%s'" % token[1:].decode("ascii", "replace") ) if curr is pre: raise ValueError("the query doesn't contain any '%s' placeholder") return pre, post def execute_values_generator(cur, sql, argslist, template=None, page_size=100): """Execute a statement using SQL ``VALUES`` with a sequence of parameters. Rows returned by the query are returned through a generator. You need to consume the generator for the queries to be executed! :param cur: the cursor to use to execute the query. :param sql: the query to execute. It must contain a single ``%s`` placeholder, which will be replaced by a `VALUES list`__. Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``. :param argslist: sequence of sequences or dictionaries with the arguments to send to the query. The type and content must be consistent with *template*. :param template: the snippet to merge to every item in *argslist* to compose the query. - If the *argslist* items are sequences it should contain positional placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there are constants value...). - If the *argslist* items are mappings it should contain named placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``). If not specified, assume the arguments are sequence and use a simple positional template (i.e. ``(%s, %s, ...)``), with the number of placeholders sniffed by the first element in *argslist*. :param page_size: maximum number of *argslist* items to include in every statement. If there are more items the function will execute more than one statement. :param yield_from_cur: Whether to yield results from the cursor in this function directly. .. __: https://www.postgresql.org/docs/current/static/queries-values.html After the execution of the function the `cursor.rowcount` property will **not** contain a total result. """ # we can't just use sql % vals because vals is bytes: if sql is bytes # there will be some decoding error because of stupid codec used, and Py3 # doesn't implement % on bytes. if not isinstance(sql, bytes): sql = sql.encode(psycopg2.extensions.encodings[cur.connection.encoding]) pre, post = _split_sql(sql) for page in _paginate(argslist, page_size=page_size): if template is None: template = b"(" + b",".join([b"%s"] * len(page[0])) + b")" parts = pre[:] for args in page: parts.append(cur.mogrify(template, args)) parts.append(b",") parts[-1:] = post cur.execute(b"".join(parts)) yield from cur diff --git a/swh/core/db/tests/db_testing.py b/swh/core/db/tests/db_testing.py index 9f6c01b..8284175 100644 --- a/swh/core/db/tests/db_testing.py +++ b/swh/core/db/tests/db_testing.py @@ -1,342 +1,305 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import glob import os import subprocess from typing import Dict, Iterable, Optional, Tuple, Union import psycopg2 from swh.core.utils import numfile_sortkey as sortkey DB_DUMP_TYPES = {".sql": "psql", ".dump": "pg_dump"} # type: Dict[str, str] -def swh_db_version(dbname_or_service): - """Retrieve the swh version if any. In case of the db not initialized, - this returns None. Otherwise, this returns the db's version. - - Args: - dbname_or_service (str): The db's name or service - - Returns: - Optional[Int]: Either the db's version or None - - """ - query = "select version from dbversion order by dbversion desc limit 1" - cmd = [ - "psql", - "--tuples-only", - "--no-psqlrc", - "--quiet", - "-v", - "ON_ERROR_STOP=1", - "--command=%s" % query, - dbname_or_service, - ] - - try: - r = subprocess.run( - cmd, - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - universal_newlines=True, - ) - result = int(r.stdout.strip()) - except Exception: # db not initialized - result = None - return result - - def pg_restore(dbname, dumpfile, dumptype="pg_dump"): """ Args: dbname: name of the DB to restore into dumpfile: path of the dump file dumptype: one of 'pg_dump' (for binary dumps), 'psql' (for SQL dumps) """ assert dumptype in ["pg_dump", "psql"] if dumptype == "pg_dump": subprocess.check_call( [ "pg_restore", "--no-owner", "--no-privileges", "--dbname", dbname, dumpfile, ] ) elif dumptype == "psql": subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-f", dumpfile, dbname, ] ) def pg_dump(dbname, dumpfile): subprocess.check_call( ["pg_dump", "--no-owner", "--no-privileges", "-Fc", "-f", dumpfile, dbname] ) def pg_dropdb(dbname): subprocess.check_call(["dropdb", dbname]) def pg_createdb(dbname, check=True): """Create a db. If check is True and the db already exists, this will raise an exception (original behavior). If check is False and the db already exists, this will fail silently. If the db does not exist, the db will be created. """ subprocess.run(["createdb", dbname], check=check) def db_create(dbname, dumps=None): """create the test DB and load the test data dumps into it dumps is an iterable of couples (dump_file, dump_type). context: setUpClass """ try: pg_createdb(dbname) except subprocess.CalledProcessError: # try recovering once, in case pg_dropdb(dbname) # the db already existed pg_createdb(dbname) for dump, dtype in dumps: pg_restore(dbname, dump, dtype) return dbname def db_destroy(dbname): """destroy the test DB context: tearDownClass """ pg_dropdb(dbname) def db_connect(dbname): """connect to the test DB and open a cursor context: setUp """ conn = psycopg2.connect("dbname=" + dbname) return {"conn": conn, "cursor": conn.cursor()} def db_close(conn): """rollback current transaction and disconnect from the test DB context: tearDown """ if not conn.closed: conn.rollback() conn.close() class DbTestConn: def __init__(self, dbname): self.dbname = dbname def __enter__(self): self.db_setup = db_connect(self.dbname) self.conn = self.db_setup["conn"] self.cursor = self.db_setup["cursor"] return self def __exit__(self, *_): db_close(self.conn) class DbTestContext: def __init__(self, name="softwareheritage-test", dumps=None): self.dbname = name self.dumps = dumps def __enter__(self): db_create(dbname=self.dbname, dumps=self.dumps) return self def __exit__(self, *_): db_destroy(self.dbname) class DbTestFixture: """Mix this in a test subject class to get DB testing support. Use the class method add_db() to add a new database to be tested. Using this will create a DbTestConn entry in the `test_db` dictionary for all the tests, indexed by the name of the database. Example: class TestDb(DbTestFixture, unittest.TestCase): @classmethod def setUpClass(cls): cls.add_db('db_name', DUMP) super().setUpClass() def setUp(self): db = self.test_db['db_name'] print('conn: {}, cursor: {}'.format(db.conn, db.cursor)) To ensure test isolation, each test method of the test case class will execute in its own connection, cursor, and transaction. Note that if you want to define setup/teardown methods, you need to explicitly call super() to ensure that the fixture setup/teardown methods are invoked. Here is an example where all setup/teardown methods are defined in a test case: class TestDb(DbTestFixture, unittest.TestCase): @classmethod def setUpClass(cls): # your add_db() calls here super().setUpClass() # your class setup code here def setUp(self): super().setUp() # your instance setup code here def tearDown(self): # your instance teardown code here super().tearDown() @classmethod def tearDownClass(cls): # your class teardown code here super().tearDownClass() """ _DB_DUMP_LIST = {} # type: Dict[str, Iterable[Tuple[str, str]]] _DB_LIST = {} # type: Dict[str, DbTestContext] DB_TEST_FIXTURE_IMPORTED = True @classmethod def add_db(cls, name="softwareheritage-test", dumps=None): cls._DB_DUMP_LIST[name] = dumps @classmethod def setUpClass(cls): for name, dumps in cls._DB_DUMP_LIST.items(): cls._DB_LIST[name] = DbTestContext(name, dumps) cls._DB_LIST[name].__enter__() super().setUpClass() @classmethod def tearDownClass(cls): super().tearDownClass() for name, context in cls._DB_LIST.items(): context.__exit__() def setUp(self, *args, **kwargs): self.test_db = {} for name in self._DB_LIST.keys(): self.test_db[name] = DbTestConn(name) self.test_db[name].__enter__() super().setUp(*args, **kwargs) def tearDown(self): super().tearDown() for name in self._DB_LIST.keys(): self.test_db[name].__exit__() def reset_db_tables(self, name, excluded=None): db = self.test_db[name] conn = db.conn cursor = db.cursor cursor.execute( """SELECT table_name FROM information_schema.tables WHERE table_schema = %s""", ("public",), ) tables = set(table for (table,) in cursor.fetchall()) if excluded is not None: tables -= set(excluded) for table in tables: cursor.execute("truncate table %s cascade" % table) conn.commit() class SingleDbTestFixture(DbTestFixture): """Simplified fixture like DbTest but that can only handle a single DB. Gives access to shortcuts like self.cursor and self.conn. DO NOT use this with other fixtures that need to access databases, like StorageTestFixture. The class can override the following class attributes: TEST_DB_NAME: name of the DB used for testing TEST_DB_DUMP: DB dump to be restored before running test methods; can be set to None if no restore from dump is required. If the dump file name endswith" - '.sql' it will be loaded via psql, - '.dump' it will be loaded via pg_restore. Other file extensions will be ignored. Can be a string or a list of strings; each path will be expanded using glob pattern matching. The test case class will then have the following attributes, accessible via self: dbname: name of the test database conn: psycopg2 connection object cursor: open psycopg2 cursor to the DB """ TEST_DB_NAME = "softwareheritage-test" TEST_DB_DUMP = None # type: Optional[Union[str, Iterable[str]]] @classmethod def setUpClass(cls): cls.dbname = cls.TEST_DB_NAME # XXX to kill? dump_files = cls.TEST_DB_DUMP if dump_files is None: dump_files = [] elif isinstance(dump_files, str): dump_files = [dump_files] all_dump_files = [] for files in dump_files: all_dump_files.extend(sorted(glob.glob(files), key=sortkey)) all_dump_files = [ (x, DB_DUMP_TYPES[os.path.splitext(x)[1]]) for x in all_dump_files ] cls.add_db(name=cls.TEST_DB_NAME, dumps=all_dump_files) super().setUpClass() def setUp(self, *args, **kwargs): super().setUp(*args, **kwargs) db = self.test_db[self.TEST_DB_NAME] self.conn = db.conn self.cursor = db.cursor