diff --git a/swh/storage/common.py b/swh/storage/common.py --- a/swh/storage/common.py +++ b/swh/storage/common.py @@ -3,78 +3,4 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import inspect -import functools - - -def apply_options(cursor, options): - """Applies the given postgresql client options to the given cursor. - - Returns a dictionary with the old values if they changed.""" - old_options = {} - for option, value in options.items(): - cursor.execute('SHOW %s' % option) - old_value = cursor.fetchall()[0][0] - if old_value != value: - cursor.execute('SET LOCAL %s TO %%s' % option, (value,)) - old_options[option] = old_value - return old_options - - -def db_transaction(**client_options): - """decorator to execute Storage methods within DB transactions - - The decorated method must accept a `cur` and `db` keyword argument - - Client options are passed as `set` options to the postgresql server - """ - def decorator(meth, __client_options=client_options): - if inspect.isgeneratorfunction(meth): - raise ValueError( - 'Use db_transaction_generator for generator functions.') - - @functools.wraps(meth) - def _meth(self, *args, **kwargs): - if 'cur' in kwargs and kwargs['cur']: - cur = kwargs['cur'] - old_options = apply_options(cur, __client_options) - ret = meth(self, *args, **kwargs) - apply_options(cur, old_options) - return ret - else: - db = self.get_db() - with db.transaction() as cur: - apply_options(cur, __client_options) - return meth(self, *args, db=db, cur=cur, **kwargs) - return _meth - - return decorator - - -def db_transaction_generator(**client_options): - """decorator to execute Storage methods within DB transactions, while - returning a generator - - The decorated method must accept a `cur` and `db` keyword argument - - Client options are passed as `set` options to the postgresql server - """ - def decorator(meth, __client_options=client_options): - if not inspect.isgeneratorfunction(meth): - raise ValueError( - 'Use db_transaction for non-generator functions.') - - @functools.wraps(meth) - def _meth(self, *args, **kwargs): - if 'cur' in kwargs and kwargs['cur']: - cur = kwargs['cur'] - old_options = apply_options(cur, __client_options) - yield from meth(self, *args, **kwargs) - apply_options(cur, old_options) - else: - db = self.get_db() - with db.transaction() as cur: - apply_options(cur, __client_options) - yield from meth(self, *args, db=db, cur=cur, **kwargs) - return _meth - return decorator +from swh.core.db.common import * # noqa diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -1,224 +1,13 @@ -# Copyright (C) 2015-2018 The Software Heritage developers +# Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import binascii -import datetime -import enum -import functools -import json -import os import select -import threading -from contextlib import contextmanager - -import psycopg2 -import psycopg2.extras - -from .db_utils import execute_values_generator - -TMP_CONTENT_TABLE = 'tmp_content' - - -psycopg2.extras.register_uuid() - - -def stored_procedure(stored_proc): - """decorator to execute remote stored procedure, specified as argument - - Generally, the body of the decorated function should be empty. If it is - not, the stored procedure will be executed first; the function body then. - - """ - def wrap(meth): - @functools.wraps(meth) - def _meth(self, *args, **kwargs): - cur = kwargs.get('cur', None) - self._cursor(cur).execute('SELECT %s()' % stored_proc) - meth(self, *args, **kwargs) - return _meth - return wrap - - -def jsonize(value): - """Convert a value to a psycopg2 JSON object if necessary""" - if isinstance(value, dict): - return psycopg2.extras.Json(value) - - return value - - -def entry_to_bytes(entry): - """Convert an entry coming from the database to bytes""" - if isinstance(entry, memoryview): - return entry.tobytes() - if isinstance(entry, list): - return [entry_to_bytes(value) for value in entry] - return entry - - -def line_to_bytes(line): - """Convert a line coming from the database to bytes""" - if not line: - return line - if isinstance(line, dict): - return {k: entry_to_bytes(v) for k, v in line.items()} - return line.__class__(entry_to_bytes(entry) for entry in line) - - -def cursor_to_bytes(cursor): - """Yield all the data from a cursor as bytes""" - yield from (line_to_bytes(line) for line in cursor) - - -def execute_values_to_bytes(*args, **kwargs): - for line in execute_values_generator(*args, **kwargs): - yield line_to_bytes(line) - - -class BaseDb: - """Base class for swh.storage.*Db. - - cf. swh.storage.db.Db, swh.archiver.db.ArchiverDb - - """ - - @classmethod - def connect(cls, *args, **kwargs): - """factory method to create a DB proxy - - Accepts all arguments of psycopg2.connect; only some specific - possibilities are reported below. - - Args: - connstring: libpq2 connection string - - """ - conn = psycopg2.connect(*args, **kwargs) - return cls(conn) - - @classmethod - def from_pool(cls, pool): - return cls(pool.getconn(), pool=pool) - - def _cursor(self, cur_arg): - """get a cursor: from cur_arg if given, or a fresh one otherwise - - meant to avoid boilerplate if/then/else in methods that proxy stored - procedures - - """ - if cur_arg is not None: - return cur_arg - # elif self.cur is not None: - # return self.cur - else: - return self.conn.cursor() - - def __init__(self, conn, pool=None): - """create a DB proxy - - Args: - conn: psycopg2 connection to the SWH DB - pool: psycopg2 pool of connections - - """ - self.conn = conn - self.pool = pool - - def __del__(self): - if self.pool: - self.pool.putconn(self.conn) - - @contextmanager - def transaction(self): - """context manager to execute within a DB transaction - - Yields: - a psycopg2 cursor - - """ - with self.conn.cursor() as cur: - try: - yield cur - self.conn.commit() - except Exception: - if not self.conn.closed: - self.conn.rollback() - raise - - def copy_to(self, items, tblname, columns, cur=None, item_cb=None): - """Copy items' entries to table tblname with columns information. - - Args: - items (dict): dictionary of data to copy over tblname - tblname (str): Destination table's name - columns ([str]): keys to access data in items and also the - column names in the destination table. - item_cb (fn): optional function to apply to items's entry - - """ - def escape(data): - if data is None: - return '' - if isinstance(data, bytes): - return '\\x%s' % binascii.hexlify(data).decode('ascii') - elif isinstance(data, str): - return '"%s"' % data.replace('"', '""') - elif isinstance(data, datetime.datetime): - # We escape twice to make sure the string generated by - # isoformat gets escaped - return escape(data.isoformat()) - elif isinstance(data, dict): - return escape(json.dumps(data)) - elif isinstance(data, list): - return escape("{%s}" % ','.join(escape(d) for d in data)) - elif isinstance(data, psycopg2.extras.Range): - # We escape twice here too, so that we make sure - # everything gets passed to copy properly - return escape( - '%s%s,%s%s' % ( - '[' if data.lower_inc else '(', - '-infinity' if data.lower_inf else escape(data.lower), - 'infinity' if data.upper_inf else escape(data.upper), - ']' if data.upper_inc else ')', - ) - ) - elif isinstance(data, enum.IntEnum): - return escape(int(data)) - else: - # We don't escape here to make sure we pass literals properly - return str(data) - - read_file, write_file = os.pipe() - - def writer(): - cursor = self._cursor(cur) - with open(read_file, 'r') as f: - cursor.copy_expert('COPY %s (%s) FROM STDIN CSV' % ( - tblname, ', '.join(columns)), f) - - write_thread = threading.Thread(target=writer) - write_thread.start() - - try: - with open(write_file, 'w') as f: - for d in items: - if item_cb is not None: - item_cb(d) - line = [escape(d.get(k)) for k in columns] - f.write(','.join(line)) - f.write('\n') - finally: - # No problem bubbling up exceptions, but we still need to make sure - # we finish copying, even though we're probably going to cancel the - # transaction. - write_thread.join() - - def mktemp(self, tblname, cur=None): - self._cursor(cur).execute('SELECT swh_mktemp(%s)', (tblname,)) +from swh.core.db.db import * # noqa, bw compat +from swh.core.db.db import BaseDb, stored_procedure, execute_values_to_bytes +from swh.core.db.db import cursor_to_bytes, line_to_bytes, jsonize class Db(BaseDb): diff --git a/swh/storage/db_utils.py b/swh/storage/db_utils.py deleted file mode 100644 --- a/swh/storage/db_utils.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (C) 2015-2018 The Software Heritage developers -# See the AUTHORS file at the top-level directory of this distribution -# -# This code has been imported from psycopg2, version 2.7.4, -# https://github.com/psycopg/psycopg2/tree/5afb2ce803debea9533e293eef73c92ffce95bcd -# and modified by Software Heritage. -# -# Original file: lib/extras.py -# -# psycopg2 is free software: you can redistribute it and/or modify it under the -# terms of the GNU Lesser General Public License as published by the Free -# Software Foundation, either version 3 of the License, or (at your option) any -# later version. - - -import re - -import psycopg2.extensions - - -def _paginate(seq, page_size): - """Consume an iterable and return it in chunks. - Every chunk is at most `page_size`. Never return an empty chunk. - """ - page = [] - it = iter(seq) - while 1: - try: - for i in range(page_size): - page.append(next(it)) - yield page - page = [] - except StopIteration: - if page: - yield page - return - - -def _split_sql(sql): - """Split *sql* on a single ``%s`` placeholder. - Split on the %s, perform %% replacement and return pre, post lists of - snippets. - """ - curr = pre = [] - post = [] - tokens = re.split(br'(%.)', sql) - for token in tokens: - if len(token) != 2 or token[:1] != b'%': - curr.append(token) - continue - - if token[1:] == b's': - if curr is pre: - curr = post - else: - raise ValueError( - "the query contains more than one '%s' placeholder") - elif token[1:] == b'%': - curr.append(b'%') - else: - raise ValueError("unsupported format character: '%s'" - % token[1:].decode('ascii', 'replace')) - - if curr is pre: - raise ValueError("the query doesn't contain any '%s' placeholder") - - return pre, post - - -def execute_values_generator(cur, sql, argslist, template=None, page_size=100): - '''Execute a statement using SQL ``VALUES`` with a sequence of parameters. - Rows returned by the query are returned through a generator. - You need to consume the generator for the queries to be executed! - - :param cur: the cursor to use to execute the query. - :param sql: the query to execute. It must contain a single ``%s`` - placeholder, which will be replaced by a `VALUES list`__. - Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``. - :param argslist: sequence of sequences or dictionaries with the arguments - to send to the query. The type and content must be consistent with - *template*. - :param template: the snippet to merge to every item in *argslist* to - compose the query. - - - If the *argslist* items are sequences it should contain positional - placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there - are constants value...). - - If the *argslist* items are mappings it should contain named - placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``). - - If not specified, assume the arguments are sequence and use a simple - positional template (i.e. ``(%s, %s, ...)``), with the number of - placeholders sniffed by the first element in *argslist*. - :param page_size: maximum number of *argslist* items to include in every - statement. If there are more items the function will execute more than - one statement. - :param yield_from_cur: Whether to yield results from the cursor in this - function directly. - - .. __: https://www.postgresql.org/docs/current/static/queries-values.html - - After the execution of the function the `cursor.rowcount` property will - **not** contain a total result. - ''' - # we can't just use sql % vals because vals is bytes: if sql is bytes - # there will be some decoding error because of stupid codec used, and Py3 - # doesn't implement % on bytes. - if not isinstance(sql, bytes): - sql = sql.encode( - psycopg2.extensions.encodings[cur.connection.encoding] - ) - pre, post = _split_sql(sql) - - for page in _paginate(argslist, page_size=page_size): - if template is None: - template = b'(' + b','.join([b'%s'] * len(page[0])) + b')' - parts = pre[:] - for args in page: - parts.append(cur.mogrify(template, args)) - parts.append(b',') - parts[-1:] = post - cur.execute(b''.join(parts)) - yield from cur