diff --git a/swh/journal/tests/test_kafka_writer.py b/swh/journal/tests/test_kafka_writer.py index e931b15..e6c7025 100644 --- a/swh/journal/tests/test_kafka_writer.py +++ b/swh/journal/tests/test_kafka_writer.py @@ -1,149 +1,314 @@ # Copyright (C) 2018-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import defaultdict -from confluent_kafka import Consumer, KafkaException +from confluent_kafka import Consumer, Producer, KafkaException + +import pytest from subprocess import Popen -from typing import Tuple +from typing import List, Tuple from swh.storage import get_storage from swh.journal.serializers import object_key, kafka_to_key, kafka_to_value -from swh.journal.writer.kafka import KafkaJournalWriter +from swh.journal.writer.kafka import KafkaJournalWriter, KafkaDeliveryError -from swh.model.model import Origin, OriginVisit +from swh.model.model import Directory, DirectoryEntry, Origin, OriginVisit from .conftest import TEST_OBJECTS, TEST_OBJECT_DICTS def consume_messages(consumer, kafka_prefix, expected_messages): """Consume expected_messages from the consumer; Sort them all into a consumed_objects dict""" consumed_messages = defaultdict(list) fetched_messages = 0 retries_left = 1000 while fetched_messages < expected_messages: if retries_left == 0: raise ValueError("Timed out fetching messages from kafka") msg = consumer.poll(timeout=0.01) if not msg: retries_left -= 1 continue error = msg.error() if error is not None: if error.fatal(): raise KafkaException(error) retries_left -= 1 continue fetched_messages += 1 topic = msg.topic() assert topic.startswith(kafka_prefix + "."), "Unexpected topic" object_type = topic[len(kafka_prefix + ".") :] consumed_messages[object_type].append( (kafka_to_key(msg.key()), kafka_to_value(msg.value())) ) return consumed_messages def assert_all_objects_consumed(consumed_messages): """Check whether all objects from TEST_OBJECT_DICTS have been consumed""" for object_type, known_values in TEST_OBJECT_DICTS.items(): known_keys = [object_key(object_type, obj) for obj in TEST_OBJECTS[object_type]] (received_keys, received_values) = zip(*consumed_messages[object_type]) if object_type == "origin_visit": for value in received_values: del value["visit"] elif object_type == "content": for value in received_values: del value["ctime"] for key in known_keys: assert key in received_keys for value in known_values: assert value in received_values def test_kafka_writer( kafka_prefix: str, kafka_server: Tuple[Popen, int], consumer: Consumer ): kafka_prefix += ".swh.journal.objects" writer = KafkaJournalWriter( brokers=[f"localhost:{kafka_server[1]}"], client_id="kafka_writer", prefix=kafka_prefix, ) expected_messages = 0 for object_type, objects in TEST_OBJECTS.items(): writer.write_additions(object_type, objects) expected_messages += len(objects) consumed_messages = consume_messages(consumer, kafka_prefix, expected_messages) assert_all_objects_consumed(consumed_messages) def test_storage_direct_writer( kafka_prefix: str, kafka_server: Tuple[Popen, int], consumer: Consumer ): kafka_prefix += ".swh.journal.objects" writer_config = { "cls": "kafka", "brokers": ["localhost:%d" % kafka_server[1]], "client_id": "kafka_writer", "prefix": kafka_prefix, } storage_config = { "cls": "pipeline", "steps": [{"cls": "memory", "journal_writer": writer_config},], } storage = get_storage(**storage_config) expected_messages = 0 for object_type, objects in TEST_OBJECTS.items(): method = getattr(storage, object_type + "_add") if object_type in ( "content", "directory", "revision", "release", "snapshot", "origin", ): method(objects) expected_messages += len(objects) elif object_type in ("origin_visit",): for obj in objects: assert isinstance(obj, OriginVisit) storage.origin_add_one(Origin(url=obj.origin)) visit = method(obj.origin, date=obj.date, type=obj.type) expected_messages += 1 obj_d = obj.to_dict() for k in ("visit", "origin", "date", "type"): del obj_d[k] storage.origin_visit_update(obj.origin, visit.visit, **obj_d) expected_messages += 1 else: assert False, object_type consumed_messages = consume_messages(consumer, kafka_prefix, expected_messages) assert_all_objects_consumed(consumed_messages) + + +@pytest.fixture(scope="session") +def large_directories() -> List[Directory]: + dir_sizes = [1 << n for n in range(21)] # 2**0 = 1 to 2**20 = 1024 * 1024 + + dir_entries = [ + DirectoryEntry( + name=("%09d" % i).encode(), + type="file", + perms=0o100644, + target=b"\x00" * 20, + ) + for i in range(max(dir_sizes)) + ] + + return [Directory(entries=dir_entries[:size]) for size in dir_sizes] + + +def test_write_large_objects( + kafka_prefix: str, + kafka_server: Tuple[Popen, int], + consumer: Consumer, + large_directories: List[Directory], +): + kafka_prefix += ".swh.journal.objects" + + # Needed as there is no directories in TEST_OBJECT_DICTS, the consumer + # isn't autosubscribed to directories. + consumer.subscribe([kafka_prefix + ".directory"]) + + writer = KafkaJournalWriter( + brokers=["localhost:%d" % kafka_server[1]], + client_id="kafka_writer", + prefix=kafka_prefix, + ) + + writer.write_additions("directory", large_directories) + + consumed_messages = consume_messages(consumer, kafka_prefix, len(large_directories)) + + for dir, message in zip(large_directories, consumed_messages["directory"]): + (dir_id, consumed_dir) = message + assert dir_id == dir.id + assert consumed_dir == dir.to_dict() + + +def dir_message_size(directory: Directory) -> int: + """Estimate the size of a directory kafka message. + + We could just do it with `len(value_to_kafka(directory.to_dict()))`, + but the serialization is a substantial chunk of the test time here. + + """ + n_entries = len(directory.entries) + return ( + # fmt: off + 0 + + 1 # header of a 2-element fixmap + + 1 + 2 # fixstr("id") + + 2 + 20 # bin8(directory.id of length 20) + + 1 + 7 # fixstr("entries") + + 4 # array header + + n_entries + * ( + 0 + + 1 # header of a 4-element fixmap + + 1 + 6 # fixstr("target") + + 2 + 20 # bin8(target of length 20) + + 1 + 4 # fixstr("name") + + 2 + 9 # bin8(name of length 9) + + 1 + 5 # fixstr("perms") + + 5 # uint32(perms) + + 1 + 4 # fixstr("type") + + 1 + 3 # fixstr(type) + ) + # fmt: on + ) + + +SMALL_MESSAGE_SIZE = 1024 * 1024 + + +@pytest.mark.parametrize( + "kafka_server_config_overrides", [{"message.max.bytes": str(SMALL_MESSAGE_SIZE)}] +) +def test_fail_write_large_objects( + kafka_prefix: str, + kafka_server: Tuple[Popen, int], + consumer: Consumer, + large_directories: List[Directory], +): + kafka_prefix += ".swh.journal.objects" + + # Needed as there is no directories in TEST_OBJECT_DICTS, the consumer + # isn't autosubscribed to directories. + consumer.subscribe([kafka_prefix + ".directory"]) + + writer = KafkaJournalWriter( + brokers=["localhost:%d" % kafka_server[1]], + client_id="kafka_writer", + prefix=kafka_prefix, + ) + + expected_dirs = [] + + for directory in large_directories: + if dir_message_size(directory) < SMALL_MESSAGE_SIZE: + # No error; write anyway, but continue + writer.write_addition("directory", directory) + expected_dirs.append(directory) + continue + + with pytest.raises(KafkaDeliveryError) as exc: + writer.write_addition("directory", directory) + + assert "Failed deliveries" in exc.value.message + assert len(exc.value.delivery_failures) == 1 + + object_type, key, msg, code = exc.value.delivery_failures[0] + + assert object_type == "directory" + assert key == directory.id + assert code == "MSG_SIZE_TOO_LARGE" + + consumed_messages = consume_messages(consumer, kafka_prefix, len(expected_dirs)) + + for dir, message in zip(expected_dirs, consumed_messages["directory"]): + (dir_id, consumed_dir) = message + assert dir_id == dir.id + assert consumed_dir == dir.to_dict() + + +def test_write_delivery_timeout( + kafka_prefix: str, kafka_server: Tuple[Popen, int], consumer: Consumer +): + + produced = [] + + class MockProducer(Producer): + def produce(self, **kwargs): + produced.append(kwargs) + + kafka_prefix += ".swh.journal.objects" + writer = KafkaJournalWriter( + brokers=["localhost:%d" % kafka_server[1]], + client_id="kafka_writer", + prefix=kafka_prefix, + flush_timeout=1, + producer_class=MockProducer, + ) + + empty_dir = Directory(entries=[]) + with pytest.raises(KafkaDeliveryError) as exc: + writer.write_addition("directory", empty_dir) + + assert len(produced) == 1 + + assert "timeout" in exc.value.message + assert len(exc.value.delivery_failures) == 1 + delivery_failure = exc.value.delivery_failures[0] + assert delivery_failure.key == empty_dir.id + assert delivery_failure.code == "SWH_FLUSH_TIMEOUT" diff --git a/swh/journal/writer/kafka.py b/swh/journal/writer/kafka.py index 78d26c2..2a26e53 100644 --- a/swh/journal/writer/kafka.py +++ b/swh/journal/writer/kafka.py @@ -1,142 +1,234 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging -from typing import Dict, Iterable, Optional, Type +import time +from typing import Dict, Iterable, List, NamedTuple, Optional, Type from confluent_kafka import Producer, KafkaException from swh.model.model import ( BaseModel, Content, Directory, Origin, OriginVisit, Release, Revision, SkippedContent, Snapshot, ) from swh.journal.serializers import ( KeyType, ModelObject, object_key, + pprint_key, key_to_kafka, value_to_kafka, ) logger = logging.getLogger(__name__) OBJECT_TYPES: Dict[Type[BaseModel], str] = { Content: "content", Directory: "directory", Origin: "origin", OriginVisit: "origin_visit", Release: "release", Revision: "revision", SkippedContent: "skipped_content", Snapshot: "snapshot", } +class DeliveryTag(NamedTuple): + """Unique tag allowing us to check for a message delivery""" + + topic: str + kafka_key: bytes + + +class DeliveryFailureInfo(NamedTuple): + """Verbose information for failed deliveries""" + + object_type: str + key: KeyType + message: str + code: str + + +def get_object_type(topic: str) -> str: + """Get the object type from a topic string""" + return topic.rsplit(".", 1)[-1] + + +class KafkaDeliveryError(Exception): + """Delivery failed on some kafka messages.""" + + def __init__(self, message: str, delivery_failures: Iterable[DeliveryFailureInfo]): + self.message = message + self.delivery_failures = list(delivery_failures) + + def pretty_failures(self) -> str: + return ", ".join( + f"{f.object_type} {pprint_key(f.key)} ({f.message})" + for f in self.delivery_failures + ) + + def __str__(self): + return f"KafkaDeliveryError({self.message}, [{self.pretty_failures()}])" + + class KafkaJournalWriter: """This class is instantiated and used by swh-storage to write incoming new objects to Kafka before adding them to the storage backend (eg. postgresql) itself. Args: brokers: list of broker addresses and ports prefix: the prefix used to build the topic names for objects client_id: the id of the writer sent to kafka producer_config: extra configuration keys passed to the `Producer` + flush_timeout: timeout, in seconds, after which the `flush` operation + will fail if some message deliveries are still pending. producer_class: override for the kafka producer class """ def __init__( self, brokers: Iterable[str], prefix: str, client_id: str, producer_config: Optional[Dict] = None, + flush_timeout: float = 120, producer_class: Type[Producer] = Producer, ): self._prefix = prefix if not producer_config: producer_config = {} if "message.max.bytes" not in producer_config: producer_config = { "message.max.bytes": 100 * 1024 * 1024, **producer_config, } self.producer = producer_class( { "bootstrap.servers": ",".join(brokers), "client.id": client_id, "on_delivery": self._on_delivery, "error_cb": self._error_cb, "logger": logger, "acks": "all", **producer_config, } ) + # Delivery management + self.flush_timeout = flush_timeout + + # delivery tag -> original object "key" mapping + self.deliveries_pending: Dict[DeliveryTag, KeyType] = {} + + # List of (object_type, key, error_msg, error_name) for failed deliveries + self.delivery_failures: List[DeliveryFailureInfo] = [] + def _error_cb(self, error): if error.fatal(): raise KafkaException(error) logger.info("Received non-fatal kafka error: %s", error) def _on_delivery(self, error, message): + (topic, key) = delivery_tag = DeliveryTag(message.topic(), message.key()) + sent_key = self.deliveries_pending.pop(delivery_tag, None) + if error is not None: - self._error_cb(error) + self.delivery_failures.append( + DeliveryFailureInfo( + get_object_type(topic), sent_key, error.str(), error.name() + ) + ) def send(self, topic: str, key: KeyType, value): kafka_key = key_to_kafka(key) self.producer.produce( topic=topic, key=kafka_key, value=value_to_kafka(value), ) - # Need to service the callbacks regularly by calling poll - self.producer.poll(0) + self.deliveries_pending[DeliveryTag(topic, kafka_key)] = key + + def delivery_error(self, message) -> KafkaDeliveryError: + """Get all failed deliveries, and clear them""" + ret = self.delivery_failures + self.delivery_failures = [] + + while self.deliveries_pending: + delivery_tag, orig_key = self.deliveries_pending.popitem() + (topic, kafka_key) = delivery_tag + ret.append( + DeliveryFailureInfo( + get_object_type(topic), + orig_key, + "No delivery before flush() timeout", + "SWH_FLUSH_TIMEOUT", + ) + ) + + return KafkaDeliveryError(message, ret) def flush(self): - self.producer.flush() + start = time.monotonic() + + self.producer.flush(self.flush_timeout) + + while self.deliveries_pending: + if time.monotonic() - start > self.flush_timeout: + break + self.producer.poll(0.1) + + if self.deliveries_pending: + # Delivery timeout + raise self.delivery_error( + "flush() exceeded timeout (%ss)" % self.flush_timeout, + ) + elif self.delivery_failures: + raise self.delivery_error("Failed deliveries after flush()") def _sanitize_object( self, object_type: str, object_: ModelObject ) -> Dict[str, str]: dict_ = object_.to_dict() if object_type == "origin_visit": # :( dict_["date"] = str(dict_["date"]) if object_type == "content": dict_.pop("data", None) return dict_ def _write_addition(self, object_type: str, object_: ModelObject) -> None: """Write a single object to the journal""" topic = f"{self._prefix}.{object_type}" key = object_key(object_type, object_) dict_ = self._sanitize_object(object_type, object_) logger.debug("topic: %s, key: %s, value: %s", topic, key, dict_) self.send(topic, key=key, value=dict_) def write_addition(self, object_type: str, object_: ModelObject) -> None: """Write a single object to the journal""" self._write_addition(object_type, object_) self.flush() write_update = write_addition def write_additions(self, object_type: str, objects: Iterable[ModelObject]) -> None: """Write a set of objects to the journal""" for object_ in objects: self._write_addition(object_type, object_) self.flush()