diff --git a/swh/journal/pytest_plugin.py b/swh/journal/pytest_plugin.py --- a/swh/journal/pytest_plugin.py +++ b/swh/journal/pytest_plugin.py @@ -6,7 +6,7 @@ from collections import defaultdict import random import string -from typing import Collection, Dict, Iterator, Optional +from typing import Any, Collection, Dict, Iterator, Optional import attr from confluent_kafka import Consumer, KafkaException, Producer @@ -17,6 +17,25 @@ from swh.journal.tests.journal_data import TEST_OBJECTS +def ensure_lists(value: Any) -> Any: + """ + >>> ensure_lists(["foo", 42]) + ['foo', 42] + >>> ensure_lists(("foo", 42)) + ['foo', 42] + >>> ensure_lists({"a": ["foo", 42]}) + {'a': ['foo', 42]} + >>> ensure_lists({"a": ("foo", 42)}) + {'a': ['foo', 42]} + """ + if isinstance(value, (tuple, list)): + return list(map(ensure_lists, value)) + elif isinstance(value, dict): + return dict(ensure_lists(list(value.items()))) + else: + return value + + def consume_messages(consumer, kafka_prefix, expected_messages): """Consume expected_messages from the consumer; Sort them all into a consumed_objects dict""" @@ -95,7 +114,7 @@ expected_value = value.to_dict() if value.object_type in ("content", "skipped_content"): expected_value.pop("ctime", None) - assert expected_value in received_values, ( + assert ensure_lists(expected_value) in received_values, ( f"expected {object_type} value {value!r} is " "absent from consumed messages" ) diff --git a/swh/journal/serializers.py b/swh/journal/serializers.py --- a/swh/journal/serializers.py +++ b/swh/journal/serializers.py @@ -103,7 +103,7 @@ def kafka_to_value(kafka_value: bytes) -> Any: """Deserialize some data stored in kafka""" - value = msgpack.unpackb( + return msgpack.unpackb( kafka_value, raw=False, object_hook=decode_types_bw, @@ -111,13 +111,3 @@ strict_map_key=False, timestamp=3, # convert Timestamp in datetime objects (tz UTC) ) - return ensure_tuples(value) - - -def ensure_tuples(value: Any) -> Any: - if isinstance(value, (tuple, list)): - return tuple(map(ensure_tuples, value)) - elif isinstance(value, dict): - return dict(ensure_tuples(list(value.items()))) - else: - return value diff --git a/swh/journal/tests/test_client.py b/swh/journal/tests/test_client.py --- a/swh/journal/tests/test_client.py +++ b/swh/journal/tests/test_client.py @@ -34,7 +34,7 @@ ), "synthetic": False, "metadata": None, - "parents": (), + "parents": [], "id": b"\x8b\xeb\xd1\x9d\x07\xe2\x1e0\xe2 \x91X\x8d\xbd\x1c\xa8\x86\xdeB\x0c", }