diff --git a/swh/journal/replay.py b/swh/journal/replay.py index ae446ce..c29f449 100644 --- a/swh/journal/replay.py +++ b/swh/journal/replay.py @@ -1,77 +1,93 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging from kafka import KafkaConsumer from swh.storage import HashCollision from .serializers import kafka_to_value logger = logging.getLogger(__name__) OBJECT_TYPES = frozenset([ 'origin', 'origin_visit', 'snapshot', 'release', 'revision', 'directory', 'content', ]) class StorageReplayer: def __init__(self, brokers, prefix, consumer_id, object_types=OBJECT_TYPES): if not set(object_types).issubset(OBJECT_TYPES): raise ValueError('Unknown object types: %s' % ', '.join( set(object_types) - OBJECT_TYPES)) self._object_types = object_types self.consumer = KafkaConsumer( bootstrap_servers=brokers, value_deserializer=kafka_to_value, auto_offset_reset='earliest', enable_auto_commit=False, group_id=consumer_id, ) self.consumer.subscribe( topics=['%s.%s' % (prefix, object_type) for object_type in object_types], ) def poll(self): - yield from self.consumer + return self.consumer.poll() - def fill(self, storage, max_messages=None): - num = 0 - for message in self.poll(): - object_type = message.topic.split('.')[-1] - - # Got a message from a topic we did not subscribe to. - assert object_type in self._object_types, object_type + def commit(self): + self.consumer.commit() - self.insert_object(storage, object_type, message.value) - - num += 1 - if max_messages and num >= max_messages: - break - return num + def fill(self, storage, max_messages=None): + nb_messages = 0 + + def done(): + nonlocal nb_messages + return max_messages and nb_messages >= max_messages + + while not done(): + polled = self.poll() + for (partition, messages) in polled.items(): + assert messages + for message in messages: + object_type = partition.topic.split('.')[-1] + + # Got a message from a topic we did not subscribe to. + assert object_type in self._object_types, object_type + + self.insert_object(storage, object_type, message.value) + + nb_messages += 1 + if done(): + break + if done(): + break + self.commit() + logger.info('Processed %d messages.' % nb_messages) + return nb_messages def insert_object(self, storage, object_type, object_): if object_type in ('content', 'directory', 'revision', 'release', 'snapshot', 'origin'): if object_type == 'content': try: storage.content_add_metadata([object_]) except HashCollision as e: logger.error('Hash collision: %s', e.args) else: method = getattr(storage, object_type + '_add') method([object_]) elif object_type == 'origin_visit': storage.origin_visit_upsert([{ **object_, 'origin': storage.origin_add_one(object_['origin'])}]) else: assert False diff --git a/swh/journal/tests/test_write_replay.py b/swh/journal/tests/test_write_replay.py index 03d2bc6..720a642 100644 --- a/swh/journal/tests/test_write_replay.py +++ b/swh/journal/tests/test_write_replay.py @@ -1,75 +1,88 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import namedtuple from hypothesis import given, settings, HealthCheck from hypothesis.strategies import lists from swh.model.hypothesis_strategies import object_dicts from swh.storage.in_memory import Storage from swh.storage import HashCollision from swh.journal.serializers import ( key_to_kafka, kafka_to_key, value_to_kafka, kafka_to_value) from swh.journal.direct_writer import DirectKafkaWriter from swh.journal.replay import StorageReplayer, OBJECT_TYPES -FakeKafkaMessage = namedtuple('FakeKafkaMessage', 'topic key value') +FakeKafkaMessage = namedtuple('FakeKafkaMessage', 'key value') +FakeKafkaPartition = namedtuple('FakeKafkaPartition', 'topic') class MockedDirectKafkaWriter(DirectKafkaWriter): def __init__(self): self._prefix = 'prefix' class MockedStorageReplayer(StorageReplayer): def __init__(self, object_types=OBJECT_TYPES): self._object_types = object_types -@given(lists(object_dicts())) +@given(lists(object_dicts(), min_size=1)) @settings(suppress_health_check=[HealthCheck.too_slow]) def test_write_replay_same_order(objects): + committed = False queue = [] def send(topic, key, value): key = kafka_to_key(key_to_kafka(key)) value = kafka_to_value(value_to_kafka(value)) - queue.append(FakeKafkaMessage(topic=topic, key=key, value=value)) + queue.append({ + FakeKafkaPartition(topic): + [FakeKafkaMessage(key=key, value=value)] + }) def poll(): - yield from queue + return queue.pop(0) + + def commit(): + nonlocal committed + if queue == []: + committed = True storage1 = Storage() storage1.journal_writer = MockedDirectKafkaWriter() storage1.journal_writer.send = send for (obj_type, obj) in objects: obj = obj.copy() if obj_type == 'origin_visit': origin_id = storage1.origin_add_one(obj.pop('origin')) if 'visit' in obj: del obj['visit'] storage1.origin_visit_add(origin_id, **obj) else: method = getattr(storage1, obj_type + '_add') try: method([obj]) except HashCollision: pass storage2 = Storage() replayer = MockedStorageReplayer() replayer.poll = poll - replayer.fill(storage2) + replayer.commit = commit + replayer.fill(storage2, max_messages=len(queue)) + + assert committed for attr_name in ('_contents', '_directories', '_revisions', '_releases', '_snapshots', '_origin_visits', '_origins'): assert getattr(storage1, attr_name) == getattr(storage2, attr_name), \ attr_name # TODO: add test for hash collision