diff --git a/swh/journal/direct_writer.py b/swh/journal/direct_writer.py new file mode 100644 --- /dev/null +++ b/swh/journal/direct_writer.py @@ -0,0 +1,49 @@ +# Copyright (C) 2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import logging + +from kafka import KafkaProducer + +from .serializers import key_to_kafka + +logger = logging.getLogger(__name__) + + +class DirectKafkaWriter: + """This class is instantiated and used by swh-storage to write incoming + new objects to Kafka before adding them to the storage backend + (eg. postgresql) itself.""" + def __init__(self, brokers, prefix, client_id): + self._prefix = prefix + + self.producer = KafkaProducer( + bootstrap_servers=brokers, + key_serializer=key_to_kafka, + value_serializer=key_to_kafka, + client_id=client_id, + ) + + def _get_key(self, object_type, object_): + if object_type in ('revision', 'release', 'directory', 'snapshot'): + return object_['id'] + elif object_type == 'content': + return object_['sha1'] # TODO: use a dict of hashes + elif object_type == 'origin': + return {'url': object_['url'], 'type': object_['type']} + else: + raise ValueError('Unknown object type: %s.' % object_type) + + def write_addition(self, object_type, object_): + topic = '%s.%s' % (self._prefix, object_type) + key = self._get_key(object_type, object_) + logger.debug('topic: %s, key: %s, value: %s' % (topic, key, object_)) + self.producer.send(topic, key=key, value=object_) + + write_update = write_addition + + def write_additions(self, object_type, objects): + for object_ in objects: + self.write_addition(object_type, object_) diff --git a/swh/journal/tests/test_direct_writer.py b/swh/journal/tests/test_direct_writer.py new file mode 100644 --- /dev/null +++ b/swh/journal/tests/test_direct_writer.py @@ -0,0 +1,88 @@ +# Copyright (C) 2018-2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from collections import defaultdict +import time + +from kafka import KafkaConsumer +from subprocess import Popen +from typing import Tuple + +from swh.storage import get_storage + +from swh.journal.direct_writer import DirectKafkaWriter +from swh.journal.serializers import value_to_kafka, kafka_to_value + +from .conftest import OBJECT_TYPE_KEYS + + +def assert_written(consumer): + time.sleep(0.1) # Without this, some messages are missing + + consumed_objects = defaultdict(list) + for msg in consumer: + consumed_objects[msg.topic].append((msg.key, msg.value)) + + assert dict(consumed_objects.items()) == { + 'swh.journal.objects.%s' % object_type: [ + ( + object_[key_name], + kafka_to_value(value_to_kafka(object_)), # str -> bytes + ) + for object_ in objects + ] + for (object_type, (key_name, objects)) in OBJECT_TYPE_KEYS.items() + } + + +def test_direct_writer( + kafka_server: Tuple[Popen, int], + consumer_from_publisher: KafkaConsumer): + + config = { + 'brokers': 'localhost:%d' % kafka_server[1], + 'client_id': 'direct_writer', + 'prefix': 'swh.journal.objects', + } + + writer = DirectKafkaWriter(**config) + + for (object_type, (_, objects)) in OBJECT_TYPE_KEYS.items(): + for object_ in objects: + writer.write_addition(object_type, object_) + + assert_written(consumer_from_publisher) + + +def test_storage_direct_writer( + kafka_server: Tuple[Popen, int], + consumer_from_publisher: KafkaConsumer): + + config = { + 'brokers': 'localhost:%d' % kafka_server[1], + 'client_id': 'direct_writer', + 'prefix': 'swh.journal.objects', + } + + storage = get_storage('memory', {'journal_writer': { + 'cls': 'kafka', 'args': config}}) + + for (object_type, (_, objects)) in OBJECT_TYPE_KEYS.items(): + method = getattr(storage, object_type + '_add') + if object_type in ('content', 'directory', 'revision', 'release', + 'origin'): + if object_type == 'content': + objects = [{**obj, 'data': b''} for obj in objects] + method(objects) + elif object_type in ('snapshot',): + for object_ in objects: + method(object_) + elif object_type in ('origin_visit',): + for object_ in objects: + method(**object_) + else: + assert False, object_type + + assert_written(consumer_from_publisher)