diff --git a/swh/journal/tests/test_kafka_writer.py b/swh/journal/tests/test_kafka_writer.py --- a/swh/journal/tests/test_kafka_writer.py +++ b/swh/journal/tests/test_kafka_writer.py @@ -25,8 +25,10 @@ MODEL_OBJECTS = {v: k for (k, v) in OBJECT_TYPES.items()} -def assert_written(consumer, kafka_prefix, expected_messages): - consumed_objects = defaultdict(list) +def consume_messages(consumer, kafka_prefix, expected_messages): + """Consume expected_messages from the consumer; + Sort them all into a consumed_objects dict""" + consumed_messages = defaultdict(list) fetched_messages = 0 retries_left = 1000 @@ -49,13 +51,21 @@ continue fetched_messages += 1 - consumed_objects[msg.topic()].append( + topic = msg.topic() + assert topic.startswith(kafka_prefix + '.'), "Unexpected topic" + object_type = topic[len(kafka_prefix + '.'):] + + consumed_messages[object_type].append( (kafka_to_key(msg.key()), kafka_to_value(msg.value())) ) + return consumed_messages + + +def assert_all_objects_consumed(consumed_messages): + """Check whether all objects from OBJECT_TYPE_KEYS have been consumed""" for (object_type, (key_name, objects)) in OBJECT_TYPE_KEYS.items(): - topic = kafka_prefix + '.' + object_type - (keys, values) = zip(*consumed_objects[topic]) + (keys, values) = zip(*consumed_messages[object_type]) if key_name: assert list(keys) == [object_[key_name] for object_ in objects] else: @@ -99,7 +109,10 @@ writer.write_addition(object_type, object_) expected_messages += 1 - assert_written(consumer, kafka_prefix, expected_messages) + consumed_messages = consume_messages( + consumer, kafka_prefix, expected_messages + ) + assert_all_objects_consumed(consumed_messages) def test_storage_direct_writer( @@ -159,4 +172,7 @@ else: assert False, object_type - assert_written(consumer, kafka_prefix, expected_messages) + consumed_messages = consume_messages( + consumer, kafka_prefix, expected_messages + ) + assert_all_objects_consumed(consumed_messages)