diff --git a/swh/journal/pytest_plugin.py b/swh/journal/pytest_plugin.py --- a/swh/journal/pytest_plugin.py +++ b/swh/journal/pytest_plugin.py @@ -9,13 +9,18 @@ from typing import Collection, Dict, Iterator, Optional from collections import defaultdict +import attr import pytest from confluent_kafka import Consumer, KafkaException, Producer from confluent_kafka.admin import AdminClient from swh.journal.serializers import object_key, kafka_to_key, kafka_to_value, pprint_key -from swh.journal.tests.journal_data import TEST_OBJECTS, TEST_OBJECT_DICTS +from swh.journal.tests.journal_data import ( + TEST_OBJECTS, + TEST_OBJECT_DICTS, + MODEL_OBJECTS, +) def consume_messages(consumer, kafka_prefix, expected_messages): @@ -63,14 +68,14 @@ def assert_all_objects_consumed( consumed_messages: Dict, exclude: Optional[Collection] = None ): - """Check whether all objects from TEST_OBJECT_DICTS have been consumed + """Check whether all objects from TEST_OBJECTS have been consumed `exclude` can be a list of object types for which we do not want to compare the values (eg. for anonymized object). """ - for object_type, known_values in TEST_OBJECT_DICTS.items(): - known_keys = [object_key(object_type, obj) for obj in TEST_OBJECTS[object_type]] + for object_type, known_objects in TEST_OBJECTS.items(): + known_keys = [object_key(object_type, obj) for obj in known_objects] if not consumed_messages[object_type]: return @@ -80,6 +85,8 @@ if object_type in ("content", "skipped_content"): for value in received_values: del value["ctime"] + if object_type == "content": + known_objects = [attr.evolve(o, data=None) for o in known_objects] for key in known_keys: assert key in received_keys, ( @@ -90,8 +97,12 @@ if exclude and object_type in exclude: continue - for value in known_values: - assert value in received_values, ( + received_objects = [ + MODEL_OBJECTS[object_type].from_dict(d) for d in received_values + ] + + for value in known_objects: + assert value in received_objects, ( f"expected {object_type} value {value!r} is " "absent from consumed messages" ) diff --git a/swh/journal/tests/journal_data.py b/swh/journal/tests/journal_data.py --- a/swh/journal/tests/journal_data.py +++ b/swh/journal/tests/journal_data.py @@ -23,17 +23,23 @@ Snapshot, ) +MODEL_CLASSES = ( + Content, + Directory, + Origin, + OriginVisit, + OriginVisitStatus, + Release, + Revision, + SkippedContent, + Snapshot, +) OBJECT_TYPES: Dict[Type[BaseModel], str] = { - Content: "content", - Directory: "directory", - Origin: "origin", - OriginVisit: "origin_visit", - OriginVisitStatus: "origin_visit_status", - Release: "release", - Revision: "revision", - SkippedContent: "skipped_content", - Snapshot: "snapshot", + cls: cls.object_type for cls in MODEL_CLASSES # type: ignore +} +MODEL_OBJECTS: Dict[str, Type[BaseModel]] = { + cls.object_type: cls for cls in MODEL_CLASSES # type: ignore } UTC = datetime.timezone.utc @@ -42,6 +48,7 @@ { **MultiHash.from_data(f"foo{i}".encode()).digest(), "length": 4, + "data": f"foo{i}".encode(), "status": "visible", } for i in range(10) @@ -49,6 +56,7 @@ { **MultiHash.from_data(f"forbidden foo{i}".encode()).digest(), "length": 14, + "data": f"forbidden foo{i}".encode(), "status": "hidden", } for i in range(10) @@ -302,8 +310,6 @@ "skipped_content": SKIPPED_CONTENTS, } -MODEL_OBJECTS = {v: k for (k, v) in OBJECT_TYPES.items()} - TEST_OBJECTS: Dict[str, List[ModelObject]] = {} for object_type, objects in TEST_OBJECT_DICTS.items(): @@ -312,7 +318,7 @@ for (num, obj_d) in enumerate(objects): if object_type == "content": - obj_d = {**obj_d, "data": b"", "ctime": datetime.datetime.now(tz=UTC)} + obj_d = {**obj_d, "ctime": datetime.datetime.now(tz=UTC)} converted_objects.append(model.from_dict(obj_d))