Page MenuHomeSoftware Heritage

D3394.diff
No OneTemporary

D3394.diff

diff --git a/requirements-test-db.txt b/requirements-test-db.txt
--- a/requirements-test-db.txt
+++ b/requirements-test-db.txt
@@ -1 +1,2 @@
pytest-postgresql
+typing-extensions
diff --git a/swh/core/db/__init__.py b/swh/core/db/__init__.py
--- a/swh/core/db/__init__.py
+++ b/swh/core/db/__init__.py
@@ -3,7 +3,6 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
-import binascii
import datetime
import enum
import json
@@ -11,6 +10,7 @@
import os
import sys
import threading
+from typing import Any, Callable, Iterable, Mapping, Optional
from contextlib import contextmanager
@@ -24,40 +24,88 @@
psycopg2.extras.register_uuid()
-def escape(data):
+def render_array(data) -> str:
+ """Render the data as a postgresql array"""
+ # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-IO
+ # "The external text representation of an array value consists of items that are
+ # interpreted according to the I/O conversion rules for the array's element type,
+ # plus decoration that indicates the array structure. The decoration consists of
+ # curly braces ({ and }) around the array value plus delimiter characters between
+ # adjacent items. The delimiter character is usually a comma (,)"
+ return "{%s}" % ",".join(render_array_element(e) for e in data)
+
+
+def render_array_element(element) -> str:
+ """Render an element from an array."""
+ if element is None:
+ # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-IO
+ # "If the value written for an element is NULL (in any case variant), the
+ # element is taken to be NULL."
+ return "NULL"
+ elif isinstance(element, (list, tuple)):
+ # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-INPUT
+ # "Each val is either a constant of the array element type, or a subarray."
+ return render_array(element)
+ else:
+ # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-IO
+ # "When writing an array value you can use double quotes around any individual
+ # array element. [...] Empty strings and strings matching the word NULL must be
+ # quoted, too. To put a double quote or backslash in a quoted array element
+ # value, precede it with a backslash."
+ ret = value_as_pg_text(element)
+ return '"%s"' % ret.replace("\\", "\\\\").replace('"', '\\"')
+
+
+def value_as_pg_text(data: Any) -> str:
+ """Render the given data in the postgresql text format.
+
+ NULL values are handled **outside** of this function (either by
+ :func:`render_array_element`, or by :meth:`BaseDb.copy_to`.)
+ """
+
if data is None:
- return ""
+ raise ValueError("value_as_pg_text doesn't handle NULLs")
+
if isinstance(data, bytes):
- return "\\x%s" % binascii.hexlify(data).decode("ascii")
- elif isinstance(data, str):
- return '"%s"' % data.replace('"', '""')
+ return "\\x%s" % data.hex()
elif isinstance(data, datetime.datetime):
- # We escape twice to make sure the string generated by
- # isoformat gets escaped
- return escape(data.isoformat())
+ return data.isoformat()
elif isinstance(data, dict):
- return escape(json.dumps(data))
- elif isinstance(data, list):
- return escape("{%s}" % ",".join(escape(d) for d in data))
+ return json.dumps(data)
+ elif isinstance(data, (list, tuple)):
+ return render_array(data)
elif isinstance(data, psycopg2.extras.Range):
- # We escape twice here too, so that we make sure
- # everything gets passed to copy properly
- return escape(
- "%s%s,%s%s"
- % (
- "[" if data.lower_inc else "(",
- "-infinity" if data.lower_inf else escape(data.lower),
- "infinity" if data.upper_inf else escape(data.upper),
- "]" if data.upper_inc else ")",
- )
+ return "%s%s,%s%s" % (
+ "[" if data.lower_inc else "(",
+ "-infinity" if data.lower_inf else value_as_pg_text(data.lower),
+ "infinity" if data.upper_inf else value_as_pg_text(data.upper),
+ "]" if data.upper_inc else ")",
)
elif isinstance(data, enum.IntEnum):
- return escape(int(data))
+ return str(int(data))
else:
- # We don't escape here to make sure we pass literals properly
return str(data)
+def escape_copy_column(column: str) -> str:
+ """Escape the text representation of a column for use by COPY."""
+ # From https://www.postgresql.org/docs/11/sql-copy.html
+ # File Formats > Text Format
+ # "Backslash characters (\) can be used in the COPY data to quote data characters
+ # that might otherwise be taken as row or column delimiters. In particular, the
+ # following characters must be preceded by a backslash if they appear as part of a
+ # column value: backslash itself, newline, carriage return, and the current
+ # delimiter character."
+ ret = (
+ column.replace("\\", "\\\\")
+ .replace("\n", "\\n")
+ .replace("\r", "\\r")
+ .replace("\t", "\\t")
+ )
+
+ return ret
+
+
def typecast_bytea(value, cur):
if value is not None:
data = psycopg2.BINARY(value, cur)
@@ -148,20 +196,31 @@
raise
def copy_to(
- self, items, tblname, columns, cur=None, item_cb=None, default_values={}
- ):
- """Copy items' entries to table tblname with columns information.
+ self,
+ items: Iterable[Mapping[str, Any]],
+ tblname: str,
+ columns: Iterable[str],
+ cur: Optional[psycopg2.extensions.cursor] = None,
+ item_cb: Optional[Callable[[Any], Any]] = None,
+ default_values: Optional[Mapping[str, Any]] = None,
+ ) -> None:
+ """Run the COPY command to insert the `columns` of each element of `items` into
+ `tblname`.
Args:
- items (List[dict]): dictionaries of data to copy over tblname.
- tblname (str): destination table's name.
- columns ([str]): keys to access data in items and also the
- column names in the destination table.
- default_values (dict): dictionary of default values to use when
- inserting entried int the tblname table.
+ items: dictionaries of data to copy into `tblname`.
+ tblname: name of the destination table.
+ columns: columns of the destination table. Elements of `items` must have
+ these set as keys.
+ default_values: dictionary of default values to use when inserting entries
+ in `tblname`.
cur: a db cursor; if not given, a new cursor will be created.
- item_cb (fn): optional function to apply to items's entry.
+ item_cb: optional callback, run on each element of `items`, when it is
+ copied.
+
"""
+ if default_values is None:
+ default_values = {}
read_file, write_file = os.pipe()
exc_info = None
@@ -172,7 +231,7 @@
with open(read_file, "r") as f:
try:
cursor.copy_expert(
- "COPY %s (%s) FROM STDIN CSV" % (tblname, ", ".join(columns)), f
+ "COPY %s (%s) FROM STDIN" % (tblname, ", ".join(columns)), f
)
except Exception:
# Tell the main thread about the exception
@@ -183,6 +242,15 @@
try:
with open(write_file, "w") as f:
+ # From https://www.postgresql.org/docs/11/sql-copy.html
+ # File Formats > Text Format
+ # "When the text format is used, the data read or written is a text file
+ # with one line per table row. Columns in a row are separated by the
+ # delimiter character."
+ # NULL
+ # "The default is \N (backslash-N) in text format."
+ # DELIMITER
+ # "The default is a tab character in text format."
for d in items:
if item_cb is not None:
item_cb(d)
@@ -190,7 +258,10 @@
for k in columns:
value = d.get(k, default_values.get(k))
try:
- line.append(escape(value))
+ if value is None:
+ line.append("\\N")
+ else:
+ line.append(escape_copy_column(value_as_pg_text(value)))
except Exception as e:
logger.error(
"Could not escape value `%r` for column `%s`:"
@@ -200,7 +271,7 @@
e,
)
raise e from None
- f.write(",".join(line))
+ f.write("\t".join(line))
f.write("\n")
finally:
diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py
--- a/swh/core/db/tests/test_db.py
+++ b/swh/core/db/tests/test_db.py
@@ -3,13 +3,21 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
+from dataclasses import dataclass
+import datetime
+from enum import IntEnum
import inspect
import os.path
+from string import printable
import tempfile
+from typing import Any
+from typing_extensions import Protocol
import unittest
from unittest.mock import Mock, MagicMock
+import uuid
from hypothesis import strategies, given
+from hypothesis.extra.pytz import timezones
import psycopg2
import pytest
@@ -23,31 +31,216 @@
)
-INIT_SQL = """
-create table test_table
-(
- i int,
- txt text,
- bytes bytea
-);
-"""
-
-db_rows = strategies.lists(
- strategies.tuples(
- strategies.integers(-2147483648, +2147483647),
- strategies.text(
- alphabet=strategies.characters(
- blacklist_categories=["Cs"], # surrogates
- blacklist_characters=[
- "\x00", # pgsql does not support the null codepoint
- "\r", # pgsql normalizes those
- ],
+# workaround mypy bug https://github.com/python/mypy/issues/5485
+class Converter(Protocol):
+ def __call__(self, x: Any) -> Any:
+ ...
+
+
+@dataclass
+class Field:
+ name: str
+ """Column name"""
+ pg_type: str
+ """Type of the PostgreSQL column"""
+ example: Any
+ """Example value for the static tests"""
+ strategy: strategies.SearchStrategy
+ """Hypothesis strategy to generate these values"""
+ in_wrapper: Converter = lambda x: x
+ """Wrapper to convert this data type for the static tests"""
+ out_converter: Converter = lambda x: x
+ """Converter from the raw PostgreSQL column value to this data type"""
+
+
+# Limit PostgreSQL integer values
+pg_int = strategies.integers(-2147483648, +2147483647)
+
+pg_text = strategies.text(
+ alphabet=strategies.characters(
+ blacklist_categories=["Cs"], # surrogates
+ blacklist_characters=[
+ "\x00", # pgsql does not support the null codepoint
+ "\r", # pgsql normalizes those
+ ],
+ ),
+)
+
+pg_bytea = strategies.binary()
+
+
+def pg_bytea_a(min_size: int, max_size: int) -> strategies.SearchStrategy:
+ """Generate a PostgreSQL bytea[]"""
+ return strategies.lists(pg_bytea, min_size=min_size, max_size=max_size)
+
+
+def pg_bytea_a_a(min_size: int, max_size: int) -> strategies.SearchStrategy:
+ """Generate a PostgreSQL bytea[][]. The inner lists must all have the same size."""
+ return strategies.integers(min_value=max(1, min_size), max_value=max_size).flatmap(
+ lambda n: strategies.lists(
+ pg_bytea_a(min_size=n, max_size=n), min_size=min_size, max_size=max_size
+ )
+ )
+
+
+def pg_tstz() -> strategies.SearchStrategy:
+ """Generate values that fit in a PostgreSQL timestamptz.
+
+ Notes:
+ We're forbidding old datetimes, because until 1956, many timezones had
+ seconds in their "UTC offsets" (see
+ <https://en.wikipedia.org/wiki/Time_zone Worldwide_time_zones>), which is
+ not representable by PostgreSQL.
+
+ """
+ min_value = datetime.datetime(1960, 1, 1, 0, 0, 0)
+ return strategies.datetimes(min_value=min_value, timezones=timezones())
+
+
+def pg_jsonb(min_size: int, max_size: int) -> strategies.SearchStrategy:
+ """Generate values representable as a PostgreSQL jsonb object (dict)."""
+ return strategies.dictionaries(
+ strategies.text(printable),
+ strategies.recursive(
+ # should use floats() instead of integers(), but PostgreSQL
+ # coerces large integers into floats, making the tests fail. We
+ # only store ints in our generated data anyway.
+ strategies.none()
+ | strategies.booleans()
+ | strategies.integers(-2147483648, +2147483647)
+ | strategies.text(printable),
+ lambda children: strategies.lists(children, max_size=max_size)
+ | strategies.dictionaries(
+ strategies.text(printable), children, max_size=max_size
),
),
- strategies.binary(),
+ min_size=min_size,
+ max_size=max_size,
)
+
+
+def tuple_2d_to_list_2d(v):
+ """Convert a 2D tuple to a 2D list"""
+ return [list(inner) for inner in v]
+
+
+def list_2d_to_tuple_2d(v):
+ """Convert a 2D list to a 2D tuple"""
+ return tuple(tuple(inner) for inner in v)
+
+
+class TestIntEnum(IntEnum):
+ foo = 1
+ bar = 2
+
+
+def now():
+ return datetime.datetime.now(tz=datetime.timezone.utc)
+
+
+FIELDS = (
+ Field("i", "int", 1, pg_int),
+ Field("txt", "text", "foo", pg_text),
+ Field("bytes", "bytea", b"bar", strategies.binary()),
+ Field(
+ "bytes_array",
+ "bytea[]",
+ [b"baz1", b"baz2"],
+ pg_bytea_a(min_size=0, max_size=5),
+ ),
+ Field(
+ "bytes_tuple",
+ "bytea[]",
+ (b"baz1", b"baz2"),
+ pg_bytea_a(min_size=0, max_size=5).map(tuple),
+ in_wrapper=list,
+ out_converter=tuple,
+ ),
+ Field(
+ "bytes_2d",
+ "bytea[][]",
+ [[b"quux1"], [b"quux2"]],
+ pg_bytea_a_a(min_size=0, max_size=5),
+ ),
+ Field(
+ "bytes_2d_tuple",
+ "bytea[][]",
+ ((b"quux1",), (b"quux2",)),
+ pg_bytea_a_a(min_size=0, max_size=5).map(list_2d_to_tuple_2d),
+ in_wrapper=tuple_2d_to_list_2d,
+ out_converter=list_2d_to_tuple_2d,
+ ),
+ Field("ts", "timestamptz", now(), pg_tstz(),),
+ Field(
+ "dict",
+ "jsonb",
+ {"str": "bar", "int": 1, "list": ["a", "b"], "nested": {"a": "b"}},
+ pg_jsonb(min_size=0, max_size=5),
+ in_wrapper=psycopg2.extras.Json,
+ ),
+ Field(
+ "intenum",
+ "int",
+ TestIntEnum.foo,
+ strategies.sampled_from(TestIntEnum),
+ in_wrapper=int,
+ out_converter=TestIntEnum,
+ ),
+ Field("uuid", "uuid", uuid.uuid4(), strategies.uuids()),
+ Field(
+ "text_list",
+ "text[]",
+ # All the funky corner cases
+ ["null", "NULL", None, "\\", "\t", "\n", "\r", " ", "'", ",", '"', "{", "}"],
+ strategies.lists(pg_text, min_size=0, max_size=5),
+ ),
+ Field(
+ "tstz_list",
+ "timestamptz[]",
+ [now(), now() + datetime.timedelta(days=1)],
+ strategies.lists(pg_tstz(), min_size=0, max_size=5),
+ ),
+ Field(
+ "tstz_range",
+ "tstzrange",
+ psycopg2.extras.DateTimeTZRange(
+ lower=now(), upper=now() + datetime.timedelta(days=1), bounds="[)",
+ ),
+ strategies.tuples(
+ # generate two sorted timestamptzs for use as bounds
+ strategies.tuples(pg_tstz(), pg_tstz()).map(sorted),
+ # and a set of bounds
+ strategies.sampled_from(["[]", "()", "[)", "(]"]),
+ ).map(
+ # and build the actual DateTimeTZRange object from these args
+ lambda args: psycopg2.extras.DateTimeTZRange(
+ lower=args[0][0], upper=args[0][1], bounds=args[1],
+ )
+ ),
+ ),
)
+INIT_SQL = "create table test_table (%s)" % ", ".join(
+ f"{field.name} {field.pg_type}" for field in FIELDS
+)
+
+COLUMNS = tuple(field.name for field in FIELDS)
+INSERT_SQL = "insert into test_table (%s) values (%s)" % (
+ ", ".join(COLUMNS),
+ ", ".join("%s" for i in range(len(COLUMNS))),
+)
+
+STATIC_ROW_IN = tuple(field.in_wrapper(field.example) for field in FIELDS)
+EXPECTED_ROW_OUT = tuple(field.example for field in FIELDS)
+
+db_rows = strategies.lists(strategies.tuples(*(field.strategy for field in FIELDS)))
+
+
+def convert_lines(cur):
+ return [
+ tuple(field.out_converter(x) for x, field in zip(line, FIELDS)) for line in cur
+ ]
+
@pytest.mark.db
def test_connect():
@@ -55,10 +248,13 @@
try:
db = BaseDb.connect("dbname=%s" % db_name)
with db.cursor() as cur:
+ psycopg2.extras.register_default_jsonb(cur)
cur.execute(INIT_SQL)
- cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar"))
+ cur.execute(INSERT_SQL, STATIC_ROW_IN)
cur.execute("select * from test_table;")
- assert list(cur) == [(1, "foo", b"bar")]
+ output = convert_lines(cur)
+ assert len(output) == 1
+ assert EXPECTED_ROW_OUT == output[0]
finally:
db_close(db.conn)
db_destroy(db_name)
@@ -84,35 +280,51 @@
def test_initialized(self):
cur = self.db.cursor()
- cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar"))
+ psycopg2.extras.register_default_jsonb(cur)
+ cur.execute(INSERT_SQL, STATIC_ROW_IN)
cur.execute("select * from test_table;")
- self.assertEqual(list(cur), [(1, "foo", b"bar")])
+ output = convert_lines(cur)
+ assert len(output) == 1
+ assert EXPECTED_ROW_OUT == output[0]
def test_reset_tables(self):
cur = self.db.cursor()
- cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar"))
+ cur.execute(INSERT_SQL, STATIC_ROW_IN)
self.reset_db_tables("test-db")
cur.execute("select * from test_table;")
- self.assertEqual(list(cur), [])
+ assert convert_lines(cur) == []
+
+ def test_copy_to_static(self):
+ items = [{field.name: field.example for field in FIELDS}]
+ self.db.copy_to(items, "test_table", COLUMNS)
+
+ cur = self.db.cursor()
+ cur.execute("select * from test_table;")
+ output = convert_lines(cur)
+ assert len(output) == 1
+ assert EXPECTED_ROW_OUT == output[0]
@given(db_rows)
def test_copy_to(self, data):
- # the table is not reset between runs by hypothesis
- self.reset_db_tables("test-db")
+ try:
+ # the table is not reset between runs by hypothesis
+ self.reset_db_tables("test-db")
- items = [dict(zip(["i", "txt", "bytes"], item)) for item in data]
- self.db.copy_to(items, "test_table", ["i", "txt", "bytes"])
+ items = [dict(zip(COLUMNS, item)) for item in data]
+ self.db.copy_to(items, "test_table", COLUMNS)
- cur = self.db.cursor()
- cur.execute("select * from test_table;")
- self.assertCountEqual(list(cur), data)
+ cur = self.db.cursor()
+ cur.execute("select * from test_table;")
+ assert convert_lines(cur) == data
+ finally:
+ self.db.conn.rollback()
def test_copy_to_thread_exception(self):
data = [(2 ** 65, "foo", b"bar")]
- items = [dict(zip(["i", "txt", "bytes"], item)) for item in data]
+ items = [dict(zip(COLUMNS, item)) for item in data]
with self.assertRaises(psycopg2.errors.NumericValueOutOfRange):
- self.db.copy_to(items, "test_table", ["i", "txt", "bytes"])
+ self.db.copy_to(items, "test_table", COLUMNS)
def test_db_transaction(mocker):

File Metadata

Mime Type
text/plain
Expires
Thu, Jul 3, 3:24 PM (1 w, 2 h ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3222425

Event Timeline