diff --git a/swh/journal/replay.py b/swh/journal/replay.py index c29f449..040001d 100644 --- a/swh/journal/replay.py +++ b/swh/journal/replay.py @@ -1,93 +1,96 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging from kafka import KafkaConsumer from swh.storage import HashCollision from .serializers import kafka_to_value logger = logging.getLogger(__name__) OBJECT_TYPES = frozenset([ 'origin', 'origin_visit', 'snapshot', 'release', 'revision', 'directory', 'content', ]) class StorageReplayer: def __init__(self, brokers, prefix, consumer_id, object_types=OBJECT_TYPES): if not set(object_types).issubset(OBJECT_TYPES): raise ValueError('Unknown object types: %s' % ', '.join( set(object_types) - OBJECT_TYPES)) self._object_types = object_types self.consumer = KafkaConsumer( bootstrap_servers=brokers, value_deserializer=kafka_to_value, auto_offset_reset='earliest', enable_auto_commit=False, group_id=consumer_id, ) self.consumer.subscribe( topics=['%s.%s' % (prefix, object_type) for object_type in object_types], ) def poll(self): return self.consumer.poll() def commit(self): self.consumer.commit() def fill(self, storage, max_messages=None): nb_messages = 0 def done(): nonlocal nb_messages return max_messages and nb_messages >= max_messages while not done(): polled = self.poll() for (partition, messages) in polled.items(): - assert messages - for message in messages: - object_type = partition.topic.split('.')[-1] + object_type = partition.topic.split('.')[-1] + # Got a message from a topic we did not subscribe to. + assert object_type in self._object_types, object_type - # Got a message from a topic we did not subscribe to. - assert object_type in self._object_types, object_type + self.insert_objects(storage, object_type, + [msg.value for msg in messages]) - self.insert_object(storage, object_type, message.value) - - nb_messages += 1 - if done(): - break + nb_messages += len(messages) if done(): break self.commit() logger.info('Processed %d messages.' % nb_messages) return nb_messages - def insert_object(self, storage, object_type, object_): + def insert_objects(self, storage, object_type, objects): if object_type in ('content', 'directory', 'revision', 'release', 'snapshot', 'origin'): if object_type == 'content': - try: - storage.content_add_metadata([object_]) - except HashCollision as e: - logger.error('Hash collision: %s', e.args) + # TODO: insert 'content' in batches + for object_ in objects: + try: + storage.content_add_metadata([object_]) + except HashCollision as e: + logger.error('Hash collision: %s', e.args) else: + # TODO: split batches that are too large for the storage + # to handle? method = getattr(storage, object_type + '_add') - method([object_]) + method(objects) elif object_type == 'origin_visit': - storage.origin_visit_upsert([{ - **object_, - 'origin': storage.origin_add_one(object_['origin'])}]) + storage.origin_visit_upsert([ + { + **obj, + 'origin': storage.origin_add_one(obj['origin']) + } + for obj in objects]) else: assert False diff --git a/swh/journal/tests/test_write_replay.py b/swh/journal/tests/test_write_replay.py index 720a642..1463960 100644 --- a/swh/journal/tests/test_write_replay.py +++ b/swh/journal/tests/test_write_replay.py @@ -1,88 +1,151 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import namedtuple from hypothesis import given, settings, HealthCheck from hypothesis.strategies import lists from swh.model.hypothesis_strategies import object_dicts from swh.storage.in_memory import Storage from swh.storage import HashCollision from swh.journal.serializers import ( key_to_kafka, kafka_to_key, value_to_kafka, kafka_to_value) from swh.journal.direct_writer import DirectKafkaWriter from swh.journal.replay import StorageReplayer, OBJECT_TYPES FakeKafkaMessage = namedtuple('FakeKafkaMessage', 'key value') FakeKafkaPartition = namedtuple('FakeKafkaPartition', 'topic') class MockedDirectKafkaWriter(DirectKafkaWriter): def __init__(self): self._prefix = 'prefix' class MockedStorageReplayer(StorageReplayer): def __init__(self, object_types=OBJECT_TYPES): self._object_types = object_types @given(lists(object_dicts(), min_size=1)) @settings(suppress_health_check=[HealthCheck.too_slow]) def test_write_replay_same_order(objects): committed = False queue = [] def send(topic, key, value): key = kafka_to_key(key_to_kafka(key)) value = kafka_to_value(value_to_kafka(value)) queue.append({ FakeKafkaPartition(topic): [FakeKafkaMessage(key=key, value=value)] }) def poll(): return queue.pop(0) def commit(): nonlocal committed if queue == []: committed = True storage1 = Storage() storage1.journal_writer = MockedDirectKafkaWriter() storage1.journal_writer.send = send for (obj_type, obj) in objects: obj = obj.copy() if obj_type == 'origin_visit': origin_id = storage1.origin_add_one(obj.pop('origin')) if 'visit' in obj: del obj['visit'] storage1.origin_visit_add(origin_id, **obj) else: method = getattr(storage1, obj_type + '_add') try: method([obj]) except HashCollision: pass storage2 = Storage() replayer = MockedStorageReplayer() replayer.poll = poll replayer.commit = commit replayer.fill(storage2, max_messages=len(queue)) assert committed for attr_name in ('_contents', '_directories', '_revisions', '_releases', '_snapshots', '_origin_visits', '_origins'): assert getattr(storage1, attr_name) == getattr(storage2, attr_name), \ attr_name +@given(lists(object_dicts(), min_size=1)) +@settings(suppress_health_check=[HealthCheck.too_slow]) +def test_write_replay_same_order_batches(objects): + committed = False + queue = [] + + def send(topic, key, value): + key = kafka_to_key(key_to_kafka(key)) + value = kafka_to_value(value_to_kafka(value)) + partition = FakeKafkaPartition(topic) + msg = FakeKafkaMessage(key=key, value=value) + if queue and {partition} == set(queue[-1]): + # The last message is of the same object type, groupping them + queue[-1][partition].append(msg) + else: + queue.append({ + FakeKafkaPartition(topic): [msg] + }) + + def poll(): + return queue.pop(0) + + def commit(): + nonlocal committed + if queue == []: + committed = True + + storage1 = Storage() + storage1.journal_writer = MockedDirectKafkaWriter() + storage1.journal_writer.send = send + + for (obj_type, obj) in objects: + obj = obj.copy() + if obj_type == 'origin_visit': + origin_id = storage1.origin_add_one(obj.pop('origin')) + if 'visit' in obj: + del obj['visit'] + storage1.origin_visit_add(origin_id, **obj) + else: + method = getattr(storage1, obj_type + '_add') + try: + method([obj]) + except HashCollision: + pass + + queue_size = sum(len(partition) + for batch in queue + for partition in batch.values()) + + storage2 = Storage() + replayer = MockedStorageReplayer() + replayer.poll = poll + replayer.commit = commit + replayer.fill(storage2, max_messages=queue_size) + + assert committed + + for attr_name in ('_contents', '_directories', '_revisions', '_releases', + '_snapshots', '_origin_visits', '_origins'): + assert getattr(storage1, attr_name) == getattr(storage2, attr_name), \ + attr_name + + # TODO: add test for hash collision