Changeset View
Changeset View
Standalone View
Standalone View
swh/journal/tests/test_kafka_writer.py
# Copyright (C) 2018-2020 The Software Heritage developers | # Copyright (C) 2018-2020 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
from collections import defaultdict | |||||
import pytest | import pytest | ||||
from confluent_kafka import Consumer, Producer, KafkaException | from confluent_kafka import Consumer, Producer | ||||
from swh.storage import get_storage | from swh.storage import get_storage | ||||
from swh.journal.serializers import object_key, kafka_to_key, kafka_to_value | |||||
from swh.journal.writer.kafka import KafkaJournalWriter, KafkaDeliveryError | |||||
from swh.model.model import Directory, Origin, OriginVisit | from swh.model.model import Directory, Origin, OriginVisit | ||||
from .conftest import TEST_OBJECTS, TEST_OBJECT_DICTS | from swh.journal.tests.journal_data import TEST_OBJECTS | ||||
from swh.journal.pytest_plugin import consume_messages, assert_all_objects_consumed | |||||
from swh.journal.writer.kafka import KafkaJournalWriter, KafkaDeliveryError | |||||
def consume_messages(consumer, kafka_prefix, expected_messages): | |||||
"""Consume expected_messages from the consumer; | |||||
Sort them all into a consumed_objects dict""" | |||||
consumed_messages = defaultdict(list) | |||||
fetched_messages = 0 | |||||
retries_left = 1000 | |||||
while fetched_messages < expected_messages: | |||||
if retries_left == 0: | |||||
raise ValueError("Timed out fetching messages from kafka") | |||||
msg = consumer.poll(timeout=0.01) | |||||
if not msg: | |||||
retries_left -= 1 | |||||
continue | |||||
error = msg.error() | |||||
if error is not None: | |||||
if error.fatal(): | |||||
raise KafkaException(error) | |||||
retries_left -= 1 | |||||
continue | |||||
fetched_messages += 1 | |||||
topic = msg.topic() | |||||
assert topic.startswith(kafka_prefix + "."), "Unexpected topic" | |||||
object_type = topic[len(kafka_prefix + ".") :] | |||||
consumed_messages[object_type].append( | |||||
(kafka_to_key(msg.key()), kafka_to_value(msg.value())) | |||||
) | |||||
return consumed_messages | |||||
def assert_all_objects_consumed(consumed_messages): | |||||
"""Check whether all objects from TEST_OBJECT_DICTS have been consumed""" | |||||
for object_type, known_values in TEST_OBJECT_DICTS.items(): | |||||
known_keys = [object_key(object_type, obj) for obj in TEST_OBJECTS[object_type]] | |||||
(received_keys, received_values) = zip(*consumed_messages[object_type]) | |||||
if object_type == "origin_visit": | |||||
for value in received_values: | |||||
del value["visit"] | |||||
elif object_type == "content": | |||||
for value in received_values: | |||||
del value["ctime"] | |||||
for key in known_keys: | |||||
assert key in received_keys | |||||
for value in known_values: | |||||
assert value in received_values | |||||
def test_kafka_writer(kafka_prefix: str, kafka_server: str, consumer: Consumer): | def test_kafka_writer(kafka_prefix: str, kafka_server: str, consumer: Consumer): | ||||
kafka_prefix += ".swh.journal.objects" | kafka_prefix += ".swh.journal.objects" | ||||
writer = KafkaJournalWriter( | writer = KafkaJournalWriter( | ||||
brokers=[kafka_server], client_id="kafka_writer", prefix=kafka_prefix, | brokers=[kafka_server], client_id="kafka_writer", prefix=kafka_prefix, | ||||
) | ) | ||||
▲ Show 20 Lines • Show All 128 Lines • Show Last 20 Lines |