diff --git a/requirements-test.txt b/requirements-test.txt --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,2 +1,3 @@ -pytest +pytest < 4 +pytest-postgresql requests-mock diff --git a/swh/core/db/__init__.py b/swh/core/db/__init__.py new file mode 100644 diff --git a/swh/core/db/common.py b/swh/core/db/common.py new file mode 100644 --- /dev/null +++ b/swh/core/db/common.py @@ -0,0 +1,80 @@ +# Copyright (C) 2015-2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import inspect +import functools + + +def apply_options(cursor, options): + """Applies the given postgresql client options to the given cursor. + + Returns a dictionary with the old values if they changed.""" + old_options = {} + for option, value in options.items(): + cursor.execute('SHOW %s' % option) + old_value = cursor.fetchall()[0][0] + if old_value != value: + cursor.execute('SET LOCAL %s TO %%s' % option, (value,)) + old_options[option] = old_value + return old_options + + +def db_transaction(**client_options): + """decorator to execute Backend methods within DB transactions + + The decorated method must accept a `cur` and `db` keyword argument + + Client options are passed as `set` options to the postgresql server + """ + def decorator(meth, __client_options=client_options): + if inspect.isgeneratorfunction(meth): + raise ValueError( + 'Use db_transaction_generator for generator functions.') + + @functools.wraps(meth) + def _meth(self, *args, **kwargs): + if 'cur' in kwargs and kwargs['cur']: + cur = kwargs['cur'] + old_options = apply_options(cur, __client_options) + ret = meth(self, *args, **kwargs) + apply_options(cur, old_options) + return ret + else: + db = self.get_db() + with db.transaction() as cur: + apply_options(cur, __client_options) + return meth(self, *args, db=db, cur=cur, **kwargs) + return _meth + + return decorator + + +def db_transaction_generator(**client_options): + """decorator to execute Backend methods within DB transactions, while + returning a generator + + The decorated method must accept a `cur` and `db` keyword argument + + Client options are passed as `set` options to the postgresql server + """ + def decorator(meth, __client_options=client_options): + if not inspect.isgeneratorfunction(meth): + raise ValueError( + 'Use db_transaction for non-generator functions.') + + @functools.wraps(meth) + def _meth(self, *args, **kwargs): + if 'cur' in kwargs and kwargs['cur']: + cur = kwargs['cur'] + old_options = apply_options(cur, __client_options) + yield from meth(self, *args, **kwargs) + apply_options(cur, old_options) + else: + db = self.get_db() + with db.transaction() as cur: + apply_options(cur, __client_options) + yield from meth(self, *args, db=db, cur=cur, **kwargs) + return _meth + return decorator diff --git a/swh/core/db/db.py b/swh/core/db/db.py new file mode 100644 --- /dev/null +++ b/swh/core/db/db.py @@ -0,0 +1,218 @@ +# Copyright (C) 2015-2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import binascii +import datetime +import enum +import functools +import json +import os +import threading + +from contextlib import contextmanager + +import psycopg2 +import psycopg2.extras + +from .db_utils import execute_values_generator + +TMP_CONTENT_TABLE = 'tmp_content' + + +psycopg2.extras.register_uuid() + + +def stored_procedure(stored_proc): + """decorator to execute remote stored procedure, specified as argument + + Generally, the body of the decorated function should be empty. If it is + not, the stored procedure will be executed first; the function body then. + + """ + def wrap(meth): + @functools.wraps(meth) + def _meth(self, *args, **kwargs): + cur = kwargs.get('cur', None) + self._cursor(cur).execute('SELECT %s()' % stored_proc) + meth(self, *args, **kwargs) + return _meth + return wrap + + +def jsonize(value): + """Convert a value to a psycopg2 JSON object if necessary""" + if isinstance(value, dict): + return psycopg2.extras.Json(value) + + return value + + +def entry_to_bytes(entry): + """Convert an entry coming from the database to bytes""" + if isinstance(entry, memoryview): + return entry.tobytes() + if isinstance(entry, list): + return [entry_to_bytes(value) for value in entry] + return entry + + +def line_to_bytes(line): + """Convert a line coming from the database to bytes""" + if not line: + return line + if isinstance(line, dict): + return {k: entry_to_bytes(v) for k, v in line.items()} + return line.__class__(entry_to_bytes(entry) for entry in line) + + +def cursor_to_bytes(cursor): + """Yield all the data from a cursor as bytes""" + yield from (line_to_bytes(line) for line in cursor) + + +def execute_values_to_bytes(*args, **kwargs): + for line in execute_values_generator(*args, **kwargs): + yield line_to_bytes(line) + + +class BaseDb: + """Base class for swh.*.*Db. + + cf. swh.storage.db.Db, swh.archiver.db.ArchiverDb + + """ + + @classmethod + def connect(cls, *args, **kwargs): + """factory method to create a DB proxy + + Accepts all arguments of psycopg2.connect; only some specific + possibilities are reported below. + + Args: + connstring: libpq2 connection string + + """ + conn = psycopg2.connect(*args, **kwargs) + return cls(conn) + + @classmethod + def from_pool(cls, pool): + return cls(pool.getconn(), pool=pool) + + def _cursor(self, cur_arg): + """get a cursor: from cur_arg if given, or a fresh one otherwise + + meant to avoid boilerplate if/then/else in methods that proxy stored + procedures + + """ + if cur_arg is not None: + return cur_arg + else: + return self.conn.cursor() + + def __init__(self, conn, pool=None): + """create a DB proxy + + Args: + conn: psycopg2 connection to the SWH DB + pool: psycopg2 pool of connections + + """ + self.conn = conn + self.pool = pool + + def __del__(self): + if self.pool: + self.pool.putconn(self.conn) + + @contextmanager + def transaction(self): + """context manager to execute within a DB transaction + + Yields: + a psycopg2 cursor + + """ + with self.conn.cursor() as cur: + try: + yield cur + self.conn.commit() + except Exception: + if not self.conn.closed: + self.conn.rollback() + raise + + def copy_to(self, items, tblname, columns, cur=None, item_cb=None): + """Copy items' entries to table tblname with columns information. + + Args: + items (dict): dictionary of data to copy over tblname + tblname (str): Destination table's name + columns ([str]): keys to access data in items and also the + column names in the destination table. + item_cb (fn): optional function to apply to items's entry + + """ + def escape(data): + if data is None: + return '' + if isinstance(data, bytes): + return '\\x%s' % binascii.hexlify(data).decode('ascii') + elif isinstance(data, str): + return '"%s"' % data.replace('"', '""') + elif isinstance(data, datetime.datetime): + # We escape twice to make sure the string generated by + # isoformat gets escaped + return escape(data.isoformat()) + elif isinstance(data, dict): + return escape(json.dumps(data)) + elif isinstance(data, list): + return escape("{%s}" % ','.join(escape(d) for d in data)) + elif isinstance(data, psycopg2.extras.Range): + # We escape twice here too, so that we make sure + # everything gets passed to copy properly + return escape( + '%s%s,%s%s' % ( + '[' if data.lower_inc else '(', + '-infinity' if data.lower_inf else escape(data.lower), + 'infinity' if data.upper_inf else escape(data.upper), + ']' if data.upper_inc else ')', + ) + ) + elif isinstance(data, enum.IntEnum): + return escape(int(data)) + else: + # We don't escape here to make sure we pass literals properly + return str(data) + + read_file, write_file = os.pipe() + + def writer(): + cursor = self._cursor(cur) + with open(read_file, 'r') as f: + cursor.copy_expert('COPY %s (%s) FROM STDIN CSV' % ( + tblname, ', '.join(columns)), f) + + write_thread = threading.Thread(target=writer) + write_thread.start() + + try: + with open(write_file, 'w') as f: + for d in items: + if item_cb is not None: + item_cb(d) + line = [escape(d.get(k)) for k in columns] + f.write(','.join(line)) + f.write('\n') + finally: + # No problem bubbling up exceptions, but we still need to make sure + # we finish copying, even though we're probably going to cancel the + # transaction. + write_thread.join() + + def mktemp(self, tblname, cur=None): + self._cursor(cur).execute('SELECT swh_mktemp(%s)', (tblname,)) diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py new file mode 100644 --- /dev/null +++ b/swh/core/db/db_utils.py @@ -0,0 +1,123 @@ +# Copyright (C) 2015-2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# +# This code has been imported from psycopg2, version 2.7.4, +# https://github.com/psycopg/psycopg2/tree/5afb2ce803debea9533e293eef73c92ffce95bcd +# and modified by Software Heritage. +# +# Original file: lib/extras.py +# +# psycopg2 is free software: you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 3 of the License, or (at your option) any +# later version. + + +import re + +import psycopg2.extensions + + +def _paginate(seq, page_size): + """Consume an iterable and return it in chunks. + Every chunk is at most `page_size`. Never return an empty chunk. + """ + page = [] + it = iter(seq) + while 1: + try: + for i in range(page_size): + page.append(next(it)) + yield page + page = [] + except StopIteration: + if page: + yield page + return + + +def _split_sql(sql): + """Split *sql* on a single ``%s`` placeholder. + Split on the %s, perform %% replacement and return pre, post lists of + snippets. + """ + curr = pre = [] + post = [] + tokens = re.split(br'(%.)', sql) + for token in tokens: + if len(token) != 2 or token[:1] != b'%': + curr.append(token) + continue + + if token[1:] == b's': + if curr is pre: + curr = post + else: + raise ValueError( + "the query contains more than one '%s' placeholder") + elif token[1:] == b'%': + curr.append(b'%') + else: + raise ValueError("unsupported format character: '%s'" + % token[1:].decode('ascii', 'replace')) + + if curr is pre: + raise ValueError("the query doesn't contain any '%s' placeholder") + + return pre, post + + +def execute_values_generator(cur, sql, argslist, template=None, page_size=100): + '''Execute a statement using SQL ``VALUES`` with a sequence of parameters. + Rows returned by the query are returned through a generator. + You need to consume the generator for the queries to be executed! + + :param cur: the cursor to use to execute the query. + :param sql: the query to execute. It must contain a single ``%s`` + placeholder, which will be replaced by a `VALUES list`__. + Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``. + :param argslist: sequence of sequences or dictionaries with the arguments + to send to the query. The type and content must be consistent with + *template*. + :param template: the snippet to merge to every item in *argslist* to + compose the query. + + - If the *argslist* items are sequences it should contain positional + placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there + are constants value...). + - If the *argslist* items are mappings it should contain named + placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``). + + If not specified, assume the arguments are sequence and use a simple + positional template (i.e. ``(%s, %s, ...)``), with the number of + placeholders sniffed by the first element in *argslist*. + :param page_size: maximum number of *argslist* items to include in every + statement. If there are more items the function will execute more than + one statement. + :param yield_from_cur: Whether to yield results from the cursor in this + function directly. + + .. __: https://www.postgresql.org/docs/current/static/queries-values.html + + After the execution of the function the `cursor.rowcount` property will + **not** contain a total result. + ''' + # we can't just use sql % vals because vals is bytes: if sql is bytes + # there will be some decoding error because of stupid codec used, and Py3 + # doesn't implement % on bytes. + if not isinstance(sql, bytes): + sql = sql.encode( + psycopg2.extensions.encodings[cur.connection.encoding] + ) + pre, post = _split_sql(sql) + + for page in _paginate(argslist, page_size=page_size): + if template is None: + template = b'(' + b','.join([b'%s'] * len(page[0])) + b')' + parts = pre[:] + for args in page: + parts.append(cur.mogrify(template, args)) + parts.append(b',') + parts[-1:] = post + cur.execute(b''.join(parts)) + yield from cur diff --git a/swh/core/tests/test_logger.py b/swh/core/tests/test_logger.py --- a/swh/core/tests/test_logger.py +++ b/swh/core/tests/test_logger.py @@ -5,41 +5,46 @@ import logging import os -import unittest import pytest from swh.core.logger import PostgresHandler -from swh.core.tests.db_testing import SingleDbTestFixture from swh.core.tests import SQL_DIR +DUMP_FILE = os.path.join(SQL_DIR, 'log-schema.sql') -@pytest.mark.db -class PgLogHandler(SingleDbTestFixture, unittest.TestCase): - TEST_DB_DUMP = os.path.join(SQL_DIR, 'log-schema.sql') +@pytest.fixture +def swh_db_logger(postgresql_proc, postgresql): - def setUp(self): - super().setUp() - self.modname = 'swh.core.tests.test_logger' - self.logger = logging.Logger(self.modname, logging.DEBUG) - self.logger.addHandler(PostgresHandler('dbname=' + self.TEST_DB_NAME)) + cursor = postgresql.cursor() + with open(DUMP_FILE) as fobj: + cursor.execute(fobj.read()) + postgresql.commit() + modname = 'swh.core.tests.test_logger' + logger = logging.Logger(modname, logging.DEBUG) + dsn = 'postgresql://{user}@{host}:{port}/{dbname}'.format( + host=postgresql_proc.host, + port=postgresql_proc.port, + user='postgres', + dbname='tests') + logger.addHandler(PostgresHandler(dsn)) + return logger - def tearDown(self): - logging.shutdown() - super().tearDown() - def test_log(self): - self.logger.info('notice', - extra={'swh_type': 'test entry', 'swh_data': 42}) - self.logger.warning('warning') +def test_log(swh_db_logger, postgresql): + logger = swh_db_logger + modname = logger.name - with self.conn.cursor() as cur: - cur.execute('SELECT level, message, data, src_module FROM log') - db_log_entries = cur.fetchall() + logger.info('notice', + extra={'swh_type': 'test entry', 'swh_data': 42}) + logger.warning('warning') - self.assertIn(('info', 'notice', {'type': 'test entry', 'data': 42}, - self.modname), - db_log_entries) - self.assertIn(('warning', 'warning', {}, self.modname), db_log_entries) + with postgresql.cursor() as cur: + cur.execute('SELECT level, message, data, src_module FROM log') + db_log_entries = cur.fetchall() + + assert ('info', 'notice', {'type': 'test entry', 'data': 42}, + modname) in db_log_entries + assert ('warning', 'warning', {}, modname) in db_log_entries diff --git a/tox.ini b/tox.ini --- a/tox.ini +++ b/tox.ini @@ -5,9 +5,8 @@ deps = .[testing] pytest-cov - pifpaf commands = - pifpaf run postgresql -- pytest --cov=swh --cov-branch {posargs} + pytest --cov=swh --cov-branch {posargs} [testenv:flake8] skip_install = true