Page MenuHomeSoftware Heritage

D1053.diff
No OneTemporary

D1053.diff

diff --git a/swh/storage/db.py b/swh/storage/db.py
--- a/swh/storage/db.py
+++ b/swh/storage/db.py
@@ -50,32 +50,10 @@
return value
-def entry_to_bytes(entry):
- """Convert an entry coming from the database to bytes"""
- if isinstance(entry, memoryview):
- return entry.tobytes()
- if isinstance(entry, list):
- return [entry_to_bytes(value) for value in entry]
- return entry
-
-
-def line_to_bytes(line):
- """Convert a line coming from the database to bytes"""
- if not line:
- return line
- if isinstance(line, dict):
- return {k: entry_to_bytes(v) for k, v in line.items()}
- return line.__class__(entry_to_bytes(entry) for entry in line)
-
-
-def cursor_to_bytes(cursor):
- """Yield all the data from a cursor as bytes"""
- yield from (line_to_bytes(line) for line in cursor)
-
-
-def execute_values_to_bytes(*args, **kwargs):
- for line in execute_values_generator(*args, **kwargs):
- yield line_to_bytes(line)
+def typecast_bytea(value, cur):
+ if value is not None:
+ data = psycopg2.BINARY(value, cur)
+ return data.tobytes()
class BaseDb:
@@ -86,6 +64,23 @@
"""
@classmethod
+ def adapt_conn(cls, conn):
+ """Makes psycopg2 use 'bytes' to decode bytea instead of
+ 'memoryview', for this connection."""
+ cur = conn.cursor()
+ cur.execute("SELECT null::bytea, null::bytea[]")
+ bytea_oid = cur.description[0][1]
+ bytea_array_oid = cur.description[1][1]
+
+ t_bytes = psycopg2.extensions.new_type(
+ (bytea_oid,), "bytea", typecast_bytea)
+ psycopg2.extensions.register_type(t_bytes, conn)
+
+ t_bytes_array = psycopg2.extensions.new_array_type(
+ (bytea_array_oid,), "bytea[]", t_bytes)
+ psycopg2.extensions.register_type(t_bytes_array, conn)
+
+ @classmethod
def connect(cls, *args, **kwargs):
"""factory method to create a DB proxy
@@ -97,11 +92,14 @@
"""
conn = psycopg2.connect(*args, **kwargs)
+ cls.adapt_conn(conn)
return cls(conn)
@classmethod
def from_pool(cls, pool):
- return cls(pool.getconn(), pool=pool)
+ conn = pool.getconn()
+ cls.adapt_conn(conn)
+ return cls(conn, pool=pool)
def _cursor(self, cur_arg):
"""get a cursor: from cur_arg if given, or a fresh one otherwise
@@ -281,7 +279,7 @@
def content_get_metadata_from_sha1s(self, sha1s, cur=None):
cur = self._cursor(cur)
- yield from execute_values_to_bytes(
+ yield from execute_values_generator(
cur, """
select t.sha1, %s from (values %%s) as t (sha1)
left join content using (sha1)
@@ -299,7 +297,7 @@
order by sha1
limit %%s""" % ', '.join(self.content_get_metadata_keys)
cur.execute(query, (start, end, limit))
- yield from cursor_to_bytes(cur)
+ yield from cur
content_hash_keys = ['sha1', 'sha1_git', 'sha256', 'blake2s256']
@@ -312,7 +310,7 @@
for key in self.content_hash_keys
)
- yield from execute_values_to_bytes(
+ yield from execute_values_generator(
cur, """
SELECT %s
FROM (VALUES %%s) as t(%s)
@@ -327,7 +325,7 @@
def content_missing_per_sha1(self, sha1s, cur=None):
cur = self._cursor(cur)
- yield from execute_values_to_bytes(cur, """
+ yield from execute_values_generator(cur, """
SELECT t.sha1 FROM (VALUES %s) AS t(sha1)
WHERE NOT EXISTS (
SELECT 1 FROM content c WHERE c.sha1 = t.sha1
@@ -339,7 +337,7 @@
cur.execute("""SELECT sha1, sha1_git, sha256, blake2s256
FROM swh_skipped_content_missing()""")
- yield from cursor_to_bytes(cur)
+ yield from cur
def snapshot_exists(self, snapshot_id, cur=None):
"""Check whether a snapshot with the given id exists"""
@@ -366,7 +364,7 @@
cur.execute(query, (snapshot_id,))
- yield from cursor_to_bytes(cur)
+ yield from cur
snapshot_get_cols = ['snapshot_id', 'name', 'target', 'target_type']
@@ -382,7 +380,7 @@
cur.execute(query, (snapshot_id, branches_from, branches_count,
target_types))
- yield from cursor_to_bytes(cur)
+ yield from cur
def snapshot_get_by_origin_visit(self, origin_id, visit_id, cur=None):
cur = self._cursor(cur)
@@ -393,7 +391,7 @@
cur.execute(query, (origin_id, visit_id))
ret = cur.fetchone()
if ret:
- return line_to_bytes(ret)[0]
+ return ret[0]
content_find_cols = ['sha1', 'sha1_git', 'sha256', 'blake2s256', 'length',
'ctime', 'status']
@@ -420,7 +418,7 @@
LIMIT 1""" % ','.join(self.content_find_cols),
(sha1, sha1_git, sha256, blake2s256))
- content = line_to_bytes(cur.fetchone())
+ content = cur.fetchone()
if set(content) == {None}:
return None
else:
@@ -428,7 +426,7 @@
def directory_missing_from_list(self, directories, cur=None):
cur = self._cursor(cur)
- yield from execute_values_to_bytes(
+ yield from execute_values_generator(
cur, """
SELECT id FROM (VALUES %s) as t(id)
WHERE NOT EXISTS (
@@ -444,14 +442,14 @@
cols = ', '.join(self.directory_ls_cols)
query = 'SELECT %s FROM swh_directory_walk_one(%%s)' % cols
cur.execute(query, (directory,))
- yield from cursor_to_bytes(cur)
+ yield from cur
def directory_walk(self, directory, cur=None):
cur = self._cursor(cur)
cols = ', '.join(self.directory_ls_cols)
query = 'SELECT %s FROM swh_directory_walk(%%s)' % cols
cur.execute(query, (directory,))
- yield from cursor_to_bytes(cur)
+ yield from cur
def directory_entry_get_by_path(self, directory, paths, cur=None):
"""Retrieve a directory entry by path.
@@ -467,12 +465,12 @@
data = cur.fetchone()
if set(data) == {None}:
return None
- return line_to_bytes(data)
+ return data
def revision_missing_from_list(self, revisions, cur=None):
cur = self._cursor(cur)
- yield from execute_values_to_bytes(
+ yield from execute_values_generator(
cur, """
SELECT id FROM (VALUES %s) as t(id)
WHERE NOT EXISTS (
@@ -552,7 +550,7 @@
cur.execute(query, args)
- yield from cursor_to_bytes(cur)
+ yield from cur
def origin_visit_get(self, origin_id, visit_id, cur=None):
"""Retrieve information on visit visit_id of origin origin_id.
@@ -579,7 +577,7 @@
r = cur.fetchall()
if not r:
return None
- return line_to_bytes(r[0])
+ return r[0]
def origin_visit_exists(self, origin_id, visit_id, cur=None):
"""Check whether an origin visit with the given ids exists"""
@@ -625,7 +623,7 @@
r = cur.fetchone()
if not r:
return None
- return line_to_bytes(r)
+ return r
@staticmethod
def mangle_query_key(key, main_table):
@@ -657,7 +655,7 @@
for k in self.revision_get_cols
)
- yield from execute_values_to_bytes(
+ yield from execute_values_generator(
cur, """
SELECT %s FROM (VALUES %%s) as t(id)
LEFT JOIN revision ON t.id = revision.id
@@ -674,7 +672,7 @@
""" % ', '.join(self.revision_get_cols)
cur.execute(query, (root_revisions, limit))
- yield from cursor_to_bytes(cur)
+ yield from cur
revision_shortlog_cols = ['id', 'parents']
@@ -686,11 +684,11 @@
""" % ', '.join(self.revision_shortlog_cols)
cur.execute(query, (root_revisions, limit))
- yield from cursor_to_bytes(cur)
+ yield from cur
def release_missing_from_list(self, releases, cur=None):
cur = self._cursor(cur)
- yield from execute_values_to_bytes(
+ yield from execute_values_generator(
cur, """
SELECT id FROM (VALUES %s) as t(id)
WHERE NOT EXISTS (
@@ -703,7 +701,7 @@
def object_find_by_sha1_git(self, ids, cur=None):
cur = self._cursor(cur)
- yield from execute_values_to_bytes(
+ yield from execute_values_generator(
cur, """
WITH t (id) AS (VALUES %s),
known_objects as ((
@@ -819,7 +817,7 @@
cur.execute(query, (type, url))
data = cur.fetchone()
if data:
- return line_to_bytes(data)
+ return data
return None
def origin_get(self, id, cur=None):
@@ -835,7 +833,7 @@
cur.execute(query, (id,))
data = cur.fetchone()
if data:
- return line_to_bytes(data)
+ return data
return None
def origin_search(self, url_pattern, offset=0, limit=50,
@@ -876,7 +874,7 @@
query_params = (url_pattern, offset, limit)
cur.execute(query, query_params)
- yield from cursor_to_bytes(cur)
+ yield from cur
person_cols = ['fullname', 'name', 'email']
person_get_cols = person_cols + ['id']
@@ -892,7 +890,7 @@
WHERE id IN %%s""" % ', '.join(self.person_get_cols)
cur.execute(query, (tuple(ids),))
- yield from cursor_to_bytes(cur)
+ yield from cur
release_add_cols = [
'id', 'target', 'target_type', 'date', 'date_offset',
@@ -908,7 +906,7 @@
for k in self.release_get_cols
)
- yield from execute_values_to_bytes(
+ yield from execute_values_generator(
cur, """
SELECT %s FROM (VALUES %%s) as t(id)
LEFT JOIN release ON t.id = release.id
@@ -966,7 +964,7 @@
cur.execute(query, (origin_id, provider_type))
- yield from cursor_to_bytes(cur)
+ yield from cur
tool_cols = ['id', 'name', 'version', 'configuration']
@@ -978,7 +976,7 @@
cur = self._cursor(cur)
cur.execute("SELECT %s from swh_tool_add()" % (
','.join(self.tool_cols), ))
- yield from cursor_to_bytes(cur)
+ yield from cur
def tool_get(self, name, version, configuration, cur=None):
cur = self._cursor(cur)
@@ -993,7 +991,7 @@
data = cur.fetchone()
if not data:
return None
- return line_to_bytes(data)
+ return data
metadata_provider_cols = ['id', 'provider_name', 'provider_type',
'provider_url', 'metadata']
@@ -1021,7 +1019,7 @@
data = cur.fetchone()
if not data:
return None
- return line_to_bytes(data)
+ return data
def metadata_provider_get_by(self, provider_name, provider_url,
cur=None):
@@ -1036,4 +1034,4 @@
data = cur.fetchone()
if not data:
return None
- return line_to_bytes(data)
+ return data
diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py
--- a/swh/storage/tests/test_storage.py
+++ b/swh/storage/tests/test_storage.py
@@ -2214,7 +2214,7 @@
datum = cur.fetchone()
self.assertEqual(
- (datum[0].tobytes(), datum[1].tobytes(), datum[2].tobytes(),
+ (datum[0], datum[1], datum[2],
datum[3], datum[4]),
(cont['sha1'], cont['sha1_git'], cont['sha256'],
cont['length'], 'visible'))
@@ -2240,7 +2240,7 @@
datum = cur.fetchone()
self.assertEqual(
- (datum[0].tobytes(), datum[1].tobytes(), datum[2].tobytes(),
+ (datum[0], datum[1], datum[2],
datum[3], datum[4], datum[5], datum[6]),
(cont['sha1'], cont['sha1_git'], cont['sha256'],
cont['length'], 'visible', cont['test'], cont['test2']))

File Metadata

Mime Type
text/plain
Expires
Dec 21 2024, 11:09 AM (11 w, 4 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3224150

Event Timeline