diff --git a/swh/core/db/__init__.py b/swh/core/db/__init__.py --- a/swh/core/db/__init__.py +++ b/swh/core/db/__init__.py @@ -52,6 +52,12 @@ return str(data) +def typecast_bytea(value, cur): + if value is not None: + data = psycopg2.BINARY(value, cur) + return data.tobytes() + + class BaseDb: """Base class for swh.*.*Db. @@ -60,6 +66,23 @@ """ @classmethod + def adapt_conn(cls, conn): + """Makes psycopg2 use 'bytes' to decode bytea instead of + 'memoryview', for this connection.""" + cur = conn.cursor() + cur.execute("SELECT null::bytea, null::bytea[]") + bytea_oid = cur.description[0][1] + bytea_array_oid = cur.description[1][1] + + t_bytes = psycopg2.extensions.new_type( + (bytea_oid,), "bytea", typecast_bytea) + psycopg2.extensions.register_type(t_bytes, conn) + + t_bytes_array = psycopg2.extensions.new_array_type( + (bytea_array_oid,), "bytea[]", t_bytes) + psycopg2.extensions.register_type(t_bytes_array, conn) + + @classmethod def connect(cls, *args, **kwargs): """factory method to create a DB proxy @@ -71,11 +94,14 @@ """ conn = psycopg2.connect(*args, **kwargs) + cls.adapt_conn(conn) return cls(conn) @classmethod def from_pool(cls, pool): - return cls(pool.getconn(), pool=pool) + conn = pool.getconn() + cls.adapt_conn(conn) + return cls(conn, pool=pool) def __init__(self, conn, pool=None): """create a DB proxy diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -44,34 +44,6 @@ return value -def entry_to_bytes(entry): - """Convert an entry coming from the database to bytes""" - if isinstance(entry, memoryview): - return entry.tobytes() - if isinstance(entry, list): - return [entry_to_bytes(value) for value in entry] - return entry - - -def line_to_bytes(line): - """Convert a line coming from the database to bytes""" - if not line: - return line - if isinstance(line, dict): - return {k: entry_to_bytes(v) for k, v in line.items()} - return line.__class__(entry_to_bytes(entry) for entry in line) - - -def cursor_to_bytes(cursor): - """Yield all the data from a cursor as bytes""" - yield from (line_to_bytes(line) for line in cursor) - - -def execute_values_to_bytes(*args, **kwargs): - for line in execute_values_generator(*args, **kwargs): - yield line_to_bytes(line) - - def _paginate(seq, page_size): """Consume an iterable and return it in chunks. Every chunk is at most `page_size`. Never return an empty chunk.