Page Menu
Home
Software Heritage
Search
Configure Global Search
Log In
Files
F9312396
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
10 KB
Subscribers
None
View Options
diff --git a/swh/core/db/__init__.py b/swh/core/db/__init__.py
index 2192c7d..9f46544 100644
--- a/swh/core/db/__init__.py
+++ b/swh/core/db/__init__.py
@@ -1,207 +1,217 @@
# Copyright (C) 2015-2019 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import binascii
import datetime
import enum
import json
import logging
import os
+import sys
import threading
from contextlib import contextmanager
import psycopg2
import psycopg2.extras
logger = logging.getLogger(__name__)
psycopg2.extras.register_uuid()
def escape(data):
if data is None:
return ''
if isinstance(data, bytes):
return '\\x%s' % binascii.hexlify(data).decode('ascii')
elif isinstance(data, str):
return '"%s"' % data.replace('"', '""')
elif isinstance(data, datetime.datetime):
# We escape twice to make sure the string generated by
# isoformat gets escaped
return escape(data.isoformat())
elif isinstance(data, dict):
return escape(json.dumps(data))
elif isinstance(data, list):
return escape("{%s}" % ','.join(escape(d) for d in data))
elif isinstance(data, psycopg2.extras.Range):
# We escape twice here too, so that we make sure
# everything gets passed to copy properly
return escape(
'%s%s,%s%s' % (
'[' if data.lower_inc else '(',
'-infinity' if data.lower_inf else escape(data.lower),
'infinity' if data.upper_inf else escape(data.upper),
']' if data.upper_inc else ')',
)
)
elif isinstance(data, enum.IntEnum):
return escape(int(data))
else:
# We don't escape here to make sure we pass literals properly
return str(data)
def typecast_bytea(value, cur):
if value is not None:
data = psycopg2.BINARY(value, cur)
return data.tobytes()
class BaseDb:
"""Base class for swh.*.*Db.
cf. swh.storage.db.Db, swh.archiver.db.ArchiverDb
"""
@classmethod
def adapt_conn(cls, conn):
"""Makes psycopg2 use 'bytes' to decode bytea instead of
'memoryview', for this connection."""
cur = conn.cursor()
cur.execute("SELECT null::bytea, null::bytea[]")
bytea_oid = cur.description[0][1]
bytea_array_oid = cur.description[1][1]
t_bytes = psycopg2.extensions.new_type(
(bytea_oid,), "bytea", typecast_bytea)
psycopg2.extensions.register_type(t_bytes, conn)
t_bytes_array = psycopg2.extensions.new_array_type(
(bytea_array_oid,), "bytea[]", t_bytes)
psycopg2.extensions.register_type(t_bytes_array, conn)
@classmethod
def connect(cls, *args, **kwargs):
"""factory method to create a DB proxy
Accepts all arguments of psycopg2.connect; only some specific
possibilities are reported below.
Args:
connstring: libpq2 connection string
"""
conn = psycopg2.connect(*args, **kwargs)
return cls(conn)
@classmethod
def from_pool(cls, pool):
conn = pool.getconn()
return cls(conn, pool=pool)
def __init__(self, conn, pool=None):
"""create a DB proxy
Args:
conn: psycopg2 connection to the SWH DB
pool: psycopg2 pool of connections
"""
self.adapt_conn(conn)
self.conn = conn
self.pool = pool
def put_conn(self):
if self.pool:
self.pool.putconn(self.conn)
def cursor(self, cur_arg=None):
"""get a cursor: from cur_arg if given, or a fresh one otherwise
meant to avoid boilerplate if/then/else in methods that proxy stored
procedures
"""
if cur_arg is not None:
return cur_arg
else:
return self.conn.cursor()
_cursor = cursor # for bw compat
@contextmanager
def transaction(self):
"""context manager to execute within a DB transaction
Yields:
a psycopg2 cursor
"""
with self.conn.cursor() as cur:
try:
yield cur
self.conn.commit()
except Exception:
if not self.conn.closed:
self.conn.rollback()
raise
def copy_to(self, items, tblname, columns,
cur=None, item_cb=None, default_values={}):
"""Copy items' entries to table tblname with columns information.
Args:
items (List[dict]): dictionaries of data to copy over tblname.
tblname (str): destination table's name.
columns ([str]): keys to access data in items and also the
column names in the destination table.
default_values (dict): dictionary of default values to use when
inserting entried int the tblname table.
cur: a db cursor; if not given, a new cursor will be created.
item_cb (fn): optional function to apply to items's entry.
"""
read_file, write_file = os.pipe()
+ exc_info = None
def writer():
+ nonlocal exc_info
cursor = self.cursor(cur)
with open(read_file, 'r') as f:
- cursor.copy_expert('COPY %s (%s) FROM STDIN CSV' % (
- tblname, ', '.join(columns)), f)
+ try:
+ cursor.copy_expert('COPY %s (%s) FROM STDIN CSV' % (
+ tblname, ', '.join(columns)), f)
+ except Exception:
+ # Tell the main thread about the exception
+ exc_info = sys.exc_info()
write_thread = threading.Thread(target=writer)
write_thread.start()
try:
with open(write_file, 'w') as f:
for d in items:
if item_cb is not None:
item_cb(d)
line = []
for k in columns:
try:
value = d.get(k, default_values.get(k))
line.append(escape(value))
except Exception as e:
logger.error(
'Could not escape value `%r` for column `%s`:'
'Received exception: `%s`',
value, k, e
)
raise e from None
f.write(','.join(line))
f.write('\n')
finally:
# No problem bubbling up exceptions, but we still need to make sure
# we finish copying, even though we're probably going to cancel the
# transaction.
write_thread.join()
+ if exc_info:
+ # postgresql returned an error, let's raise it.
+ raise exc_info[1].with_traceback(exc_info[2])
def mktemp(self, tblname, cur=None):
self.cursor(cur).execute('SELECT swh_mktemp(%s)', (tblname,))
diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py
index 355a384..e599ed3 100644
--- a/swh/core/db/tests/test_db.py
+++ b/swh/core/db/tests/test_db.py
@@ -1,102 +1,110 @@
# Copyright (C) 2019 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import os.path
import tempfile
import unittest
from hypothesis import strategies, given
+import psycopg2
import pytest
from swh.core.db import BaseDb
from .db_testing import (
SingleDbTestFixture, db_create, db_destroy, db_close,
)
INIT_SQL = '''
create table test_table
(
i int,
txt text,
bytes bytea
);
'''
db_rows = strategies.lists(strategies.tuples(
strategies.integers(-2147483648, +2147483647),
strategies.text(
alphabet=strategies.characters(
blacklist_categories=['Cs'], # surrogates
blacklist_characters=[
'\x00', # pgsql does not support the null codepoint
'\r', # pgsql normalizes those
]
),
),
strategies.binary(),
))
@pytest.mark.db
def test_connect():
db_name = db_create('test-db2', dumps=[])
try:
db = BaseDb.connect('dbname=%s' % db_name)
with db.cursor() as cur:
cur.execute(INIT_SQL)
cur.execute("insert into test_table values (1, %s, %s);",
('foo', b'bar'))
cur.execute("select * from test_table;")
assert list(cur) == [(1, 'foo', b'bar')]
finally:
db_close(db.conn)
db_destroy(db_name)
@pytest.mark.db
class TestDb(SingleDbTestFixture, unittest.TestCase):
TEST_DB_NAME = 'test-db'
@classmethod
def setUpClass(cls):
with tempfile.TemporaryDirectory() as td:
with open(os.path.join(td, 'init.sql'), 'a') as fd:
fd.write(INIT_SQL)
cls.TEST_DB_DUMP = os.path.join(td, '*.sql')
super().setUpClass()
def setUp(self):
super().setUp()
self.db = BaseDb(self.conn)
def test_initialized(self):
cur = self.db.cursor()
cur.execute("insert into test_table values (1, %s, %s);",
('foo', b'bar'))
cur.execute("select * from test_table;")
self.assertEqual(list(cur), [(1, 'foo', b'bar')])
def test_reset_tables(self):
cur = self.db.cursor()
cur.execute("insert into test_table values (1, %s, %s);",
('foo', b'bar'))
self.reset_db_tables('test-db')
cur.execute("select * from test_table;")
self.assertEqual(list(cur), [])
@given(db_rows)
def test_copy_to(self, data):
# the table is not reset between runs by hypothesis
self.reset_db_tables('test-db')
items = [dict(zip(['i', 'txt', 'bytes'], item)) for item in data]
self.db.copy_to(items, 'test_table', ['i', 'txt', 'bytes'])
cur = self.db.cursor()
cur.execute('select * from test_table;')
self.assertCountEqual(list(cur), data)
+
+ def test_copy_to_thread_exception(self):
+ data = [(2**65, 'foo', b'bar')]
+
+ items = [dict(zip(['i', 'txt', 'bytes'], item)) for item in data]
+ with self.assertRaises(psycopg2.errors.NumericValueOutOfRange):
+ self.db.copy_to(items, 'test_table', ['i', 'txt', 'bytes'])
File Metadata
Details
Attached
Mime Type
text/x-diff
Expires
Thu, Jul 3, 10:51 AM (1 w, 4 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3251899
Attached To
rDCORE Foundations and core functionalities
Event Timeline
Log In to Comment