diff --git a/swh/journal/tests/test_kafka_writer.py b/swh/journal/tests/test_kafka_writer.py index 07a8b3a..bab6639 100644 --- a/swh/journal/tests/test_kafka_writer.py +++ b/swh/journal/tests/test_kafka_writer.py @@ -1,138 +1,144 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import defaultdict import datetime from confluent_kafka import Consumer, KafkaException from subprocess import Popen from typing import Tuple from swh.storage import get_storage from swh.journal.writer.kafka import KafkaJournalWriter from swh.journal.serializers import ( kafka_to_key, kafka_to_value ) from .conftest import OBJECT_TYPE_KEYS def assert_written(consumer, kafka_prefix, expected_messages): consumed_objects = defaultdict(list) fetched_messages = 0 retries_left = 1000 while fetched_messages < expected_messages: if retries_left == 0: raise ValueError('Timed out fetching messages from kafka') msg = consumer.poll(timeout=0.01) if not msg: retries_left -= 1 continue error = msg.error() if error is not None: if error.fatal(): raise KafkaException(error) retries_left -= 1 continue fetched_messages += 1 consumed_objects[msg.topic()].append( (kafka_to_key(msg.key()), kafka_to_value(msg.value())) ) for (object_type, (key_name, objects)) in OBJECT_TYPE_KEYS.items(): topic = kafka_prefix + '.' + object_type (keys, values) = zip(*consumed_objects[topic]) if key_name: assert list(keys) == [object_[key_name] for object_ in objects] else: pass # TODO if object_type == 'origin_visit': for value in values: del value['visit'] elif object_type == 'content': for value in values: del value['ctime'] for object_ in objects: assert object_ in values def test_kafka_writer( kafka_prefix: str, kafka_server: Tuple[Popen, int], consumer: Consumer): kafka_prefix += '.swh.journal.objects' config = { 'brokers': ['localhost:%d' % kafka_server[1]], 'client_id': 'kafka_writer', 'prefix': kafka_prefix, + 'producer_config': { + 'message.max.bytes': 100000000, + } } writer = KafkaJournalWriter(**config) expected_messages = 0 for (object_type, (_, objects)) in OBJECT_TYPE_KEYS.items(): for (num, object_) in enumerate(objects): if object_type == 'origin_visit': object_ = {**object_, 'visit': num} if object_type == 'content': object_ = {**object_, 'ctime': datetime.datetime.now()} writer.write_addition(object_type, object_) expected_messages += 1 assert_written(consumer, kafka_prefix, expected_messages) def test_storage_direct_writer( kafka_prefix: str, kafka_server: Tuple[Popen, int], consumer: Consumer): kafka_prefix += '.swh.journal.objects' config = { 'brokers': ['localhost:%d' % kafka_server[1]], 'client_id': 'kafka_writer', 'prefix': kafka_prefix, + 'producer_config': { + 'message.max.bytes': 100000000, + } } storage = get_storage('memory', journal_writer={ - 'cls': 'kafka', 'args': config, + 'cls': 'kafka', **config, }) expected_messages = 0 for (object_type, (_, objects)) in OBJECT_TYPE_KEYS.items(): method = getattr(storage, object_type + '_add') if object_type in ('content', 'directory', 'revision', 'release', 'snapshot', 'origin'): if object_type == 'content': objects = [{**obj, 'data': b''} for obj in objects] method(objects) expected_messages += len(objects) elif object_type in ('origin_visit',): for object_ in objects: object_ = object_.copy() origin_url = object_.pop('origin') storage.origin_add_one({'url': origin_url}) visit = method(origin=origin_url, date=object_.pop('date'), type=object_.pop('type')) expected_messages += 1 visit_id = visit['visit'] storage.origin_visit_update(origin_url, visit_id, **object_) expected_messages += 1 else: assert False, object_type assert_written(consumer, kafka_prefix, expected_messages) diff --git a/swh/journal/writer/kafka.py b/swh/journal/writer/kafka.py index 70ed938..d4728fb 100644 --- a/swh/journal/writer/kafka.py +++ b/swh/journal/writer/kafka.py @@ -1,110 +1,114 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging from confluent_kafka import Producer, KafkaException from swh.model.hashutil import DEFAULT_ALGORITHMS from swh.model.model import BaseModel from swh.journal.serializers import key_to_kafka, value_to_kafka logger = logging.getLogger(__name__) class KafkaJournalWriter: """This class is instantiated and used by swh-storage to write incoming new objects to Kafka before adding them to the storage backend (eg. postgresql) itself.""" - def __init__(self, brokers, prefix, client_id): + def __init__(self, brokers, prefix, client_id, producer_config=None): self._prefix = prefix if isinstance(brokers, str): brokers = [brokers] + if not producer_config: + producer_config = {} + self.producer = Producer({ 'bootstrap.servers': ','.join(brokers), 'client.id': client_id, 'on_delivery': self._on_delivery, 'error_cb': self._error_cb, 'logger': logger, 'enable.idempotence': 'true', + **producer_config, }) def _error_cb(self, error): if error.fatal(): raise KafkaException(error) logger.info('Received non-fatal kafka error: %s', error) def _on_delivery(self, error, message): if error is not None: self._error_cb(error) def send(self, topic, key, value): self.producer.produce( topic=topic, key=key_to_kafka(key), value=value_to_kafka(value), ) # Need to service the callbacks regularly by calling poll self.producer.poll(0) def flush(self): self.producer.flush() def _get_key(self, object_type, object_): if object_type in ('revision', 'release', 'directory', 'snapshot'): return object_['id'] elif object_type == 'content': return object_['sha1'] # TODO: use a dict of hashes elif object_type == 'skipped_content': return { hash: object_[hash] for hash in DEFAULT_ALGORITHMS } elif object_type == 'origin': return {'url': object_['url']} elif object_type == 'origin_visit': return { 'origin': object_['origin'], 'date': str(object_['date']), } else: raise ValueError('Unknown object type: %s.' % object_type) def _sanitize_object(self, object_type, object_): if object_type == 'origin_visit': return { **object_, 'date': str(object_['date']), } elif object_type == 'origin': assert 'id' not in object_ return object_ def write_addition(self, object_type, object_, flush=True): """Write a single object to the journal""" if isinstance(object_, BaseModel): object_ = object_.to_dict() topic = '%s.%s' % (self._prefix, object_type) key = self._get_key(object_type, object_) object_ = self._sanitize_object(object_type, object_) logger.debug('topic: %s, key: %s, value: %s', topic, key, object_) self.send(topic, key=key, value=object_) if flush: self.flush() write_update = write_addition def write_additions(self, object_type, objects, flush=True): """Write a set of objects to the journal""" for object_ in objects: self.write_addition(object_type, object_, flush=False) if flush: self.flush()