diff --git a/swh/journal/cli.py b/swh/journal/cli.py --- a/swh/journal/cli.py +++ b/swh/journal/cli.py @@ -4,13 +4,15 @@ # See top-level LICENSE file for more information import click +import functools import logging import os from swh.core import config from swh.storage import get_storage -from swh.journal.replay import StorageReplayer +from swh.journal.client import JournalClient +from swh.journal.replay import process_replay_objects from swh.journal.backfill import JournalBackfiller CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) @@ -71,10 +73,16 @@ """ conf = ctx.obj['config'] + logger = logging.getLogger(__name__) + logger.setLevel(ctx.obj['loglevel']) storage = get_storage(**conf.pop('storage')) - replayer = StorageReplayer(brokers, prefix, consumer_id) + client = JournalClient(brokers, prefix, consumer_id) + worker_fn = functools.partial(process_replay_objects, storage=storage) try: - replayer.fill(storage, max_messages=max_messages) + nb_messages = 0 + while not max_messages or nb_messages < max_messages: + nb_messages += client.process(worker_fn) + logger.info('Processed %d messages.' % nb_messages) except KeyboardInterrupt: ctx.exit(0) else: diff --git a/swh/journal/client.py b/swh/journal/client.py --- a/swh/journal/client.py +++ b/swh/journal/client.py @@ -3,28 +3,31 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from abc import ABCMeta, abstractmethod -from collections import defaultdict from kafka import KafkaConsumer +import logging -from swh.core.config import SWHConfig from .serializers import kafka_to_key, kafka_to_value + +logger = logging.getLogger(__name__) + + # Only accepted offset reset policy accepted ACCEPTED_OFFSET_RESET = ['earliest', 'latest'] # Only accepted object types ACCEPTED_OBJECT_TYPES = [ 'content', + 'directory', 'revision', 'release', - 'occurrence', + 'snapshot', 'origin', 'origin_visit' ] -class JournalClient(SWHConfig, metaclass=ABCMeta): +class JournalClient: """A base client for the Software Heritage journal. The current implementation of the journal uses Apache Kafka @@ -42,41 +45,16 @@ of maximum `max_messages`. """ - DEFAULT_CONFIG = { - # Broker to connect to - 'brokers': ('list[str]', ['localhost']), - # Prefix topic to receive notification from - 'topic_prefix': ('str', 'swh.journal.objects'), - # Consumer identifier - 'consumer_id': ('str', 'swh.journal.client'), - # Object types to deal with (in a subscription manner) - 'object_types': ('list[str]', [ - 'content', 'revision', - 'release', 'occurrence', - 'origin', 'origin_visit' - ]), - # Number of messages to batch process - 'max_messages': ('int', 100), - 'auto_offset_reset': ('str', 'earliest') - } - - CONFIG_BASE_FILENAME = 'journal/client' - - ADDITIONAL_CONFIG = {} - - def __init__(self, extra_configuration={}): - self.config = self.parse_config_file( - additional_configs=[self.ADDITIONAL_CONFIG]) - if extra_configuration: - self.config.update(extra_configuration) - - auto_offset_reset = self.config['auto_offset_reset'] + def __init__( + self, brokers, topic_prefix, consumer_id, + object_types=ACCEPTED_OBJECT_TYPES, + max_messages=0, auto_offset_reset='earliest'): + if auto_offset_reset not in ACCEPTED_OFFSET_RESET: raise ValueError( 'Option \'auto_offset_reset\' only accept %s.' % ACCEPTED_OFFSET_RESET) - object_types = self.config['object_types'] for object_type in object_types: if object_type not in ACCEPTED_OBJECT_TYPES: raise ValueError( @@ -84,46 +62,47 @@ ACCEPTED_OFFSET_RESET) self.consumer = KafkaConsumer( - bootstrap_servers=self.config['brokers'], + bootstrap_servers=brokers, key_deserializer=kafka_to_key, value_deserializer=kafka_to_value, auto_offset_reset=auto_offset_reset, enable_auto_commit=False, - group_id=self.config['consumer_id'], + group_id=consumer_id, ) self.consumer.subscribe( - topics=['%s.%s' % (self.config['topic_prefix'], object_type) + topics=['%s.%s' % (topic_prefix, object_type) for object_type in object_types], ) - self.max_messages = self.config['max_messages'] - - def process(self): - """Main entry point to process event message reception. + self.max_messages = max_messages + self._object_types = object_types - """ - while True: - messages = defaultdict(list) + def poll(self): + return self.consumer.poll() - for num, message in enumerate(self.consumer): - object_type = message.topic.split('.')[-1] - messages[object_type].append(message.value) - if num + 1 >= self.max_messages: - break + def commit(self): + self.consumer.commit() - self.process_objects(messages) - self.consumer.commit() + def process(self, worker_fn): + """Polls Kafka for a batch of messages, and calls the worker_fn + with these messages. - # Override the following method in the sub-classes + Args: + worker_fn Callable[Dict[str, List[dict]]]: Function called with + the messages as + argument. + """ + nb_messages = 0 + polled = self.poll() + for (partition, messages) in polled.items(): + object_type = partition.topic.split('.')[-1] + # Got a message from a topic we did not subscribe to. + assert object_type in self._object_types, object_type - @abstractmethod - def process_objects(self, messages): - """Process the objects (store, compute, etc...) + worker_fn({object_type: [msg.value for msg in messages]}) - Args: - messages (dict): Dict of key object_type (as per - configuration) and their associated values. + nb_messages += len(messages) - """ - pass + self.commit() + return nb_messages diff --git a/swh/journal/replay.py b/swh/journal/replay.py --- a/swh/journal/replay.py +++ b/swh/journal/replay.py @@ -5,92 +5,37 @@ import logging -from kafka import KafkaConsumer - from swh.storage import HashCollision -from .serializers import kafka_to_value logger = logging.getLogger(__name__) -OBJECT_TYPES = frozenset([ - 'origin', 'origin_visit', 'snapshot', 'release', 'revision', - 'directory', 'content', -]) - - -class StorageReplayer: - def __init__(self, brokers, prefix, consumer_id, - object_types=OBJECT_TYPES): - if not set(object_types).issubset(OBJECT_TYPES): - raise ValueError('Unknown object types: %s' % ', '.join( - set(object_types) - OBJECT_TYPES)) - - self._object_types = object_types - self.consumer = KafkaConsumer( - bootstrap_servers=brokers, - value_deserializer=kafka_to_value, - auto_offset_reset='earliest', - enable_auto_commit=False, - group_id=consumer_id, - ) - self.consumer.subscribe( - topics=['%s.%s' % (prefix, object_type) - for object_type in object_types], - ) - - def poll(self): - return self.consumer.poll() - - def commit(self): - self.consumer.commit() - - def fill(self, storage, max_messages=None): - nb_messages = 0 - - def done(): - nonlocal nb_messages - return max_messages and nb_messages >= max_messages - - while not done(): - polled = self.poll() - for (partition, messages) in polled.items(): - object_type = partition.topic.split('.')[-1] - # Got a message from a topic we did not subscribe to. - assert object_type in self._object_types, object_type - - self.insert_objects(storage, object_type, - [msg.value for msg in messages]) - - nb_messages += len(messages) - if done(): - break - self.commit() - logger.info('Processed %d messages.' % nb_messages) - return nb_messages - - def insert_objects(self, storage, object_type, objects): - if object_type in ('content', 'directory', 'revision', 'release', - 'snapshot', 'origin'): - if object_type == 'content': - # TODO: insert 'content' in batches - for object_ in objects: - try: - storage.content_add_metadata([object_]) - except HashCollision as e: - logger.error('Hash collision: %s', e.args) - else: - # TODO: split batches that are too large for the storage - # to handle? - method = getattr(storage, object_type + '_add') - method(objects) - elif object_type == 'origin_visit': - storage.origin_visit_upsert([ - { - **obj, - 'origin': storage.origin_add_one(obj['origin']) - } - for obj in objects]) - else: - assert False +def process_replay_objects(all_objects, *, storage): + for (object_type, objects) in all_objects.items(): + _insert_objects(object_type, objects, storage) + + +def _insert_objects(object_type, objects, storage): + if object_type == 'content': + # TODO: insert 'content' in batches + for object_ in objects: + try: + storage.content_add_metadata([object_]) + except HashCollision as e: + logger.error('Hash collision: %s', e.args) + elif object_type in ('directory', 'revision', 'release', + 'snapshot', 'origin'): + # TODO: split batches that are too large for the storage + # to handle? + method = getattr(storage, object_type + '_add') + method(objects) + elif object_type == 'origin_visit': + storage.origin_visit_upsert([ + { + **obj, + 'origin': storage.origin_add_one(obj['origin']) + } + for obj in objects]) + else: + assert False diff --git a/swh/journal/tests/test_replay.py b/swh/journal/tests/test_replay.py --- a/swh/journal/tests/test_replay.py +++ b/swh/journal/tests/test_replay.py @@ -4,6 +4,7 @@ # See top-level LICENSE file for more information import datetime +import functools import random from subprocess import Popen from typing import Tuple @@ -13,8 +14,9 @@ from swh.storage import get_storage +from swh.journal.client import JournalClient from swh.journal.serializers import key_to_kafka, value_to_kafka -from swh.journal.replay import StorageReplayer +from swh.journal.replay import process_replay_objects from .conftest import OBJECT_TYPE_KEYS @@ -56,10 +58,14 @@ config = { 'brokers': 'localhost:%d' % kafka_server[1], 'consumer_id': 'replayer', - 'prefix': kafka_prefix, + 'topic_prefix': kafka_prefix, + 'max_messages': nb_sent, } - replayer = StorageReplayer(**config) - nb_inserted = replayer.fill(storage, max_messages=nb_sent) + replayer = JournalClient(**config) + worker_fn = functools.partial(process_replay_objects, storage=storage) + nb_inserted = 0 + while nb_inserted < nb_sent: + nb_inserted += replayer.process(worker_fn) assert nb_sent == nb_inserted # Check the objects were actually inserted in the storage diff --git a/swh/journal/tests/test_write_replay.py b/swh/journal/tests/test_write_replay.py --- a/swh/journal/tests/test_write_replay.py +++ b/swh/journal/tests/test_write_replay.py @@ -4,6 +4,7 @@ # See top-level LICENSE file for more information from collections import namedtuple +import functools from hypothesis import given, settings, HealthCheck from hypothesis.strategies import lists @@ -12,10 +13,11 @@ from swh.storage.in_memory import Storage from swh.storage import HashCollision +from swh.journal.client import JournalClient, ACCEPTED_OBJECT_TYPES +from swh.journal.direct_writer import DirectKafkaWriter +from swh.journal.replay import process_replay_objects from swh.journal.serializers import ( key_to_kafka, kafka_to_key, value_to_kafka, kafka_to_value) -from swh.journal.direct_writer import DirectKafkaWriter -from swh.journal.replay import StorageReplayer, OBJECT_TYPES FakeKafkaMessage = namedtuple('FakeKafkaMessage', 'key value') FakeKafkaPartition = namedtuple('FakeKafkaPartition', 'topic') @@ -26,8 +28,8 @@ self._prefix = 'prefix' -class MockedStorageReplayer(StorageReplayer): - def __init__(self, object_types=OBJECT_TYPES): +class MockedJournalClient(JournalClient): + def __init__(self, object_types=ACCEPTED_OBJECT_TYPES): self._object_types = object_types @@ -72,11 +74,16 @@ pass storage2 = Storage() - replayer = MockedStorageReplayer() + worker_fn = functools.partial(process_replay_objects, storage=storage2) + replayer = MockedJournalClient() replayer.poll = poll replayer.commit = commit - replayer.fill(storage2, max_messages=len(queue)) + queue_size = len(queue) + nb_messages = 0 + while nb_messages < queue_size: + nb_messages += replayer.process(worker_fn) + assert nb_messages == queue_size assert committed for attr_name in ('_contents', '_directories', '_revisions', '_releases', @@ -135,10 +142,13 @@ for partition in batch.values()) storage2 = Storage() - replayer = MockedStorageReplayer() + worker_fn = functools.partial(process_replay_objects, storage=storage2) + replayer = MockedJournalClient() replayer.poll = poll replayer.commit = commit - replayer.fill(storage2, max_messages=queue_size) + nb_messages = 0 + while nb_messages < queue_size: + nb_messages += replayer.process(worker_fn) assert committed