diff --git a/swh/core/db/__init__.py b/swh/core/db/__init__.py index 9f46544..0090457 100644 --- a/swh/core/db/__init__.py +++ b/swh/core/db/__init__.py @@ -1,217 +1,212 @@ # Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import binascii import datetime import enum import json import logging import os import sys import threading from contextlib import contextmanager import psycopg2 import psycopg2.extras logger = logging.getLogger(__name__) psycopg2.extras.register_uuid() def escape(data): if data is None: return '' if isinstance(data, bytes): return '\\x%s' % binascii.hexlify(data).decode('ascii') elif isinstance(data, str): return '"%s"' % data.replace('"', '""') elif isinstance(data, datetime.datetime): # We escape twice to make sure the string generated by # isoformat gets escaped return escape(data.isoformat()) elif isinstance(data, dict): return escape(json.dumps(data)) elif isinstance(data, list): return escape("{%s}" % ','.join(escape(d) for d in data)) elif isinstance(data, psycopg2.extras.Range): # We escape twice here too, so that we make sure # everything gets passed to copy properly return escape( '%s%s,%s%s' % ( '[' if data.lower_inc else '(', '-infinity' if data.lower_inf else escape(data.lower), 'infinity' if data.upper_inf else escape(data.upper), ']' if data.upper_inc else ')', ) ) elif isinstance(data, enum.IntEnum): return escape(int(data)) else: # We don't escape here to make sure we pass literals properly return str(data) def typecast_bytea(value, cur): if value is not None: data = psycopg2.BINARY(value, cur) return data.tobytes() class BaseDb: """Base class for swh.*.*Db. cf. swh.storage.db.Db, swh.archiver.db.ArchiverDb """ @classmethod def adapt_conn(cls, conn): """Makes psycopg2 use 'bytes' to decode bytea instead of 'memoryview', for this connection.""" - cur = conn.cursor() - cur.execute("SELECT null::bytea, null::bytea[]") - bytea_oid = cur.description[0][1] - bytea_array_oid = cur.description[1][1] - t_bytes = psycopg2.extensions.new_type( - (bytea_oid,), "bytea", typecast_bytea) + (17,), "bytea", typecast_bytea) psycopg2.extensions.register_type(t_bytes, conn) t_bytes_array = psycopg2.extensions.new_array_type( - (bytea_array_oid,), "bytea[]", t_bytes) + (1001,), "bytea[]", t_bytes) psycopg2.extensions.register_type(t_bytes_array, conn) @classmethod def connect(cls, *args, **kwargs): """factory method to create a DB proxy Accepts all arguments of psycopg2.connect; only some specific possibilities are reported below. Args: connstring: libpq2 connection string """ conn = psycopg2.connect(*args, **kwargs) return cls(conn) @classmethod def from_pool(cls, pool): conn = pool.getconn() return cls(conn, pool=pool) def __init__(self, conn, pool=None): """create a DB proxy Args: conn: psycopg2 connection to the SWH DB pool: psycopg2 pool of connections """ self.adapt_conn(conn) self.conn = conn self.pool = pool def put_conn(self): if self.pool: self.pool.putconn(self.conn) def cursor(self, cur_arg=None): """get a cursor: from cur_arg if given, or a fresh one otherwise meant to avoid boilerplate if/then/else in methods that proxy stored procedures """ if cur_arg is not None: return cur_arg else: return self.conn.cursor() _cursor = cursor # for bw compat @contextmanager def transaction(self): """context manager to execute within a DB transaction Yields: a psycopg2 cursor """ with self.conn.cursor() as cur: try: yield cur self.conn.commit() except Exception: if not self.conn.closed: self.conn.rollback() raise def copy_to(self, items, tblname, columns, cur=None, item_cb=None, default_values={}): """Copy items' entries to table tblname with columns information. Args: items (List[dict]): dictionaries of data to copy over tblname. tblname (str): destination table's name. columns ([str]): keys to access data in items and also the column names in the destination table. default_values (dict): dictionary of default values to use when inserting entried int the tblname table. cur: a db cursor; if not given, a new cursor will be created. item_cb (fn): optional function to apply to items's entry. """ read_file, write_file = os.pipe() exc_info = None def writer(): nonlocal exc_info cursor = self.cursor(cur) with open(read_file, 'r') as f: try: cursor.copy_expert('COPY %s (%s) FROM STDIN CSV' % ( tblname, ', '.join(columns)), f) except Exception: # Tell the main thread about the exception exc_info = sys.exc_info() write_thread = threading.Thread(target=writer) write_thread.start() try: with open(write_file, 'w') as f: for d in items: if item_cb is not None: item_cb(d) line = [] for k in columns: try: value = d.get(k, default_values.get(k)) line.append(escape(value)) except Exception as e: logger.error( 'Could not escape value `%r` for column `%s`:' 'Received exception: `%s`', value, k, e ) raise e from None f.write(','.join(line)) f.write('\n') finally: # No problem bubbling up exceptions, but we still need to make sure # we finish copying, even though we're probably going to cancel the # transaction. write_thread.join() if exc_info: # postgresql returned an error, let's raise it. raise exc_info[1].with_traceback(exc_info[2]) def mktemp(self, tblname, cur=None): self.cursor(cur).execute('SELECT swh_mktemp(%s)', (tblname,))