diff --git a/swh/journal/tests/conftest.py b/swh/journal/tests/conftest.py --- a/swh/journal/tests/conftest.py +++ b/swh/journal/tests/conftest.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import datetime import os import pytest import logging @@ -28,6 +29,9 @@ from swh.model.hashutil import MultiHash, hash_to_bytes +from swh.journal.writer.kafka import OBJECT_TYPES, ModelObject + + logger = logging.getLogger(__name__) CONTENTS = [ @@ -149,6 +153,24 @@ "origin_visit": (None, ORIGIN_VISITS), } +MODEL_OBJECTS = {v: k for (k, v) in OBJECT_TYPES.items()} + +TEST_OBJECTS: Dict[str, List[ModelObject]] = {} + +for object_type, (_, objects) in TEST_OBJECT_DICTS.items(): + converted_objects: List[ModelObject] = [] + model = MODEL_OBJECTS[object_type] + + for (num, obj_d) in enumerate(objects): + if object_type == "origin_visit": + obj_d = {**obj_d, "visit": num} + elif object_type == "content": + obj_d = {**obj_d, "data": b"", "ctime": datetime.datetime.now()} + + converted_objects.append(model.from_dict(obj_d)) + + TEST_OBJECTS[object_type] = converted_objects + KAFKA_ROOT = os.environ.get("SWH_KAFKA_ROOT") KAFKA_ROOT = KAFKA_ROOT if KAFKA_ROOT else os.path.dirname(__file__) + "/kafka" diff --git a/swh/journal/tests/test_kafka_writer.py b/swh/journal/tests/test_kafka_writer.py --- a/swh/journal/tests/test_kafka_writer.py +++ b/swh/journal/tests/test_kafka_writer.py @@ -4,23 +4,19 @@ # See top-level LICENSE file for more information from collections import defaultdict -import datetime from confluent_kafka import Consumer, KafkaException from subprocess import Popen -from typing import List, Tuple +from typing import Tuple from swh.storage import get_storage -from swh.journal.replay import object_converter_fn from swh.journal.serializers import kafka_to_key, kafka_to_value -from swh.journal.writer.kafka import KafkaJournalWriter, OBJECT_TYPES +from swh.journal.writer.kafka import KafkaJournalWriter -from swh.model.model import Content, Origin, BaseModel +from swh.model.model import Origin, OriginVisit -from .conftest import TEST_OBJECT_DICTS - -MODEL_OBJECTS = {v: k for (k, v) in OBJECT_TYPES.items()} +from .conftest import TEST_OBJECTS, TEST_OBJECT_DICTS def consume_messages(consumer, kafka_prefix, expected_messages): @@ -93,16 +89,9 @@ expected_messages = 0 - for (object_type, (_, objects)) in TEST_OBJECT_DICTS.items(): - for (num, object_d) in enumerate(objects): - if object_type == "origin_visit": - object_d = {**object_d, "visit": num} - if object_type == "content": - object_d = {**object_d, "ctime": datetime.datetime.now()} - object_ = MODEL_OBJECTS[object_type].from_dict(object_d) - - writer.write_addition(object_type, object_) - expected_messages += 1 + for object_type, objects in TEST_OBJECTS.items(): + writer.write_additions(object_type, objects) + expected_messages += len(objects) consumed_messages = consume_messages(consumer, kafka_prefix, expected_messages) assert_all_objects_consumed(consumed_messages) @@ -128,7 +117,7 @@ expected_messages = 0 - for (object_type, (_, objects)) in TEST_OBJECT_DICTS.items(): + for object_type, objects in TEST_OBJECTS.items(): method = getattr(storage, object_type + "_add") if object_type in ( "content", @@ -138,23 +127,19 @@ "snapshot", "origin", ): - objects_: List[BaseModel] - if object_type == "content": - objects_ = [Content.from_dict({**obj, "data": b""}) for obj in objects] - else: - objects_ = [object_converter_fn[object_type](obj) for obj in objects] - method(objects_) + method(objects) expected_messages += len(objects) elif object_type in ("origin_visit",): - for object_ in objects: - object_ = object_.copy() - origin_url = object_.pop("origin") - storage.origin_add_one(Origin(url=origin_url)) - visit = method( - origin_url, date=object_.pop("date"), type=object_.pop("type") - ) + for obj in objects: + assert isinstance(obj, OriginVisit) + storage.origin_add_one(Origin(url=obj.origin)) + visit = method(obj.origin, date=obj.date, type=obj.type) expected_messages += 1 - storage.origin_visit_update(origin_url, visit.visit, **object_) + + obj_d = obj.to_dict() + for k in ("visit", "origin", "date", "type"): + del obj_d[k] + storage.origin_visit_update(obj.origin, visit.visit, **obj_d) expected_messages += 1 else: assert False, object_type diff --git a/swh/journal/writer/kafka.py b/swh/journal/writer/kafka.py --- a/swh/journal/writer/kafka.py +++ b/swh/journal/writer/kafka.py @@ -153,6 +153,8 @@ if object_type == "origin_visit": # :( dict_["date"] = str(dict_["date"]) + if object_type == "content": + dict_.pop("data", None) return dict_ def _write_addition(self, object_type: str, object_: ModelObject) -> None: