diff --git a/swh/core/db/__init__.py b/swh/core/db/__init__.py index ad64b8b..cab7ddb 100644 --- a/swh/core/db/__init__.py +++ b/swh/core/db/__init__.py @@ -1,167 +1,193 @@ # Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import binascii import datetime import enum import json import os import threading from contextlib import contextmanager import psycopg2 import psycopg2.extras psycopg2.extras.register_uuid() def escape(data): if data is None: return '' if isinstance(data, bytes): return '\\x%s' % binascii.hexlify(data).decode('ascii') elif isinstance(data, str): return '"%s"' % data.replace('"', '""') elif isinstance(data, datetime.datetime): # We escape twice to make sure the string generated by # isoformat gets escaped return escape(data.isoformat()) elif isinstance(data, dict): return escape(json.dumps(data)) elif isinstance(data, list): return escape("{%s}" % ','.join(escape(d) for d in data)) elif isinstance(data, psycopg2.extras.Range): # We escape twice here too, so that we make sure # everything gets passed to copy properly return escape( '%s%s,%s%s' % ( '[' if data.lower_inc else '(', '-infinity' if data.lower_inf else escape(data.lower), 'infinity' if data.upper_inf else escape(data.upper), ']' if data.upper_inc else ')', ) ) elif isinstance(data, enum.IntEnum): return escape(int(data)) else: # We don't escape here to make sure we pass literals properly return str(data) +def typecast_bytea(value, cur): + if value is not None: + data = psycopg2.BINARY(value, cur) + return data.tobytes() + + class BaseDb: """Base class for swh.*.*Db. cf. swh.storage.db.Db, swh.archiver.db.ArchiverDb """ + @classmethod + def adapt_conn(cls, conn): + """Makes psycopg2 use 'bytes' to decode bytea instead of + 'memoryview', for this connection.""" + cur = conn.cursor() + cur.execute("SELECT null::bytea, null::bytea[]") + bytea_oid = cur.description[0][1] + bytea_array_oid = cur.description[1][1] + + t_bytes = psycopg2.extensions.new_type( + (bytea_oid,), "bytea", typecast_bytea) + psycopg2.extensions.register_type(t_bytes, conn) + + t_bytes_array = psycopg2.extensions.new_array_type( + (bytea_array_oid,), "bytea[]", t_bytes) + psycopg2.extensions.register_type(t_bytes_array, conn) + @classmethod def connect(cls, *args, **kwargs): """factory method to create a DB proxy Accepts all arguments of psycopg2.connect; only some specific possibilities are reported below. Args: connstring: libpq2 connection string """ conn = psycopg2.connect(*args, **kwargs) + cls.adapt_conn(conn) return cls(conn) @classmethod def from_pool(cls, pool): - return cls(pool.getconn(), pool=pool) + conn = pool.getconn() + cls.adapt_conn(conn) + return cls(conn, pool=pool) def __init__(self, conn, pool=None): """create a DB proxy Args: conn: psycopg2 connection to the SWH DB pool: psycopg2 pool of connections """ self.conn = conn self.pool = pool def __del__(self): if self.pool: self.pool.putconn(self.conn) def cursor(self, cur_arg=None): """get a cursor: from cur_arg if given, or a fresh one otherwise meant to avoid boilerplate if/then/else in methods that proxy stored procedures """ if cur_arg is not None: return cur_arg else: return self.conn.cursor() _cursor = cursor # for bw compat @contextmanager def transaction(self): """context manager to execute within a DB transaction Yields: a psycopg2 cursor """ with self.conn.cursor() as cur: try: yield cur self.conn.commit() except Exception: if not self.conn.closed: self.conn.rollback() raise def copy_to(self, items, tblname, columns, cur=None, item_cb=None, default_values={}): """Copy items' entries to table tblname with columns information. Args: items (dict): dictionary of data to copy over tblname. tblname (str): destination table's name. columns ([str]): keys to access data in items and also the column names in the destination table. default_values (dict): dictionnary of default values to use when inserting entried int the tblname table. cur: a db cursor; if not given, a new cursor will be created. item_cb (fn): optional function to apply to items's entry. """ read_file, write_file = os.pipe() def writer(): cursor = self.cursor(cur) with open(read_file, 'r') as f: cursor.copy_expert('COPY %s (%s) FROM STDIN CSV' % ( tblname, ', '.join(columns)), f) write_thread = threading.Thread(target=writer) write_thread.start() try: with open(write_file, 'w') as f: for d in items: if item_cb is not None: item_cb(d) line = [escape(d.get(k, default_values.get(k))) for k in columns] f.write(','.join(line)) f.write('\n') finally: # No problem bubbling up exceptions, but we still need to make sure # we finish copying, even though we're probably going to cancel the # transaction. write_thread.join() def mktemp(self, tblname, cur=None): self.cursor(cur).execute('SELECT swh_mktemp(%s)', (tblname,)) diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py index 41fbdd7..451fb58 100644 --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -1,177 +1,149 @@ # Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # # This code has been imported from psycopg2, version 2.7.4, # https://github.com/psycopg/psycopg2/tree/5afb2ce803debea9533e293eef73c92ffce95bcd # and modified by Software Heritage. # # Original file: lib/extras.py # # psycopg2 is free software: you can redistribute it and/or modify it under the # terms of the GNU Lesser General Public License as published by the Free # Software Foundation, either version 3 of the License, or (at your option) any # later version. import re import functools import psycopg2.extensions def stored_procedure(stored_proc): """decorator to execute remote stored procedure, specified as argument Generally, the body of the decorated function should be empty. If it is not, the stored procedure will be executed first; the function body then. """ def wrap(meth): @functools.wraps(meth) def _meth(self, *args, **kwargs): cur = kwargs.get('cur', None) self._cursor(cur).execute('SELECT %s()' % stored_proc) meth(self, *args, **kwargs) return _meth return wrap def jsonize(value): """Convert a value to a psycopg2 JSON object if necessary""" if isinstance(value, dict): return psycopg2.extras.Json(value) return value -def entry_to_bytes(entry): - """Convert an entry coming from the database to bytes""" - if isinstance(entry, memoryview): - return entry.tobytes() - if isinstance(entry, list): - return [entry_to_bytes(value) for value in entry] - return entry - - -def line_to_bytes(line): - """Convert a line coming from the database to bytes""" - if not line: - return line - if isinstance(line, dict): - return {k: entry_to_bytes(v) for k, v in line.items()} - return line.__class__(entry_to_bytes(entry) for entry in line) - - -def cursor_to_bytes(cursor): - """Yield all the data from a cursor as bytes""" - yield from (line_to_bytes(line) for line in cursor) - - -def execute_values_to_bytes(*args, **kwargs): - for line in execute_values_generator(*args, **kwargs): - yield line_to_bytes(line) - - def _paginate(seq, page_size): """Consume an iterable and return it in chunks. Every chunk is at most `page_size`. Never return an empty chunk. """ page = [] it = iter(seq) while 1: try: for i in range(page_size): page.append(next(it)) yield page page = [] except StopIteration: if page: yield page return def _split_sql(sql): """Split *sql* on a single ``%s`` placeholder. Split on the %s, perform %% replacement and return pre, post lists of snippets. """ curr = pre = [] post = [] tokens = re.split(br'(%.)', sql) for token in tokens: if len(token) != 2 or token[:1] != b'%': curr.append(token) continue if token[1:] == b's': if curr is pre: curr = post else: raise ValueError( "the query contains more than one '%s' placeholder") elif token[1:] == b'%': curr.append(b'%') else: raise ValueError("unsupported format character: '%s'" % token[1:].decode('ascii', 'replace')) if curr is pre: raise ValueError("the query doesn't contain any '%s' placeholder") return pre, post def execute_values_generator(cur, sql, argslist, template=None, page_size=100): '''Execute a statement using SQL ``VALUES`` with a sequence of parameters. Rows returned by the query are returned through a generator. You need to consume the generator for the queries to be executed! :param cur: the cursor to use to execute the query. :param sql: the query to execute. It must contain a single ``%s`` placeholder, which will be replaced by a `VALUES list`__. Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``. :param argslist: sequence of sequences or dictionaries with the arguments to send to the query. The type and content must be consistent with *template*. :param template: the snippet to merge to every item in *argslist* to compose the query. - If the *argslist* items are sequences it should contain positional placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there are constants value...). - If the *argslist* items are mappings it should contain named placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``). If not specified, assume the arguments are sequence and use a simple positional template (i.e. ``(%s, %s, ...)``), with the number of placeholders sniffed by the first element in *argslist*. :param page_size: maximum number of *argslist* items to include in every statement. If there are more items the function will execute more than one statement. :param yield_from_cur: Whether to yield results from the cursor in this function directly. .. __: https://www.postgresql.org/docs/current/static/queries-values.html After the execution of the function the `cursor.rowcount` property will **not** contain a total result. ''' # we can't just use sql % vals because vals is bytes: if sql is bytes # there will be some decoding error because of stupid codec used, and Py3 # doesn't implement % on bytes. if not isinstance(sql, bytes): sql = sql.encode( psycopg2.extensions.encodings[cur.connection.encoding] ) pre, post = _split_sql(sql) for page in _paginate(argslist, page_size=page_size): if template is None: template = b'(' + b','.join([b'%s'] * len(page[0])) + b')' parts = pre[:] for args in page: parts.append(cur.mogrify(template, args)) parts.append(b',') parts[-1:] = post cur.execute(b''.join(parts)) yield from cur