diff --git a/swh/core/db/__init__.py b/swh/core/db/__init__.py --- a/swh/core/db/__init__.py +++ b/swh/core/db/__init__.py @@ -3,7 +3,6 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import binascii import datetime import enum import json @@ -25,37 +24,50 @@ def escape(data): + """Render the given data as a postgresql string""" + + def clean_string(string: str) -> str: + return ( + string.replace("\\", "\\\\") + .replace("\t", "\\t") + .replace("\n", "\\n") + .replace("\r", "\\r") + ) + if data is None: - return "" + return "\\N" if isinstance(data, bytes): - return "\\x%s" % binascii.hexlify(data).decode("ascii") + return "\\\\x%s" % data.hex() elif isinstance(data, str): - return '"%s"' % data.replace('"', '""') + return clean_string(data) elif isinstance(data, datetime.datetime): - # We escape twice to make sure the string generated by - # isoformat gets escaped - return escape(data.isoformat()) + return clean_string(data.isoformat()) elif isinstance(data, dict): - return escape(json.dumps(data)) + return clean_string(json.dumps(data)) elif isinstance(data, list): - return escape("{%s}" % ",".join(escape(d) for d in data)) + escaped = [] + for v in data: + if v is None: + escaped.append("NULL") + elif v == "NULL": + escaped.append('"NULL"') + elif isinstance(v, list): + # Nested arrays are annoying + escaped.append(escape(v)) + else: + escaped.append(escape(v).replace("\\", "\\\\")) + return "{%s}" % ",".join(escaped) elif isinstance(data, psycopg2.extras.Range): - # We escape twice here too, so that we make sure - # everything gets passed to copy properly - return escape( - "%s%s,%s%s" - % ( - "[" if data.lower_inc else "(", - "-infinity" if data.lower_inf else escape(data.lower), - "infinity" if data.upper_inf else escape(data.upper), - "]" if data.upper_inc else ")", - ) + return "%s%s,%s%s" % ( + "[" if data.lower_inc else "(", + "-infinity" if data.lower_inf else escape(data.lower), + "infinity" if data.upper_inf else escape(data.upper), + "]" if data.upper_inc else ")", ) elif isinstance(data, enum.IntEnum): - return escape(int(data)) + return str(int(data)) else: - # We don't escape here to make sure we pass literals properly - return str(data) + return clean_string(str(data)) def typecast_bytea(value, cur): @@ -172,7 +184,7 @@ with open(read_file, "r") as f: try: cursor.copy_expert( - "COPY %s (%s) FROM STDIN CSV" % (tblname, ", ".join(columns)), f + "COPY %s (%s) FROM STDIN" % (tblname, ", ".join(columns)), f ) except Exception: # Tell the main thread about the exception @@ -200,7 +212,7 @@ e, ) raise e from None - f.write(",".join(line)) + f.write("\t".join(line)) f.write("\n") finally: diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py --- a/swh/core/db/tests/test_db.py +++ b/swh/core/db/tests/test_db.py @@ -3,13 +3,16 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import datetime import inspect import os.path +from string import printable import tempfile import unittest from unittest.mock import Mock, MagicMock from hypothesis import strategies, given +from hypothesis.extra.pytz import timezones import psycopg2 import pytest @@ -26,12 +29,43 @@ INIT_SQL = """ create table test_table ( - i int, - txt text, - bytes bytea + i int, + txt text, + bytes bytea, + bytes_arr bytea[], + bytes_arr2 bytea[][], + ts timestamptz, + dict jsonb ); """ +COLUMNS = ["i", "txt", "bytes", "bytes_arr", "bytes_arr2", "ts", "dict"] +STATIC_ROW_OUT = [ + "foo", + b"bar", + [b"baz", b"baz"], + [[b"quux"]], + datetime.datetime.now(tz=datetime.timezone.utc), + {"str": "bar", "int": 1, "list": ["a", "b"], "nested": {"a": "b"}}, +] +STATIC_ROW_IN = tuple( + psycopg2.extras.Json(v) if isinstance(v, dict) else v for v in STATIC_ROW_OUT +) + + +def aware_datetimes(): + # datetimes in Software Heritage are not used for software artifacts + # (which may be much older than 2000), but only for objects like scheduler + # task runs, and origin visits, which were created by Software Heritage, + # so at least in 2015. + # We're forbidding old datetimes, because until 1956, many timezones had seconds + # in their "UTC offsets" (see + # ), which is not + # representable by postgresql. + min_value = datetime.datetime(2000, 1, 1, 0, 0, 0) + return strategies.datetimes(min_value=min_value, timezones=timezones()) + + db_rows = strategies.lists( strategies.tuples( strategies.integers(-2147483648, +2147483647), @@ -45,6 +79,35 @@ ), ), strategies.binary(), + # bytea[] + strategies.lists(strategies.binary()), + # bytea[][] needs a uniform size >= 1 for the subarrays. + strategies.integers(min_value=1, max_value=10).flatmap( + lambda n: strategies.lists( + strategies.lists(strategies.binary(), min_size=n, max_size=n) + ) + ), + # timestamptz + aware_datetimes(), + # jsonb + strategies.dictionaries( + strategies.text(printable), + strategies.recursive( + # should use floats() instead of integers(), but postgresql + # coerces large integers into floats, making the tests fail. We + # only store ints in our generated data anyway. + strategies.none() + | strategies.booleans() + | strategies.integers(-2147483648, +2147483647) + | strategies.text(printable), + lambda children: strategies.lists(children, max_size=3) + | strategies.dictionaries( + strategies.text(printable), children, max_size=3 + ), + ), + min_size=0, + max_size=5, + ), ) ) @@ -55,10 +118,15 @@ try: db = BaseDb.connect("dbname=%s" % db_name) with db.cursor() as cur: + psycopg2.extras.register_default_jsonb(cur) cur.execute(INIT_SQL) - cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar")) + cur.execute( + "insert into test_table values (1, %s);" + % (",".join("%s" for i in range(len(COLUMNS) - 1))), + STATIC_ROW_IN, + ) cur.execute("select * from test_table;") - assert list(cur) == [(1, "foo", b"bar")] + assert list(cur) == [tuple([1] + STATIC_ROW_OUT)] finally: db_close(db.conn) db_destroy(db_name) @@ -84,35 +152,49 @@ def test_initialized(self): cur = self.db.cursor() - cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar")) + psycopg2.extras.register_default_jsonb(cur) + cur.execute( + "insert into test_table values (1, %s);" + % (",".join("%s" for i in range(len(COLUMNS) - 1))), + STATIC_ROW_IN, + ) cur.execute("select * from test_table;") - self.assertEqual(list(cur), [(1, "foo", b"bar")]) + assert list(cur) == [tuple([1] + STATIC_ROW_OUT)] def test_reset_tables(self): cur = self.db.cursor() - cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar")) + psycopg2.extras.register_default_jsonb(cur) + cur.execute( + "insert into test_table values (1, %s);" + % (",".join("%s" for i in range(len(COLUMNS) - 1))), + STATIC_ROW_IN, + ) self.reset_db_tables("test-db") cur.execute("select * from test_table;") self.assertEqual(list(cur), []) @given(db_rows) def test_copy_to(self, data): - # the table is not reset between runs by hypothesis - self.reset_db_tables("test-db") + try: + # the table is not reset between runs by hypothesis + self.reset_db_tables("test-db") - items = [dict(zip(["i", "txt", "bytes"], item)) for item in data] - self.db.copy_to(items, "test_table", ["i", "txt", "bytes"]) + items = [dict(zip(COLUMNS, item)) for item in data] + print(items) + self.db.copy_to(items, "test_table", COLUMNS) - cur = self.db.cursor() - cur.execute("select * from test_table;") - self.assertCountEqual(list(cur), data) + cur = self.db.cursor() + cur.execute("select * from test_table;") + self.assertCountEqual(list(cur), data) + finally: + self.db.conn.rollback() def test_copy_to_thread_exception(self): data = [(2 ** 65, "foo", b"bar")] - items = [dict(zip(["i", "txt", "bytes"], item)) for item in data] + items = [dict(zip(COLUMNS, item)) for item in data] with self.assertRaises(psycopg2.errors.NumericValueOutOfRange): - self.db.copy_to(items, "test_table", ["i", "txt", "bytes"]) + self.db.copy_to(items, "test_table", COLUMNS) def test_db_transaction(mocker):