diff --git a/swh/journal/replay.py b/swh/journal/replay.py --- a/swh/journal/replay.py +++ b/swh/journal/replay.py @@ -56,38 +56,41 @@ while not done(): polled = self.poll() for (partition, messages) in polled.items(): - assert messages - for message in messages: - object_type = partition.topic.split('.')[-1] + object_type = partition.topic.split('.')[-1] + # Got a message from a topic we did not subscribe to. + assert object_type in self._object_types, object_type - # Got a message from a topic we did not subscribe to. - assert object_type in self._object_types, object_type + self.insert_objects(storage, object_type, + [msg.value for msg in messages]) - self.insert_object(storage, object_type, message.value) - - nb_messages += 1 - if done(): - break + nb_messages += len(messages) if done(): break self.commit() logger.info('Processed %d messages.' % nb_messages) return nb_messages - def insert_object(self, storage, object_type, object_): + def insert_objects(self, storage, object_type, objects): if object_type in ('content', 'directory', 'revision', 'release', 'snapshot', 'origin'): if object_type == 'content': - try: - storage.content_add_metadata([object_]) - except HashCollision as e: - logger.error('Hash collision: %s', e.args) + # TODO: insert 'content' in batches + for object_ in objects: + try: + storage.content_add_metadata([object_]) + except HashCollision as e: + logger.error('Hash collision: %s', e.args) else: + # TODO: split batches that are too large for the storage + # to handle? method = getattr(storage, object_type + '_add') - method([object_]) + method(objects) elif object_type == 'origin_visit': - storage.origin_visit_upsert([{ - **object_, - 'origin': storage.origin_add_one(object_['origin'])}]) + storage.origin_visit_upsert([ + { + **obj, + 'origin': storage.origin_add_one(obj['origin']) + } + for obj in objects]) else: assert False diff --git a/swh/journal/tests/test_write_replay.py b/swh/journal/tests/test_write_replay.py --- a/swh/journal/tests/test_write_replay.py +++ b/swh/journal/tests/test_write_replay.py @@ -85,4 +85,67 @@ attr_name +@given(lists(object_dicts(), min_size=1)) +@settings(suppress_health_check=[HealthCheck.too_slow]) +def test_write_replay_same_order_batches(objects): + committed = False + queue = [] + + def send(topic, key, value): + key = kafka_to_key(key_to_kafka(key)) + value = kafka_to_value(value_to_kafka(value)) + partition = FakeKafkaPartition(topic) + msg = FakeKafkaMessage(key=key, value=value) + if queue and {partition} == set(queue[-1]): + # The last message is of the same object type, groupping them + queue[-1][partition].append(msg) + else: + queue.append({ + FakeKafkaPartition(topic): [msg] + }) + + def poll(): + return queue.pop(0) + + def commit(): + nonlocal committed + if queue == []: + committed = True + + storage1 = Storage() + storage1.journal_writer = MockedDirectKafkaWriter() + storage1.journal_writer.send = send + + for (obj_type, obj) in objects: + obj = obj.copy() + if obj_type == 'origin_visit': + origin_id = storage1.origin_add_one(obj.pop('origin')) + if 'visit' in obj: + del obj['visit'] + storage1.origin_visit_add(origin_id, **obj) + else: + method = getattr(storage1, obj_type + '_add') + try: + method([obj]) + except HashCollision: + pass + + queue_size = sum(len(partition) + for batch in queue + for partition in batch.values()) + + storage2 = Storage() + replayer = MockedStorageReplayer() + replayer.poll = poll + replayer.commit = commit + replayer.fill(storage2, max_messages=queue_size) + + assert committed + + for attr_name in ('_contents', '_directories', '_revisions', '_releases', + '_snapshots', '_origin_visits', '_origins'): + assert getattr(storage1, attr_name) == getattr(storage2, attr_name), \ + attr_name + + # TODO: add test for hash collision