diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -50,32 +50,16 @@ return value -def entry_to_bytes(entry): - """Convert an entry coming from the database to bytes""" - if isinstance(entry, memoryview): - return entry.tobytes() - if isinstance(entry, list): - return [entry_to_bytes(value) for value in entry] - return entry +def typecast_bytea(value, cur): + if value is not None: + data = psycopg2.BINARY(value, cur) + return data.tobytes() -def line_to_bytes(line): - """Convert a line coming from the database to bytes""" - if not line: - return line - if isinstance(line, dict): - return {k: entry_to_bytes(v) for k, v in line.items()} - return line.__class__(entry_to_bytes(entry) for entry in line) - - -def cursor_to_bytes(cursor): - """Yield all the data from a cursor as bytes""" - yield from (line_to_bytes(line) for line in cursor) - - -def execute_values_to_bytes(*args, **kwargs): - for line in execute_values_generator(*args, **kwargs): - yield line_to_bytes(line) +def typecast_bytea_array(value, cur): + if value is not None: + data = psycopg2.BINARY(value, cur) + return data.tobytes() class BaseDb: @@ -86,6 +70,23 @@ """ @classmethod + def adapt_conn(cls, conn): + """Makes psycopg2 use 'bytes' to decode bytea instead of + 'memoryview', for this connection.""" + cur = conn.cursor() + cur.execute("SELECT null::bytea, null::bytea[]") + bytea_oid = cur.description[0][1] + bytea_array_oid = cur.description[1][1] + + t_bytes = psycopg2.extensions.new_type( + (bytea_oid,), "bytea", typecast_bytea) + psycopg2.extensions.register_type(t_bytes, conn) + + t_bytes_array = psycopg2.extensions.new_array_type( + (bytea_array_oid,), "bytea[]", t_bytes) + psycopg2.extensions.register_type(t_bytes_array, conn) + + @classmethod def connect(cls, *args, **kwargs): """factory method to create a DB proxy @@ -97,11 +98,14 @@ """ conn = psycopg2.connect(*args, **kwargs) + cls.adapt_conn(conn) return cls(conn) @classmethod def from_pool(cls, pool): - return cls(pool.getconn(), pool=pool) + conn = pool.getconn() + cls.adapt_conn(conn) + return cls(conn, pool=pool) def _cursor(self, cur_arg): """get a cursor: from cur_arg if given, or a fresh one otherwise @@ -281,7 +285,7 @@ def content_get_metadata_from_sha1s(self, sha1s, cur=None): cur = self._cursor(cur) - yield from execute_values_to_bytes( + yield from execute_values_generator( cur, """ select t.sha1, %s from (values %%s) as t (sha1) left join content using (sha1) @@ -299,7 +303,7 @@ order by sha1 limit %%s""" % ', '.join(self.content_get_metadata_keys) cur.execute(query, (start, end, limit)) - yield from cursor_to_bytes(cur) + yield from cur content_hash_keys = ['sha1', 'sha1_git', 'sha256', 'blake2s256'] @@ -312,7 +316,7 @@ for key in self.content_hash_keys ) - yield from execute_values_to_bytes( + yield from execute_values_generator( cur, """ SELECT %s FROM (VALUES %%s) as t(%s) @@ -327,7 +331,7 @@ def content_missing_per_sha1(self, sha1s, cur=None): cur = self._cursor(cur) - yield from execute_values_to_bytes(cur, """ + yield from execute_values_generator(cur, """ SELECT t.sha1 FROM (VALUES %s) AS t(sha1) WHERE NOT EXISTS ( SELECT 1 FROM content c WHERE c.sha1 = t.sha1 @@ -339,7 +343,7 @@ cur.execute("""SELECT sha1, sha1_git, sha256, blake2s256 FROM swh_skipped_content_missing()""") - yield from cursor_to_bytes(cur) + yield from cur def snapshot_exists(self, snapshot_id, cur=None): """Check whether a snapshot with the given id exists""" @@ -366,7 +370,7 @@ cur.execute(query, (snapshot_id,)) - yield from cursor_to_bytes(cur) + yield from cur snapshot_get_cols = ['snapshot_id', 'name', 'target', 'target_type'] @@ -382,7 +386,7 @@ cur.execute(query, (snapshot_id, branches_from, branches_count, target_types)) - yield from cursor_to_bytes(cur) + yield from cur def snapshot_get_by_origin_visit(self, origin_id, visit_id, cur=None): cur = self._cursor(cur) @@ -393,7 +397,7 @@ cur.execute(query, (origin_id, visit_id)) ret = cur.fetchone() if ret: - return line_to_bytes(ret)[0] + return ret[0] content_find_cols = ['sha1', 'sha1_git', 'sha256', 'blake2s256', 'length', 'ctime', 'status'] @@ -420,7 +424,7 @@ LIMIT 1""" % ','.join(self.content_find_cols), (sha1, sha1_git, sha256, blake2s256)) - content = line_to_bytes(cur.fetchone()) + content = cur.fetchone() if set(content) == {None}: return None else: @@ -428,7 +432,7 @@ def directory_missing_from_list(self, directories, cur=None): cur = self._cursor(cur) - yield from execute_values_to_bytes( + yield from execute_values_generator( cur, """ SELECT id FROM (VALUES %s) as t(id) WHERE NOT EXISTS ( @@ -444,14 +448,14 @@ cols = ', '.join(self.directory_ls_cols) query = 'SELECT %s FROM swh_directory_walk_one(%%s)' % cols cur.execute(query, (directory,)) - yield from cursor_to_bytes(cur) + yield from cur def directory_walk(self, directory, cur=None): cur = self._cursor(cur) cols = ', '.join(self.directory_ls_cols) query = 'SELECT %s FROM swh_directory_walk(%%s)' % cols cur.execute(query, (directory,)) - yield from cursor_to_bytes(cur) + yield from cur def directory_entry_get_by_path(self, directory, paths, cur=None): """Retrieve a directory entry by path. @@ -467,12 +471,12 @@ data = cur.fetchone() if set(data) == {None}: return None - return line_to_bytes(data) + return data def revision_missing_from_list(self, revisions, cur=None): cur = self._cursor(cur) - yield from execute_values_to_bytes( + yield from execute_values_generator( cur, """ SELECT id FROM (VALUES %s) as t(id) WHERE NOT EXISTS ( @@ -552,7 +556,7 @@ cur.execute(query, args) - yield from cursor_to_bytes(cur) + yield from cur def origin_visit_get(self, origin_id, visit_id, cur=None): """Retrieve information on visit visit_id of origin origin_id. @@ -579,7 +583,7 @@ r = cur.fetchall() if not r: return None - return line_to_bytes(r[0]) + return r[0] def origin_visit_exists(self, origin_id, visit_id, cur=None): """Check whether an origin visit with the given ids exists""" @@ -625,7 +629,7 @@ r = cur.fetchone() if not r: return None - return line_to_bytes(r) + return r @staticmethod def mangle_query_key(key, main_table): @@ -657,7 +661,7 @@ for k in self.revision_get_cols ) - yield from execute_values_to_bytes( + yield from execute_values_generator( cur, """ SELECT %s FROM (VALUES %%s) as t(id) LEFT JOIN revision ON t.id = revision.id @@ -674,7 +678,7 @@ """ % ', '.join(self.revision_get_cols) cur.execute(query, (root_revisions, limit)) - yield from cursor_to_bytes(cur) + yield from cur revision_shortlog_cols = ['id', 'parents'] @@ -686,11 +690,11 @@ """ % ', '.join(self.revision_shortlog_cols) cur.execute(query, (root_revisions, limit)) - yield from cursor_to_bytes(cur) + yield from cur def release_missing_from_list(self, releases, cur=None): cur = self._cursor(cur) - yield from execute_values_to_bytes( + yield from execute_values_generator( cur, """ SELECT id FROM (VALUES %s) as t(id) WHERE NOT EXISTS ( @@ -703,7 +707,7 @@ def object_find_by_sha1_git(self, ids, cur=None): cur = self._cursor(cur) - yield from execute_values_to_bytes( + yield from execute_values_generator( cur, """ WITH t (id) AS (VALUES %s), known_objects as (( @@ -819,7 +823,7 @@ cur.execute(query, (type, url)) data = cur.fetchone() if data: - return line_to_bytes(data) + return data return None def origin_get(self, id, cur=None): @@ -835,7 +839,7 @@ cur.execute(query, (id,)) data = cur.fetchone() if data: - return line_to_bytes(data) + return data return None def origin_search(self, url_pattern, offset=0, limit=50, @@ -876,7 +880,7 @@ query_params = (url_pattern, offset, limit) cur.execute(query, query_params) - yield from cursor_to_bytes(cur) + yield from cur person_cols = ['fullname', 'name', 'email'] person_get_cols = person_cols + ['id'] @@ -892,7 +896,7 @@ WHERE id IN %%s""" % ', '.join(self.person_get_cols) cur.execute(query, (tuple(ids),)) - yield from cursor_to_bytes(cur) + yield from cur release_add_cols = [ 'id', 'target', 'target_type', 'date', 'date_offset', @@ -908,7 +912,7 @@ for k in self.release_get_cols ) - yield from execute_values_to_bytes( + yield from execute_values_generator( cur, """ SELECT %s FROM (VALUES %%s) as t(id) LEFT JOIN release ON t.id = release.id @@ -966,7 +970,7 @@ cur.execute(query, (origin_id, provider_type)) - yield from cursor_to_bytes(cur) + yield from cur tool_cols = ['id', 'name', 'version', 'configuration'] @@ -978,7 +982,7 @@ cur = self._cursor(cur) cur.execute("SELECT %s from swh_tool_add()" % ( ','.join(self.tool_cols), )) - yield from cursor_to_bytes(cur) + yield from cur def tool_get(self, name, version, configuration, cur=None): cur = self._cursor(cur) @@ -993,7 +997,7 @@ data = cur.fetchone() if not data: return None - return line_to_bytes(data) + return data metadata_provider_cols = ['id', 'provider_name', 'provider_type', 'provider_url', 'metadata'] @@ -1021,7 +1025,7 @@ data = cur.fetchone() if not data: return None - return line_to_bytes(data) + return data def metadata_provider_get_by(self, provider_name, provider_url, cur=None): @@ -1036,4 +1040,4 @@ data = cur.fetchone() if not data: return None - return line_to_bytes(data) + return data diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -2214,7 +2214,7 @@ datum = cur.fetchone() self.assertEqual( - (datum[0].tobytes(), datum[1].tobytes(), datum[2].tobytes(), + (datum[0], datum[1], datum[2], datum[3], datum[4]), (cont['sha1'], cont['sha1_git'], cont['sha256'], cont['length'], 'visible')) @@ -2240,7 +2240,7 @@ datum = cur.fetchone() self.assertEqual( - (datum[0].tobytes(), datum[1].tobytes(), datum[2].tobytes(), + (datum[0], datum[1], datum[2], datum[3], datum[4], datum[5], datum[6]), (cont['sha1'], cont['sha1_git'], cont['sha256'], cont['length'], 'visible', cont['test'], cont['test2']))