diff --git a/swh/journal/cli.py b/swh/journal/cli.py --- a/swh/journal/cli.py +++ b/swh/journal/cli.py @@ -72,9 +72,10 @@ """ conf = ctx.obj['config'] storage = get_storage(**conf.pop('storage')) - replayer = StorageReplayer(brokers, prefix, consumer_id) + replayer = StorageReplayer(brokers, prefix, consumer_id, + storage=storage, max_messages=max_messages) try: - replayer.fill(storage, max_messages=max_messages) + replayer.process() except KeyboardInterrupt: ctx.exit(0) else: diff --git a/swh/journal/client.py b/swh/journal/client.py --- a/swh/journal/client.py +++ b/swh/journal/client.py @@ -4,27 +4,31 @@ # See top-level LICENSE file for more information from abc import ABCMeta, abstractmethod -from collections import defaultdict from kafka import KafkaConsumer +import logging -from swh.core.config import SWHConfig from .serializers import kafka_to_key, kafka_to_value + +logger = logging.getLogger(__name__) + + # Only accepted offset reset policy accepted ACCEPTED_OFFSET_RESET = ['earliest', 'latest'] # Only accepted object types ACCEPTED_OBJECT_TYPES = [ 'content', + 'directory', 'revision', 'release', - 'occurrence', + 'snapshot', 'origin', 'origin_visit' ] -class JournalClient(SWHConfig, metaclass=ABCMeta): +class JournalClient(metaclass=ABCMeta): """A base client for the Software Heritage journal. The current implementation of the journal uses Apache Kafka @@ -42,41 +46,16 @@ of maximum `max_messages`. """ - DEFAULT_CONFIG = { - # Broker to connect to - 'brokers': ('list[str]', ['localhost']), - # Prefix topic to receive notification from - 'topic_prefix': ('str', 'swh.journal.objects'), - # Consumer identifier - 'consumer_id': ('str', 'swh.journal.client'), - # Object types to deal with (in a subscription manner) - 'object_types': ('list[str]', [ - 'content', 'revision', - 'release', 'occurrence', - 'origin', 'origin_visit' - ]), - # Number of messages to batch process - 'max_messages': ('int', 100), - 'auto_offset_reset': ('str', 'earliest') - } - - CONFIG_BASE_FILENAME = 'journal/client' - - ADDITIONAL_CONFIG = {} - - def __init__(self, extra_configuration={}): - self.config = self.parse_config_file( - additional_configs=[self.ADDITIONAL_CONFIG]) - if extra_configuration: - self.config.update(extra_configuration) - - auto_offset_reset = self.config['auto_offset_reset'] + def __init__( + self, brokers, topic_prefix, consumer_id, + object_types=ACCEPTED_OBJECT_TYPES, + max_messages=0, auto_offset_reset='earliest'): + if auto_offset_reset not in ACCEPTED_OFFSET_RESET: raise ValueError( 'Option \'auto_offset_reset\' only accept %s.' % ACCEPTED_OFFSET_RESET) - object_types = self.config['object_types'] for object_type in object_types: if object_type not in ACCEPTED_OBJECT_TYPES: raise ValueError( @@ -84,36 +63,51 @@ ACCEPTED_OFFSET_RESET) self.consumer = KafkaConsumer( - bootstrap_servers=self.config['brokers'], + bootstrap_servers=brokers, key_deserializer=kafka_to_key, value_deserializer=kafka_to_value, auto_offset_reset=auto_offset_reset, enable_auto_commit=False, - group_id=self.config['consumer_id'], + group_id=consumer_id, ) self.consumer.subscribe( - topics=['%s.%s' % (self.config['topic_prefix'], object_type) + topics=['%s.%s' % (topic_prefix, object_type) for object_type in object_types], ) - self.max_messages = self.config['max_messages'] + self.max_messages = max_messages + self._object_types = object_types - def process(self): - """Main entry point to process event message reception. + def poll(self): + return self.consumer.poll() - """ - while True: - messages = defaultdict(list) + def commit(self): + self.consumer.commit() - for num, message in enumerate(self.consumer): - object_type = message.topic.split('.')[-1] - messages[object_type].append(message.value) - if num + 1 >= self.max_messages: - break + def process(self, max_messages=None): + nb_messages = 0 + + def done(): + nonlocal nb_messages + return self.max_messages and nb_messages >= self.max_messages + + while not done(): + polled = self.poll() + for (partition, messages) in polled.items(): + object_type = partition.topic.split('.')[-1] + # Got a message from a topic we did not subscribe to. + assert object_type in self._object_types, object_type - self.process_objects(messages) - self.consumer.commit() + self.process_objects( + {object_type: [msg.value for msg in messages]}) + + nb_messages += len(messages) + if done(): + break + self.commit() + logger.info('Processed %d messages.' % nb_messages) + return nb_messages # Override the following method in the sub-classes diff --git a/swh/journal/replay.py b/swh/journal/replay.py --- a/swh/journal/replay.py +++ b/swh/journal/replay.py @@ -5,11 +5,10 @@ import logging -from kafka import KafkaConsumer - from swh.storage import HashCollision -from .serializers import kafka_to_value +from .client import JournalClient + logger = logging.getLogger(__name__) @@ -20,76 +19,35 @@ ]) -class StorageReplayer: - def __init__(self, brokers, prefix, consumer_id, - object_types=OBJECT_TYPES): - if not set(object_types).issubset(OBJECT_TYPES): - raise ValueError('Unknown object types: %s' % ', '.join( - set(object_types) - OBJECT_TYPES)) - - self._object_types = object_types - self.consumer = KafkaConsumer( - bootstrap_servers=brokers, - value_deserializer=kafka_to_value, - auto_offset_reset='earliest', - enable_auto_commit=False, - group_id=consumer_id, - ) - self.consumer.subscribe( - topics=['%s.%s' % (prefix, object_type) - for object_type in object_types], - ) - - def poll(self): - return self.consumer.poll() - - def commit(self): - self.consumer.commit() - - def fill(self, storage, max_messages=None): - nb_messages = 0 - - def done(): - nonlocal nb_messages - return max_messages and nb_messages >= max_messages - - while not done(): - polled = self.poll() - for (partition, messages) in polled.items(): - object_type = partition.topic.split('.')[-1] - # Got a message from a topic we did not subscribe to. - assert object_type in self._object_types, object_type - - self.insert_objects(storage, object_type, - [msg.value for msg in messages]) +class StorageReplayer(JournalClient): + def __init__(self, *args, storage, **kwargs): + super().__init__(*args, **kwargs) + self.storage = storage - nb_messages += len(messages) - if done(): - break - self.commit() - logger.info('Processed %d messages.' % nb_messages) - return nb_messages + def process_objects(self, all_objects): + for (object_type, objects) in all_objects.items(): + self.insert_objects(object_type, objects) - def insert_objects(self, storage, object_type, objects): + def insert_objects(self, object_type, objects): if object_type in ('content', 'directory', 'revision', 'release', 'snapshot', 'origin'): if object_type == 'content': # TODO: insert 'content' in batches for object_ in objects: try: - storage.content_add_metadata([object_]) + self.storage.content_add_metadata([object_]) except HashCollision as e: logger.error('Hash collision: %s', e.args) else: # TODO: split batches that are too large for the storage # to handle? - method = getattr(storage, object_type + '_add') + method = getattr(self.storage, object_type + '_add') method(objects) elif object_type == 'origin_visit': - storage.origin_visit_upsert([ + self.storage.origin_visit_upsert([ { **obj, - 'origin': storage.origin_add_one(obj['origin']) + 'origin': self.storage.origin_add_one(obj['origin']) } for obj in objects]) else: diff --git a/swh/journal/tests/test_replay.py b/swh/journal/tests/test_replay.py --- a/swh/journal/tests/test_replay.py +++ b/swh/journal/tests/test_replay.py @@ -56,10 +56,11 @@ config = { 'brokers': 'localhost:%d' % kafka_server[1], 'consumer_id': 'replayer', - 'prefix': kafka_prefix, + 'topic_prefix': kafka_prefix, + 'max_messages': nb_sent, } - replayer = StorageReplayer(**config) - nb_inserted = replayer.fill(storage, max_messages=nb_sent) + replayer = StorageReplayer(**config, storage=storage) + nb_inserted = replayer.process() assert nb_sent == nb_inserted # Check the objects were actually inserted in the storage diff --git a/swh/journal/tests/test_write_replay.py b/swh/journal/tests/test_write_replay.py --- a/swh/journal/tests/test_write_replay.py +++ b/swh/journal/tests/test_write_replay.py @@ -27,7 +27,9 @@ class MockedStorageReplayer(StorageReplayer): - def __init__(self, object_types=OBJECT_TYPES): + def __init__(self, storage, max_messages, object_types=OBJECT_TYPES): + self.storage = storage + self.max_messages = max_messages self._object_types = object_types @@ -72,10 +74,10 @@ pass storage2 = Storage() - replayer = MockedStorageReplayer() + replayer = MockedStorageReplayer(storage2, max_messages=len(queue)) replayer.poll = poll replayer.commit = commit - replayer.fill(storage2, max_messages=len(queue)) + replayer.process() assert committed @@ -135,10 +137,10 @@ for partition in batch.values()) storage2 = Storage() - replayer = MockedStorageReplayer() + replayer = MockedStorageReplayer(storage2, max_messages=queue_size) replayer.poll = poll replayer.commit = commit - replayer.fill(storage2, max_messages=queue_size) + replayer.process() assert committed