Page Menu
Home
Software Heritage
Search
Configure Global Search
Log In
Files
F9345539
D3394.diff
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
20 KB
Subscribers
None
D3394.diff
View Options
diff --git a/requirements-test-db.txt b/requirements-test-db.txt
--- a/requirements-test-db.txt
+++ b/requirements-test-db.txt
@@ -1 +1,2 @@
pytest-postgresql
+typing-extensions
diff --git a/swh/core/db/__init__.py b/swh/core/db/__init__.py
--- a/swh/core/db/__init__.py
+++ b/swh/core/db/__init__.py
@@ -3,7 +3,6 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
-import binascii
import datetime
import enum
import json
@@ -11,6 +10,7 @@
import os
import sys
import threading
+from typing import Any, Callable, Iterable, Mapping, Optional
from contextlib import contextmanager
@@ -24,40 +24,88 @@
psycopg2.extras.register_uuid()
-def escape(data):
+def render_array(data) -> str:
+ """Render the data as a postgresql array"""
+ # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-IO
+ # "The external text representation of an array value consists of items that are
+ # interpreted according to the I/O conversion rules for the array's element type,
+ # plus decoration that indicates the array structure. The decoration consists of
+ # curly braces ({ and }) around the array value plus delimiter characters between
+ # adjacent items. The delimiter character is usually a comma (,)"
+ return "{%s}" % ",".join(render_array_element(e) for e in data)
+
+
+def render_array_element(element) -> str:
+ """Render an element from an array."""
+ if element is None:
+ # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-IO
+ # "If the value written for an element is NULL (in any case variant), the
+ # element is taken to be NULL."
+ return "NULL"
+ elif isinstance(element, (list, tuple)):
+ # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-INPUT
+ # "Each val is either a constant of the array element type, or a subarray."
+ return render_array(element)
+ else:
+ # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-IO
+ # "When writing an array value you can use double quotes around any individual
+ # array element. [...] Empty strings and strings matching the word NULL must be
+ # quoted, too. To put a double quote or backslash in a quoted array element
+ # value, precede it with a backslash."
+ ret = value_as_pg_text(element)
+ return '"%s"' % ret.replace("\\", "\\\\").replace('"', '\\"')
+
+
+def value_as_pg_text(data: Any) -> str:
+ """Render the given data in the postgresql text format.
+
+ NULL values are handled **outside** of this function (either by
+ :func:`render_array_element`, or by :meth:`BaseDb.copy_to`.)
+ """
+
if data is None:
- return ""
+ raise ValueError("value_as_pg_text doesn't handle NULLs")
+
if isinstance(data, bytes):
- return "\\x%s" % binascii.hexlify(data).decode("ascii")
- elif isinstance(data, str):
- return '"%s"' % data.replace('"', '""')
+ return "\\x%s" % data.hex()
elif isinstance(data, datetime.datetime):
- # We escape twice to make sure the string generated by
- # isoformat gets escaped
- return escape(data.isoformat())
+ return data.isoformat()
elif isinstance(data, dict):
- return escape(json.dumps(data))
- elif isinstance(data, list):
- return escape("{%s}" % ",".join(escape(d) for d in data))
+ return json.dumps(data)
+ elif isinstance(data, (list, tuple)):
+ return render_array(data)
elif isinstance(data, psycopg2.extras.Range):
- # We escape twice here too, so that we make sure
- # everything gets passed to copy properly
- return escape(
- "%s%s,%s%s"
- % (
- "[" if data.lower_inc else "(",
- "-infinity" if data.lower_inf else escape(data.lower),
- "infinity" if data.upper_inf else escape(data.upper),
- "]" if data.upper_inc else ")",
- )
+ return "%s%s,%s%s" % (
+ "[" if data.lower_inc else "(",
+ "-infinity" if data.lower_inf else value_as_pg_text(data.lower),
+ "infinity" if data.upper_inf else value_as_pg_text(data.upper),
+ "]" if data.upper_inc else ")",
)
elif isinstance(data, enum.IntEnum):
- return escape(int(data))
+ return str(int(data))
else:
- # We don't escape here to make sure we pass literals properly
return str(data)
+def escape_copy_column(column: str) -> str:
+ """Escape the text representation of a column for use by COPY."""
+ # From https://www.postgresql.org/docs/11/sql-copy.html
+ # File Formats > Text Format
+ # "Backslash characters (\) can be used in the COPY data to quote data characters
+ # that might otherwise be taken as row or column delimiters. In particular, the
+ # following characters must be preceded by a backslash if they appear as part of a
+ # column value: backslash itself, newline, carriage return, and the current
+ # delimiter character."
+ ret = (
+ column.replace("\\", "\\\\")
+ .replace("\n", "\\n")
+ .replace("\r", "\\r")
+ .replace("\t", "\\t")
+ )
+
+ return ret
+
+
def typecast_bytea(value, cur):
if value is not None:
data = psycopg2.BINARY(value, cur)
@@ -148,20 +196,31 @@
raise
def copy_to(
- self, items, tblname, columns, cur=None, item_cb=None, default_values={}
- ):
- """Copy items' entries to table tblname with columns information.
+ self,
+ items: Iterable[Mapping[str, Any]],
+ tblname: str,
+ columns: Iterable[str],
+ cur: Optional[psycopg2.extensions.cursor] = None,
+ item_cb: Optional[Callable[[Any], Any]] = None,
+ default_values: Optional[Mapping[str, Any]] = None,
+ ) -> None:
+ """Run the COPY command to insert the `columns` of each element of `items` into
+ `tblname`.
Args:
- items (List[dict]): dictionaries of data to copy over tblname.
- tblname (str): destination table's name.
- columns ([str]): keys to access data in items and also the
- column names in the destination table.
- default_values (dict): dictionary of default values to use when
- inserting entried int the tblname table.
+ items: dictionaries of data to copy into `tblname`.
+ tblname: name of the destination table.
+ columns: columns of the destination table. Elements of `items` must have
+ these set as keys.
+ default_values: dictionary of default values to use when inserting entries
+ in `tblname`.
cur: a db cursor; if not given, a new cursor will be created.
- item_cb (fn): optional function to apply to items's entry.
+ item_cb: optional callback, run on each element of `items`, when it is
+ copied.
+
"""
+ if default_values is None:
+ default_values = {}
read_file, write_file = os.pipe()
exc_info = None
@@ -172,7 +231,7 @@
with open(read_file, "r") as f:
try:
cursor.copy_expert(
- "COPY %s (%s) FROM STDIN CSV" % (tblname, ", ".join(columns)), f
+ "COPY %s (%s) FROM STDIN" % (tblname, ", ".join(columns)), f
)
except Exception:
# Tell the main thread about the exception
@@ -183,6 +242,15 @@
try:
with open(write_file, "w") as f:
+ # From https://www.postgresql.org/docs/11/sql-copy.html
+ # File Formats > Text Format
+ # "When the text format is used, the data read or written is a text file
+ # with one line per table row. Columns in a row are separated by the
+ # delimiter character."
+ # NULL
+ # "The default is \N (backslash-N) in text format."
+ # DELIMITER
+ # "The default is a tab character in text format."
for d in items:
if item_cb is not None:
item_cb(d)
@@ -190,7 +258,10 @@
for k in columns:
value = d.get(k, default_values.get(k))
try:
- line.append(escape(value))
+ if value is None:
+ line.append("\\N")
+ else:
+ line.append(escape_copy_column(value_as_pg_text(value)))
except Exception as e:
logger.error(
"Could not escape value `%r` for column `%s`:"
@@ -200,7 +271,7 @@
e,
)
raise e from None
- f.write(",".join(line))
+ f.write("\t".join(line))
f.write("\n")
finally:
diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py
--- a/swh/core/db/tests/test_db.py
+++ b/swh/core/db/tests/test_db.py
@@ -3,13 +3,21 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
+from dataclasses import dataclass
+import datetime
+from enum import IntEnum
import inspect
import os.path
+from string import printable
import tempfile
+from typing import Any
+from typing_extensions import Protocol
import unittest
from unittest.mock import Mock, MagicMock
+import uuid
from hypothesis import strategies, given
+from hypothesis.extra.pytz import timezones
import psycopg2
import pytest
@@ -23,31 +31,216 @@
)
-INIT_SQL = """
-create table test_table
-(
- i int,
- txt text,
- bytes bytea
-);
-"""
-
-db_rows = strategies.lists(
- strategies.tuples(
- strategies.integers(-2147483648, +2147483647),
- strategies.text(
- alphabet=strategies.characters(
- blacklist_categories=["Cs"], # surrogates
- blacklist_characters=[
- "\x00", # pgsql does not support the null codepoint
- "\r", # pgsql normalizes those
- ],
+# workaround mypy bug https://github.com/python/mypy/issues/5485
+class Converter(Protocol):
+ def __call__(self, x: Any) -> Any:
+ ...
+
+
+@dataclass
+class Field:
+ name: str
+ """Column name"""
+ pg_type: str
+ """Type of the PostgreSQL column"""
+ example: Any
+ """Example value for the static tests"""
+ strategy: strategies.SearchStrategy
+ """Hypothesis strategy to generate these values"""
+ in_wrapper: Converter = lambda x: x
+ """Wrapper to convert this data type for the static tests"""
+ out_converter: Converter = lambda x: x
+ """Converter from the raw PostgreSQL column value to this data type"""
+
+
+# Limit PostgreSQL integer values
+pg_int = strategies.integers(-2147483648, +2147483647)
+
+pg_text = strategies.text(
+ alphabet=strategies.characters(
+ blacklist_categories=["Cs"], # surrogates
+ blacklist_characters=[
+ "\x00", # pgsql does not support the null codepoint
+ "\r", # pgsql normalizes those
+ ],
+ ),
+)
+
+pg_bytea = strategies.binary()
+
+
+def pg_bytea_a(min_size: int, max_size: int) -> strategies.SearchStrategy:
+ """Generate a PostgreSQL bytea[]"""
+ return strategies.lists(pg_bytea, min_size=min_size, max_size=max_size)
+
+
+def pg_bytea_a_a(min_size: int, max_size: int) -> strategies.SearchStrategy:
+ """Generate a PostgreSQL bytea[][]. The inner lists must all have the same size."""
+ return strategies.integers(min_value=max(1, min_size), max_value=max_size).flatmap(
+ lambda n: strategies.lists(
+ pg_bytea_a(min_size=n, max_size=n), min_size=min_size, max_size=max_size
+ )
+ )
+
+
+def pg_tstz() -> strategies.SearchStrategy:
+ """Generate values that fit in a PostgreSQL timestamptz.
+
+ Notes:
+ We're forbidding old datetimes, because until 1956, many timezones had
+ seconds in their "UTC offsets" (see
+ <https://en.wikipedia.org/wiki/Time_zone Worldwide_time_zones>), which is
+ not representable by PostgreSQL.
+
+ """
+ min_value = datetime.datetime(1960, 1, 1, 0, 0, 0)
+ return strategies.datetimes(min_value=min_value, timezones=timezones())
+
+
+def pg_jsonb(min_size: int, max_size: int) -> strategies.SearchStrategy:
+ """Generate values representable as a PostgreSQL jsonb object (dict)."""
+ return strategies.dictionaries(
+ strategies.text(printable),
+ strategies.recursive(
+ # should use floats() instead of integers(), but PostgreSQL
+ # coerces large integers into floats, making the tests fail. We
+ # only store ints in our generated data anyway.
+ strategies.none()
+ | strategies.booleans()
+ | strategies.integers(-2147483648, +2147483647)
+ | strategies.text(printable),
+ lambda children: strategies.lists(children, max_size=max_size)
+ | strategies.dictionaries(
+ strategies.text(printable), children, max_size=max_size
),
),
- strategies.binary(),
+ min_size=min_size,
+ max_size=max_size,
)
+
+
+def tuple_2d_to_list_2d(v):
+ """Convert a 2D tuple to a 2D list"""
+ return [list(inner) for inner in v]
+
+
+def list_2d_to_tuple_2d(v):
+ """Convert a 2D list to a 2D tuple"""
+ return tuple(tuple(inner) for inner in v)
+
+
+class TestIntEnum(IntEnum):
+ foo = 1
+ bar = 2
+
+
+def now():
+ return datetime.datetime.now(tz=datetime.timezone.utc)
+
+
+FIELDS = (
+ Field("i", "int", 1, pg_int),
+ Field("txt", "text", "foo", pg_text),
+ Field("bytes", "bytea", b"bar", strategies.binary()),
+ Field(
+ "bytes_array",
+ "bytea[]",
+ [b"baz1", b"baz2"],
+ pg_bytea_a(min_size=0, max_size=5),
+ ),
+ Field(
+ "bytes_tuple",
+ "bytea[]",
+ (b"baz1", b"baz2"),
+ pg_bytea_a(min_size=0, max_size=5).map(tuple),
+ in_wrapper=list,
+ out_converter=tuple,
+ ),
+ Field(
+ "bytes_2d",
+ "bytea[][]",
+ [[b"quux1"], [b"quux2"]],
+ pg_bytea_a_a(min_size=0, max_size=5),
+ ),
+ Field(
+ "bytes_2d_tuple",
+ "bytea[][]",
+ ((b"quux1",), (b"quux2",)),
+ pg_bytea_a_a(min_size=0, max_size=5).map(list_2d_to_tuple_2d),
+ in_wrapper=tuple_2d_to_list_2d,
+ out_converter=list_2d_to_tuple_2d,
+ ),
+ Field("ts", "timestamptz", now(), pg_tstz(),),
+ Field(
+ "dict",
+ "jsonb",
+ {"str": "bar", "int": 1, "list": ["a", "b"], "nested": {"a": "b"}},
+ pg_jsonb(min_size=0, max_size=5),
+ in_wrapper=psycopg2.extras.Json,
+ ),
+ Field(
+ "intenum",
+ "int",
+ TestIntEnum.foo,
+ strategies.sampled_from(TestIntEnum),
+ in_wrapper=int,
+ out_converter=TestIntEnum,
+ ),
+ Field("uuid", "uuid", uuid.uuid4(), strategies.uuids()),
+ Field(
+ "text_list",
+ "text[]",
+ # All the funky corner cases
+ ["null", "NULL", None, "\\", "\t", "\n", "\r", " ", "'", ",", '"', "{", "}"],
+ strategies.lists(pg_text, min_size=0, max_size=5),
+ ),
+ Field(
+ "tstz_list",
+ "timestamptz[]",
+ [now(), now() + datetime.timedelta(days=1)],
+ strategies.lists(pg_tstz(), min_size=0, max_size=5),
+ ),
+ Field(
+ "tstz_range",
+ "tstzrange",
+ psycopg2.extras.DateTimeTZRange(
+ lower=now(), upper=now() + datetime.timedelta(days=1), bounds="[)",
+ ),
+ strategies.tuples(
+ # generate two sorted timestamptzs for use as bounds
+ strategies.tuples(pg_tstz(), pg_tstz()).map(sorted),
+ # and a set of bounds
+ strategies.sampled_from(["[]", "()", "[)", "(]"]),
+ ).map(
+ # and build the actual DateTimeTZRange object from these args
+ lambda args: psycopg2.extras.DateTimeTZRange(
+ lower=args[0][0], upper=args[0][1], bounds=args[1],
+ )
+ ),
+ ),
)
+INIT_SQL = "create table test_table (%s)" % ", ".join(
+ f"{field.name} {field.pg_type}" for field in FIELDS
+)
+
+COLUMNS = tuple(field.name for field in FIELDS)
+INSERT_SQL = "insert into test_table (%s) values (%s)" % (
+ ", ".join(COLUMNS),
+ ", ".join("%s" for i in range(len(COLUMNS))),
+)
+
+STATIC_ROW_IN = tuple(field.in_wrapper(field.example) for field in FIELDS)
+EXPECTED_ROW_OUT = tuple(field.example for field in FIELDS)
+
+db_rows = strategies.lists(strategies.tuples(*(field.strategy for field in FIELDS)))
+
+
+def convert_lines(cur):
+ return [
+ tuple(field.out_converter(x) for x, field in zip(line, FIELDS)) for line in cur
+ ]
+
@pytest.mark.db
def test_connect():
@@ -55,10 +248,13 @@
try:
db = BaseDb.connect("dbname=%s" % db_name)
with db.cursor() as cur:
+ psycopg2.extras.register_default_jsonb(cur)
cur.execute(INIT_SQL)
- cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar"))
+ cur.execute(INSERT_SQL, STATIC_ROW_IN)
cur.execute("select * from test_table;")
- assert list(cur) == [(1, "foo", b"bar")]
+ output = convert_lines(cur)
+ assert len(output) == 1
+ assert EXPECTED_ROW_OUT == output[0]
finally:
db_close(db.conn)
db_destroy(db_name)
@@ -84,35 +280,51 @@
def test_initialized(self):
cur = self.db.cursor()
- cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar"))
+ psycopg2.extras.register_default_jsonb(cur)
+ cur.execute(INSERT_SQL, STATIC_ROW_IN)
cur.execute("select * from test_table;")
- self.assertEqual(list(cur), [(1, "foo", b"bar")])
+ output = convert_lines(cur)
+ assert len(output) == 1
+ assert EXPECTED_ROW_OUT == output[0]
def test_reset_tables(self):
cur = self.db.cursor()
- cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar"))
+ cur.execute(INSERT_SQL, STATIC_ROW_IN)
self.reset_db_tables("test-db")
cur.execute("select * from test_table;")
- self.assertEqual(list(cur), [])
+ assert convert_lines(cur) == []
+
+ def test_copy_to_static(self):
+ items = [{field.name: field.example for field in FIELDS}]
+ self.db.copy_to(items, "test_table", COLUMNS)
+
+ cur = self.db.cursor()
+ cur.execute("select * from test_table;")
+ output = convert_lines(cur)
+ assert len(output) == 1
+ assert EXPECTED_ROW_OUT == output[0]
@given(db_rows)
def test_copy_to(self, data):
- # the table is not reset between runs by hypothesis
- self.reset_db_tables("test-db")
+ try:
+ # the table is not reset between runs by hypothesis
+ self.reset_db_tables("test-db")
- items = [dict(zip(["i", "txt", "bytes"], item)) for item in data]
- self.db.copy_to(items, "test_table", ["i", "txt", "bytes"])
+ items = [dict(zip(COLUMNS, item)) for item in data]
+ self.db.copy_to(items, "test_table", COLUMNS)
- cur = self.db.cursor()
- cur.execute("select * from test_table;")
- self.assertCountEqual(list(cur), data)
+ cur = self.db.cursor()
+ cur.execute("select * from test_table;")
+ assert convert_lines(cur) == data
+ finally:
+ self.db.conn.rollback()
def test_copy_to_thread_exception(self):
data = [(2 ** 65, "foo", b"bar")]
- items = [dict(zip(["i", "txt", "bytes"], item)) for item in data]
+ items = [dict(zip(COLUMNS, item)) for item in data]
with self.assertRaises(psycopg2.errors.NumericValueOutOfRange):
- self.db.copy_to(items, "test_table", ["i", "txt", "bytes"])
+ self.db.copy_to(items, "test_table", COLUMNS)
def test_db_transaction(mocker):
File Metadata
Details
Attached
Mime Type
text/plain
Expires
Thu, Jul 3, 3:24 PM (6 d, 22 h ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3222425
Attached To
D3394: Improve test coverage and type coverage for copy_to
Event Timeline
Log In to Comment