Page MenuHomeSoftware Heritage

No OneTemporary

diff --git a/swh/core/db/__init__.py b/swh/core/db/__init__.py
index 2192c7d..9f46544 100644
--- a/swh/core/db/__init__.py
+++ b/swh/core/db/__init__.py
@@ -1,207 +1,217 @@
# Copyright (C) 2015-2019 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import binascii
import datetime
import enum
import json
import logging
import os
+import sys
import threading
from contextlib import contextmanager
import psycopg2
import psycopg2.extras
logger = logging.getLogger(__name__)
psycopg2.extras.register_uuid()
def escape(data):
if data is None:
return ''
if isinstance(data, bytes):
return '\\x%s' % binascii.hexlify(data).decode('ascii')
elif isinstance(data, str):
return '"%s"' % data.replace('"', '""')
elif isinstance(data, datetime.datetime):
# We escape twice to make sure the string generated by
# isoformat gets escaped
return escape(data.isoformat())
elif isinstance(data, dict):
return escape(json.dumps(data))
elif isinstance(data, list):
return escape("{%s}" % ','.join(escape(d) for d in data))
elif isinstance(data, psycopg2.extras.Range):
# We escape twice here too, so that we make sure
# everything gets passed to copy properly
return escape(
'%s%s,%s%s' % (
'[' if data.lower_inc else '(',
'-infinity' if data.lower_inf else escape(data.lower),
'infinity' if data.upper_inf else escape(data.upper),
']' if data.upper_inc else ')',
)
)
elif isinstance(data, enum.IntEnum):
return escape(int(data))
else:
# We don't escape here to make sure we pass literals properly
return str(data)
def typecast_bytea(value, cur):
if value is not None:
data = psycopg2.BINARY(value, cur)
return data.tobytes()
class BaseDb:
"""Base class for swh.*.*Db.
cf. swh.storage.db.Db, swh.archiver.db.ArchiverDb
"""
@classmethod
def adapt_conn(cls, conn):
"""Makes psycopg2 use 'bytes' to decode bytea instead of
'memoryview', for this connection."""
cur = conn.cursor()
cur.execute("SELECT null::bytea, null::bytea[]")
bytea_oid = cur.description[0][1]
bytea_array_oid = cur.description[1][1]
t_bytes = psycopg2.extensions.new_type(
(bytea_oid,), "bytea", typecast_bytea)
psycopg2.extensions.register_type(t_bytes, conn)
t_bytes_array = psycopg2.extensions.new_array_type(
(bytea_array_oid,), "bytea[]", t_bytes)
psycopg2.extensions.register_type(t_bytes_array, conn)
@classmethod
def connect(cls, *args, **kwargs):
"""factory method to create a DB proxy
Accepts all arguments of psycopg2.connect; only some specific
possibilities are reported below.
Args:
connstring: libpq2 connection string
"""
conn = psycopg2.connect(*args, **kwargs)
return cls(conn)
@classmethod
def from_pool(cls, pool):
conn = pool.getconn()
return cls(conn, pool=pool)
def __init__(self, conn, pool=None):
"""create a DB proxy
Args:
conn: psycopg2 connection to the SWH DB
pool: psycopg2 pool of connections
"""
self.adapt_conn(conn)
self.conn = conn
self.pool = pool
def put_conn(self):
if self.pool:
self.pool.putconn(self.conn)
def cursor(self, cur_arg=None):
"""get a cursor: from cur_arg if given, or a fresh one otherwise
meant to avoid boilerplate if/then/else in methods that proxy stored
procedures
"""
if cur_arg is not None:
return cur_arg
else:
return self.conn.cursor()
_cursor = cursor # for bw compat
@contextmanager
def transaction(self):
"""context manager to execute within a DB transaction
Yields:
a psycopg2 cursor
"""
with self.conn.cursor() as cur:
try:
yield cur
self.conn.commit()
except Exception:
if not self.conn.closed:
self.conn.rollback()
raise
def copy_to(self, items, tblname, columns,
cur=None, item_cb=None, default_values={}):
"""Copy items' entries to table tblname with columns information.
Args:
items (List[dict]): dictionaries of data to copy over tblname.
tblname (str): destination table's name.
columns ([str]): keys to access data in items and also the
column names in the destination table.
default_values (dict): dictionary of default values to use when
inserting entried int the tblname table.
cur: a db cursor; if not given, a new cursor will be created.
item_cb (fn): optional function to apply to items's entry.
"""
read_file, write_file = os.pipe()
+ exc_info = None
def writer():
+ nonlocal exc_info
cursor = self.cursor(cur)
with open(read_file, 'r') as f:
- cursor.copy_expert('COPY %s (%s) FROM STDIN CSV' % (
- tblname, ', '.join(columns)), f)
+ try:
+ cursor.copy_expert('COPY %s (%s) FROM STDIN CSV' % (
+ tblname, ', '.join(columns)), f)
+ except Exception:
+ # Tell the main thread about the exception
+ exc_info = sys.exc_info()
write_thread = threading.Thread(target=writer)
write_thread.start()
try:
with open(write_file, 'w') as f:
for d in items:
if item_cb is not None:
item_cb(d)
line = []
for k in columns:
try:
value = d.get(k, default_values.get(k))
line.append(escape(value))
except Exception as e:
logger.error(
'Could not escape value `%r` for column `%s`:'
'Received exception: `%s`',
value, k, e
)
raise e from None
f.write(','.join(line))
f.write('\n')
finally:
# No problem bubbling up exceptions, but we still need to make sure
# we finish copying, even though we're probably going to cancel the
# transaction.
write_thread.join()
+ if exc_info:
+ # postgresql returned an error, let's raise it.
+ raise exc_info[1].with_traceback(exc_info[2])
def mktemp(self, tblname, cur=None):
self.cursor(cur).execute('SELECT swh_mktemp(%s)', (tblname,))
diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py
index 355a384..e599ed3 100644
--- a/swh/core/db/tests/test_db.py
+++ b/swh/core/db/tests/test_db.py
@@ -1,102 +1,110 @@
# Copyright (C) 2019 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import os.path
import tempfile
import unittest
from hypothesis import strategies, given
+import psycopg2
import pytest
from swh.core.db import BaseDb
from .db_testing import (
SingleDbTestFixture, db_create, db_destroy, db_close,
)
INIT_SQL = '''
create table test_table
(
i int,
txt text,
bytes bytea
);
'''
db_rows = strategies.lists(strategies.tuples(
strategies.integers(-2147483648, +2147483647),
strategies.text(
alphabet=strategies.characters(
blacklist_categories=['Cs'], # surrogates
blacklist_characters=[
'\x00', # pgsql does not support the null codepoint
'\r', # pgsql normalizes those
]
),
),
strategies.binary(),
))
@pytest.mark.db
def test_connect():
db_name = db_create('test-db2', dumps=[])
try:
db = BaseDb.connect('dbname=%s' % db_name)
with db.cursor() as cur:
cur.execute(INIT_SQL)
cur.execute("insert into test_table values (1, %s, %s);",
('foo', b'bar'))
cur.execute("select * from test_table;")
assert list(cur) == [(1, 'foo', b'bar')]
finally:
db_close(db.conn)
db_destroy(db_name)
@pytest.mark.db
class TestDb(SingleDbTestFixture, unittest.TestCase):
TEST_DB_NAME = 'test-db'
@classmethod
def setUpClass(cls):
with tempfile.TemporaryDirectory() as td:
with open(os.path.join(td, 'init.sql'), 'a') as fd:
fd.write(INIT_SQL)
cls.TEST_DB_DUMP = os.path.join(td, '*.sql')
super().setUpClass()
def setUp(self):
super().setUp()
self.db = BaseDb(self.conn)
def test_initialized(self):
cur = self.db.cursor()
cur.execute("insert into test_table values (1, %s, %s);",
('foo', b'bar'))
cur.execute("select * from test_table;")
self.assertEqual(list(cur), [(1, 'foo', b'bar')])
def test_reset_tables(self):
cur = self.db.cursor()
cur.execute("insert into test_table values (1, %s, %s);",
('foo', b'bar'))
self.reset_db_tables('test-db')
cur.execute("select * from test_table;")
self.assertEqual(list(cur), [])
@given(db_rows)
def test_copy_to(self, data):
# the table is not reset between runs by hypothesis
self.reset_db_tables('test-db')
items = [dict(zip(['i', 'txt', 'bytes'], item)) for item in data]
self.db.copy_to(items, 'test_table', ['i', 'txt', 'bytes'])
cur = self.db.cursor()
cur.execute('select * from test_table;')
self.assertCountEqual(list(cur), data)
+
+ def test_copy_to_thread_exception(self):
+ data = [(2**65, 'foo', b'bar')]
+
+ items = [dict(zip(['i', 'txt', 'bytes'], item)) for item in data]
+ with self.assertRaises(psycopg2.errors.NumericValueOutOfRange):
+ self.db.copy_to(items, 'test_table', ['i', 'txt', 'bytes'])

File Metadata

Mime Type
text/x-diff
Expires
Thu, Jul 3, 10:51 AM (1 w, 4 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3251899

Event Timeline