Changeset View
Changeset View
Standalone View
Standalone View
swh/core/db/__init__.py
Show All 20 Lines | |||||
logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
psycopg2.extras.register_uuid() | psycopg2.extras.register_uuid() | ||||
def escape(data): | def escape(data): | ||||
if data is None: | if data is None: | ||||
return '' | return "" | ||||
if isinstance(data, bytes): | if isinstance(data, bytes): | ||||
return '\\x%s' % binascii.hexlify(data).decode('ascii') | return "\\x%s" % binascii.hexlify(data).decode("ascii") | ||||
elif isinstance(data, str): | elif isinstance(data, str): | ||||
return '"%s"' % data.replace('"', '""') | return '"%s"' % data.replace('"', '""') | ||||
elif isinstance(data, datetime.datetime): | elif isinstance(data, datetime.datetime): | ||||
# We escape twice to make sure the string generated by | # We escape twice to make sure the string generated by | ||||
# isoformat gets escaped | # isoformat gets escaped | ||||
return escape(data.isoformat()) | return escape(data.isoformat()) | ||||
elif isinstance(data, dict): | elif isinstance(data, dict): | ||||
return escape(json.dumps(data)) | return escape(json.dumps(data)) | ||||
elif isinstance(data, list): | elif isinstance(data, list): | ||||
return escape("{%s}" % ','.join(escape(d) for d in data)) | return escape("{%s}" % ",".join(escape(d) for d in data)) | ||||
elif isinstance(data, psycopg2.extras.Range): | elif isinstance(data, psycopg2.extras.Range): | ||||
# We escape twice here too, so that we make sure | # We escape twice here too, so that we make sure | ||||
# everything gets passed to copy properly | # everything gets passed to copy properly | ||||
return escape( | return escape( | ||||
'%s%s,%s%s' % ( | "%s%s,%s%s" | ||||
'[' if data.lower_inc else '(', | % ( | ||||
'-infinity' if data.lower_inf else escape(data.lower), | "[" if data.lower_inc else "(", | ||||
'infinity' if data.upper_inf else escape(data.upper), | "-infinity" if data.lower_inf else escape(data.lower), | ||||
']' if data.upper_inc else ')', | "infinity" if data.upper_inf else escape(data.upper), | ||||
"]" if data.upper_inc else ")", | |||||
) | ) | ||||
) | ) | ||||
elif isinstance(data, enum.IntEnum): | elif isinstance(data, enum.IntEnum): | ||||
return escape(int(data)) | return escape(int(data)) | ||||
else: | else: | ||||
# We don't escape here to make sure we pass literals properly | # We don't escape here to make sure we pass literals properly | ||||
return str(data) | return str(data) | ||||
Show All 10 Lines | class BaseDb: | ||||
cf. swh.storage.db.Db, swh.archiver.db.ArchiverDb | cf. swh.storage.db.Db, swh.archiver.db.ArchiverDb | ||||
""" | """ | ||||
@classmethod | @classmethod | ||||
def adapt_conn(cls, conn): | def adapt_conn(cls, conn): | ||||
"""Makes psycopg2 use 'bytes' to decode bytea instead of | """Makes psycopg2 use 'bytes' to decode bytea instead of | ||||
'memoryview', for this connection.""" | 'memoryview', for this connection.""" | ||||
t_bytes = psycopg2.extensions.new_type( | t_bytes = psycopg2.extensions.new_type((17,), "bytea", typecast_bytea) | ||||
(17,), "bytea", typecast_bytea) | |||||
psycopg2.extensions.register_type(t_bytes, conn) | psycopg2.extensions.register_type(t_bytes, conn) | ||||
t_bytes_array = psycopg2.extensions.new_array_type( | t_bytes_array = psycopg2.extensions.new_array_type((1001,), "bytea[]", t_bytes) | ||||
(1001,), "bytea[]", t_bytes) | |||||
psycopg2.extensions.register_type(t_bytes_array, conn) | psycopg2.extensions.register_type(t_bytes_array, conn) | ||||
@classmethod | @classmethod | ||||
def connect(cls, *args, **kwargs): | def connect(cls, *args, **kwargs): | ||||
"""factory method to create a DB proxy | """factory method to create a DB proxy | ||||
Accepts all arguments of psycopg2.connect; only some specific | Accepts all arguments of psycopg2.connect; only some specific | ||||
possibilities are reported below. | possibilities are reported below. | ||||
Show All 32 Lines | def cursor(self, cur_arg=None): | ||||
meant to avoid boilerplate if/then/else in methods that proxy stored | meant to avoid boilerplate if/then/else in methods that proxy stored | ||||
procedures | procedures | ||||
""" | """ | ||||
if cur_arg is not None: | if cur_arg is not None: | ||||
return cur_arg | return cur_arg | ||||
else: | else: | ||||
return self.conn.cursor() | return self.conn.cursor() | ||||
_cursor = cursor # for bw compat | _cursor = cursor # for bw compat | ||||
@contextmanager | @contextmanager | ||||
def transaction(self): | def transaction(self): | ||||
"""context manager to execute within a DB transaction | """context manager to execute within a DB transaction | ||||
Yields: | Yields: | ||||
a psycopg2 cursor | a psycopg2 cursor | ||||
""" | """ | ||||
with self.conn.cursor() as cur: | with self.conn.cursor() as cur: | ||||
try: | try: | ||||
yield cur | yield cur | ||||
self.conn.commit() | self.conn.commit() | ||||
except Exception: | except Exception: | ||||
if not self.conn.closed: | if not self.conn.closed: | ||||
self.conn.rollback() | self.conn.rollback() | ||||
raise | raise | ||||
def copy_to(self, items, tblname, columns, | def copy_to( | ||||
cur=None, item_cb=None, default_values={}): | self, items, tblname, columns, cur=None, item_cb=None, default_values={} | ||||
): | |||||
"""Copy items' entries to table tblname with columns information. | """Copy items' entries to table tblname with columns information. | ||||
Args: | Args: | ||||
items (List[dict]): dictionaries of data to copy over tblname. | items (List[dict]): dictionaries of data to copy over tblname. | ||||
tblname (str): destination table's name. | tblname (str): destination table's name. | ||||
columns ([str]): keys to access data in items and also the | columns ([str]): keys to access data in items and also the | ||||
column names in the destination table. | column names in the destination table. | ||||
default_values (dict): dictionary of default values to use when | default_values (dict): dictionary of default values to use when | ||||
inserting entried int the tblname table. | inserting entried int the tblname table. | ||||
cur: a db cursor; if not given, a new cursor will be created. | cur: a db cursor; if not given, a new cursor will be created. | ||||
item_cb (fn): optional function to apply to items's entry. | item_cb (fn): optional function to apply to items's entry. | ||||
""" | """ | ||||
read_file, write_file = os.pipe() | read_file, write_file = os.pipe() | ||||
exc_info = None | exc_info = None | ||||
def writer(): | def writer(): | ||||
nonlocal exc_info | nonlocal exc_info | ||||
cursor = self.cursor(cur) | cursor = self.cursor(cur) | ||||
with open(read_file, 'r') as f: | with open(read_file, "r") as f: | ||||
try: | try: | ||||
cursor.copy_expert('COPY %s (%s) FROM STDIN CSV' % ( | cursor.copy_expert( | ||||
tblname, ', '.join(columns)), f) | "COPY %s (%s) FROM STDIN CSV" % (tblname, ", ".join(columns)), f | ||||
) | |||||
except Exception: | except Exception: | ||||
# Tell the main thread about the exception | # Tell the main thread about the exception | ||||
exc_info = sys.exc_info() | exc_info = sys.exc_info() | ||||
write_thread = threading.Thread(target=writer) | write_thread = threading.Thread(target=writer) | ||||
write_thread.start() | write_thread.start() | ||||
try: | try: | ||||
with open(write_file, 'w') as f: | with open(write_file, "w") as f: | ||||
for d in items: | for d in items: | ||||
if item_cb is not None: | if item_cb is not None: | ||||
item_cb(d) | item_cb(d) | ||||
line = [] | line = [] | ||||
for k in columns: | for k in columns: | ||||
value = d.get(k, default_values.get(k)) | value = d.get(k, default_values.get(k)) | ||||
try: | try: | ||||
line.append(escape(value)) | line.append(escape(value)) | ||||
except Exception as e: | except Exception as e: | ||||
logger.error( | logger.error( | ||||
'Could not escape value `%r` for column `%s`:' | "Could not escape value `%r` for column `%s`:" | ||||
'Received exception: `%s`', | "Received exception: `%s`", | ||||
value, k, e | value, | ||||
k, | |||||
e, | |||||
) | ) | ||||
raise e from None | raise e from None | ||||
f.write(','.join(line)) | f.write(",".join(line)) | ||||
f.write('\n') | f.write("\n") | ||||
finally: | finally: | ||||
# No problem bubbling up exceptions, but we still need to make sure | # No problem bubbling up exceptions, but we still need to make sure | ||||
# we finish copying, even though we're probably going to cancel the | # we finish copying, even though we're probably going to cancel the | ||||
# transaction. | # transaction. | ||||
write_thread.join() | write_thread.join() | ||||
if exc_info: | if exc_info: | ||||
# postgresql returned an error, let's raise it. | # postgresql returned an error, let's raise it. | ||||
raise exc_info[1].with_traceback(exc_info[2]) | raise exc_info[1].with_traceback(exc_info[2]) | ||||
def mktemp(self, tblname, cur=None): | def mktemp(self, tblname, cur=None): | ||||
self.cursor(cur).execute('SELECT swh_mktemp(%s)', (tblname,)) | self.cursor(cur).execute("SELECT swh_mktemp(%s)", (tblname,)) |