diff --git a/requirements-test-db.txt b/requirements-test-db.txt --- a/requirements-test-db.txt +++ b/requirements-test-db.txt @@ -1 +1,2 @@ pytest-postgresql +typing-extensions diff --git a/swh/core/db/__init__.py b/swh/core/db/__init__.py --- a/swh/core/db/__init__.py +++ b/swh/core/db/__init__.py @@ -3,7 +3,6 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import binascii import datetime import enum import json @@ -11,6 +10,7 @@ import os import sys import threading +from typing import Any, Callable, Iterable, Mapping, Optional from contextlib import contextmanager @@ -24,40 +24,88 @@ psycopg2.extras.register_uuid() -def escape(data): +def render_array(data) -> str: + """Render the data as a postgresql array""" + # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-IO + # "The external text representation of an array value consists of items that are + # interpreted according to the I/O conversion rules for the array's element type, + # plus decoration that indicates the array structure. The decoration consists of + # curly braces ({ and }) around the array value plus delimiter characters between + # adjacent items. The delimiter character is usually a comma (,)" + return "{%s}" % ",".join(render_array_element(e) for e in data) + + +def render_array_element(element) -> str: + """Render an element from an array.""" + if element is None: + # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-IO + # "If the value written for an element is NULL (in any case variant), the + # element is taken to be NULL." + return "NULL" + elif isinstance(element, (list, tuple)): + # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-INPUT + # "Each val is either a constant of the array element type, or a subarray." + return render_array(element) + else: + # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-IO + # "When writing an array value you can use double quotes around any individual + # array element. [...] Empty strings and strings matching the word NULL must be + # quoted, too. To put a double quote or backslash in a quoted array element + # value, precede it with a backslash." + ret = value_as_pg_text(element) + return '"%s"' % ret.replace("\\", "\\\\").replace('"', '\\"') + + +def value_as_pg_text(data: Any) -> str: + """Render the given data in the postgresql text format. + + NULL values are handled **outside** of this function (either by + :func:`render_array_element`, or by :meth:`BaseDb.copy_to`.) + """ + if data is None: - return "" + raise ValueError("value_as_pg_text doesn't handle NULLs") + if isinstance(data, bytes): - return "\\x%s" % binascii.hexlify(data).decode("ascii") - elif isinstance(data, str): - return '"%s"' % data.replace('"', '""') + return "\\x%s" % data.hex() elif isinstance(data, datetime.datetime): - # We escape twice to make sure the string generated by - # isoformat gets escaped - return escape(data.isoformat()) + return data.isoformat() elif isinstance(data, dict): - return escape(json.dumps(data)) - elif isinstance(data, list): - return escape("{%s}" % ",".join(escape(d) for d in data)) + return json.dumps(data) + elif isinstance(data, (list, tuple)): + return render_array(data) elif isinstance(data, psycopg2.extras.Range): - # We escape twice here too, so that we make sure - # everything gets passed to copy properly - return escape( - "%s%s,%s%s" - % ( - "[" if data.lower_inc else "(", - "-infinity" if data.lower_inf else escape(data.lower), - "infinity" if data.upper_inf else escape(data.upper), - "]" if data.upper_inc else ")", - ) + return "%s%s,%s%s" % ( + "[" if data.lower_inc else "(", + "-infinity" if data.lower_inf else value_as_pg_text(data.lower), + "infinity" if data.upper_inf else value_as_pg_text(data.upper), + "]" if data.upper_inc else ")", ) elif isinstance(data, enum.IntEnum): - return escape(int(data)) + return str(int(data)) else: - # We don't escape here to make sure we pass literals properly return str(data) +def escape_copy_column(column: str) -> str: + """Escape the text representation of a column for use by COPY.""" + # From https://www.postgresql.org/docs/11/sql-copy.html + # File Formats > Text Format + # "Backslash characters (\) can be used in the COPY data to quote data characters + # that might otherwise be taken as row or column delimiters. In particular, the + # following characters must be preceded by a backslash if they appear as part of a + # column value: backslash itself, newline, carriage return, and the current + # delimiter character." + ret = ( + column.replace("\\", "\\\\") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t") + ) + + return ret + + def typecast_bytea(value, cur): if value is not None: data = psycopg2.BINARY(value, cur) @@ -148,20 +196,31 @@ raise def copy_to( - self, items, tblname, columns, cur=None, item_cb=None, default_values={} - ): - """Copy items' entries to table tblname with columns information. + self, + items: Iterable[Mapping[str, Any]], + tblname: str, + columns: Iterable[str], + cur: Optional[psycopg2.extensions.cursor] = None, + item_cb: Optional[Callable[[Any], Any]] = None, + default_values: Optional[Mapping[str, Any]] = None, + ) -> None: + """Run the COPY command to insert the `columns` of each element of `items` into + `tblname`. Args: - items (List[dict]): dictionaries of data to copy over tblname. - tblname (str): destination table's name. - columns ([str]): keys to access data in items and also the - column names in the destination table. - default_values (dict): dictionary of default values to use when - inserting entried int the tblname table. + items: dictionaries of data to copy into `tblname`. + tblname: name of the destination table. + columns: columns of the destination table. Elements of `items` must have + these set as keys. + default_values: dictionary of default values to use when inserting entries + in `tblname`. cur: a db cursor; if not given, a new cursor will be created. - item_cb (fn): optional function to apply to items's entry. + item_cb: optional callback, run on each element of `items`, when it is + copied. + """ + if default_values is None: + default_values = {} read_file, write_file = os.pipe() exc_info = None @@ -172,7 +231,7 @@ with open(read_file, "r") as f: try: cursor.copy_expert( - "COPY %s (%s) FROM STDIN CSV" % (tblname, ", ".join(columns)), f + "COPY %s (%s) FROM STDIN" % (tblname, ", ".join(columns)), f ) except Exception: # Tell the main thread about the exception @@ -183,6 +242,15 @@ try: with open(write_file, "w") as f: + # From https://www.postgresql.org/docs/11/sql-copy.html + # File Formats > Text Format + # "When the text format is used, the data read or written is a text file + # with one line per table row. Columns in a row are separated by the + # delimiter character." + # NULL + # "The default is \N (backslash-N) in text format." + # DELIMITER + # "The default is a tab character in text format." for d in items: if item_cb is not None: item_cb(d) @@ -190,7 +258,10 @@ for k in columns: value = d.get(k, default_values.get(k)) try: - line.append(escape(value)) + if value is None: + line.append("\\N") + else: + line.append(escape_copy_column(value_as_pg_text(value))) except Exception as e: logger.error( "Could not escape value `%r` for column `%s`:" @@ -200,7 +271,7 @@ e, ) raise e from None - f.write(",".join(line)) + f.write("\t".join(line)) f.write("\n") finally: diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py --- a/swh/core/db/tests/test_db.py +++ b/swh/core/db/tests/test_db.py @@ -3,13 +3,21 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from dataclasses import dataclass +import datetime +from enum import IntEnum import inspect import os.path +from string import printable import tempfile +from typing import Any +from typing_extensions import Protocol import unittest from unittest.mock import Mock, MagicMock +import uuid from hypothesis import strategies, given +from hypothesis.extra.pytz import timezones import psycopg2 import pytest @@ -23,31 +31,216 @@ ) -INIT_SQL = """ -create table test_table -( - i int, - txt text, - bytes bytea -); -""" - -db_rows = strategies.lists( - strategies.tuples( - strategies.integers(-2147483648, +2147483647), - strategies.text( - alphabet=strategies.characters( - blacklist_categories=["Cs"], # surrogates - blacklist_characters=[ - "\x00", # pgsql does not support the null codepoint - "\r", # pgsql normalizes those - ], +# workaround mypy bug https://github.com/python/mypy/issues/5485 +class Converter(Protocol): + def __call__(self, x: Any) -> Any: + ... + + +@dataclass +class Field: + name: str + """Column name""" + pg_type: str + """Type of the PostgreSQL column""" + example: Any + """Example value for the static tests""" + strategy: strategies.SearchStrategy + """Hypothesis strategy to generate these values""" + in_wrapper: Converter = lambda x: x + """Wrapper to convert this data type for the static tests""" + out_converter: Converter = lambda x: x + """Converter from the raw PostgreSQL column value to this data type""" + + +# Limit PostgreSQL integer values +pg_int = strategies.integers(-2147483648, +2147483647) + +pg_text = strategies.text( + alphabet=strategies.characters( + blacklist_categories=["Cs"], # surrogates + blacklist_characters=[ + "\x00", # pgsql does not support the null codepoint + "\r", # pgsql normalizes those + ], + ), +) + +pg_bytea = strategies.binary() + + +def pg_bytea_a(min_size: int, max_size: int) -> strategies.SearchStrategy: + """Generate a PostgreSQL bytea[]""" + return strategies.lists(pg_bytea, min_size=min_size, max_size=max_size) + + +def pg_bytea_a_a(min_size: int, max_size: int) -> strategies.SearchStrategy: + """Generate a PostgreSQL bytea[][]. The inner lists must all have the same size.""" + return strategies.integers(min_value=max(1, min_size), max_value=max_size).flatmap( + lambda n: strategies.lists( + pg_bytea_a(min_size=n, max_size=n), min_size=min_size, max_size=max_size + ) + ) + + +def pg_tstz() -> strategies.SearchStrategy: + """Generate values that fit in a PostgreSQL timestamptz. + + Notes: + We're forbidding old datetimes, because until 1956, many timezones had + seconds in their "UTC offsets" (see + ), which is + not representable by PostgreSQL. + + """ + min_value = datetime.datetime(1960, 1, 1, 0, 0, 0) + return strategies.datetimes(min_value=min_value, timezones=timezones()) + + +def pg_jsonb(min_size: int, max_size: int) -> strategies.SearchStrategy: + """Generate values representable as a PostgreSQL jsonb object (dict).""" + return strategies.dictionaries( + strategies.text(printable), + strategies.recursive( + # should use floats() instead of integers(), but PostgreSQL + # coerces large integers into floats, making the tests fail. We + # only store ints in our generated data anyway. + strategies.none() + | strategies.booleans() + | strategies.integers(-2147483648, +2147483647) + | strategies.text(printable), + lambda children: strategies.lists(children, max_size=max_size) + | strategies.dictionaries( + strategies.text(printable), children, max_size=max_size ), ), - strategies.binary(), + min_size=min_size, + max_size=max_size, ) + + +def tuple_2d_to_list_2d(v): + """Convert a 2D tuple to a 2D list""" + return [list(inner) for inner in v] + + +def list_2d_to_tuple_2d(v): + """Convert a 2D list to a 2D tuple""" + return tuple(tuple(inner) for inner in v) + + +class TestIntEnum(IntEnum): + foo = 1 + bar = 2 + + +def now(): + return datetime.datetime.now(tz=datetime.timezone.utc) + + +FIELDS = ( + Field("i", "int", 1, pg_int), + Field("txt", "text", "foo", pg_text), + Field("bytes", "bytea", b"bar", strategies.binary()), + Field( + "bytes_array", + "bytea[]", + [b"baz1", b"baz2"], + pg_bytea_a(min_size=0, max_size=5), + ), + Field( + "bytes_tuple", + "bytea[]", + (b"baz1", b"baz2"), + pg_bytea_a(min_size=0, max_size=5).map(tuple), + in_wrapper=list, + out_converter=tuple, + ), + Field( + "bytes_2d", + "bytea[][]", + [[b"quux1"], [b"quux2"]], + pg_bytea_a_a(min_size=0, max_size=5), + ), + Field( + "bytes_2d_tuple", + "bytea[][]", + ((b"quux1",), (b"quux2",)), + pg_bytea_a_a(min_size=0, max_size=5).map(list_2d_to_tuple_2d), + in_wrapper=tuple_2d_to_list_2d, + out_converter=list_2d_to_tuple_2d, + ), + Field("ts", "timestamptz", now(), pg_tstz(),), + Field( + "dict", + "jsonb", + {"str": "bar", "int": 1, "list": ["a", "b"], "nested": {"a": "b"}}, + pg_jsonb(min_size=0, max_size=5), + in_wrapper=psycopg2.extras.Json, + ), + Field( + "intenum", + "int", + TestIntEnum.foo, + strategies.sampled_from(TestIntEnum), + in_wrapper=int, + out_converter=TestIntEnum, + ), + Field("uuid", "uuid", uuid.uuid4(), strategies.uuids()), + Field( + "text_list", + "text[]", + # All the funky corner cases + ["null", "NULL", None, "\\", "\t", "\n", "\r", " ", "'", ",", '"', "{", "}"], + strategies.lists(pg_text, min_size=0, max_size=5), + ), + Field( + "tstz_list", + "timestamptz[]", + [now(), now() + datetime.timedelta(days=1)], + strategies.lists(pg_tstz(), min_size=0, max_size=5), + ), + Field( + "tstz_range", + "tstzrange", + psycopg2.extras.DateTimeTZRange( + lower=now(), upper=now() + datetime.timedelta(days=1), bounds="[)", + ), + strategies.tuples( + # generate two sorted timestamptzs for use as bounds + strategies.tuples(pg_tstz(), pg_tstz()).map(sorted), + # and a set of bounds + strategies.sampled_from(["[]", "()", "[)", "(]"]), + ).map( + # and build the actual DateTimeTZRange object from these args + lambda args: psycopg2.extras.DateTimeTZRange( + lower=args[0][0], upper=args[0][1], bounds=args[1], + ) + ), + ), ) +INIT_SQL = "create table test_table (%s)" % ", ".join( + f"{field.name} {field.pg_type}" for field in FIELDS +) + +COLUMNS = tuple(field.name for field in FIELDS) +INSERT_SQL = "insert into test_table (%s) values (%s)" % ( + ", ".join(COLUMNS), + ", ".join("%s" for i in range(len(COLUMNS))), +) + +STATIC_ROW_IN = tuple(field.in_wrapper(field.example) for field in FIELDS) +EXPECTED_ROW_OUT = tuple(field.example for field in FIELDS) + +db_rows = strategies.lists(strategies.tuples(*(field.strategy for field in FIELDS))) + + +def convert_lines(cur): + return [ + tuple(field.out_converter(x) for x, field in zip(line, FIELDS)) for line in cur + ] + @pytest.mark.db def test_connect(): @@ -55,10 +248,13 @@ try: db = BaseDb.connect("dbname=%s" % db_name) with db.cursor() as cur: + psycopg2.extras.register_default_jsonb(cur) cur.execute(INIT_SQL) - cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar")) + cur.execute(INSERT_SQL, STATIC_ROW_IN) cur.execute("select * from test_table;") - assert list(cur) == [(1, "foo", b"bar")] + output = convert_lines(cur) + assert len(output) == 1 + assert EXPECTED_ROW_OUT == output[0] finally: db_close(db.conn) db_destroy(db_name) @@ -84,35 +280,51 @@ def test_initialized(self): cur = self.db.cursor() - cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar")) + psycopg2.extras.register_default_jsonb(cur) + cur.execute(INSERT_SQL, STATIC_ROW_IN) cur.execute("select * from test_table;") - self.assertEqual(list(cur), [(1, "foo", b"bar")]) + output = convert_lines(cur) + assert len(output) == 1 + assert EXPECTED_ROW_OUT == output[0] def test_reset_tables(self): cur = self.db.cursor() - cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar")) + cur.execute(INSERT_SQL, STATIC_ROW_IN) self.reset_db_tables("test-db") cur.execute("select * from test_table;") - self.assertEqual(list(cur), []) + assert convert_lines(cur) == [] + + def test_copy_to_static(self): + items = [{field.name: field.example for field in FIELDS}] + self.db.copy_to(items, "test_table", COLUMNS) + + cur = self.db.cursor() + cur.execute("select * from test_table;") + output = convert_lines(cur) + assert len(output) == 1 + assert EXPECTED_ROW_OUT == output[0] @given(db_rows) def test_copy_to(self, data): - # the table is not reset between runs by hypothesis - self.reset_db_tables("test-db") + try: + # the table is not reset between runs by hypothesis + self.reset_db_tables("test-db") - items = [dict(zip(["i", "txt", "bytes"], item)) for item in data] - self.db.copy_to(items, "test_table", ["i", "txt", "bytes"]) + items = [dict(zip(COLUMNS, item)) for item in data] + self.db.copy_to(items, "test_table", COLUMNS) - cur = self.db.cursor() - cur.execute("select * from test_table;") - self.assertCountEqual(list(cur), data) + cur = self.db.cursor() + cur.execute("select * from test_table;") + assert convert_lines(cur) == data + finally: + self.db.conn.rollback() def test_copy_to_thread_exception(self): data = [(2 ** 65, "foo", b"bar")] - items = [dict(zip(["i", "txt", "bytes"], item)) for item in data] + items = [dict(zip(COLUMNS, item)) for item in data] with self.assertRaises(psycopg2.errors.NumericValueOutOfRange): - self.db.copy_to(items, "test_table", ["i", "txt", "bytes"]) + self.db.copy_to(items, "test_table", COLUMNS) def test_db_transaction(mocker):