diff --git a/.gitignore b/.gitignore --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,5 @@ kafka/ kafka*.tgz* kafka*.tar.gz* -swh/journal/tests/kafka* \ No newline at end of file +swh/journal/tests/kafka* +.hypothesis/ diff --git a/requirements-test.txt b/requirements-test.txt --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,3 +1,4 @@ pytest swh.model pytest-kafka +hypothesis diff --git a/swh/journal/direct_writer.py b/swh/journal/direct_writer.py --- a/swh/journal/direct_writer.py +++ b/swh/journal/direct_writer.py @@ -26,6 +26,9 @@ client_id=client_id, ) + def send(self, topic, key, value): + self.producer.send(topic=topic, key=key, value=value) + def _get_key(self, object_type, object_): if object_type in ('revision', 'release', 'directory', 'snapshot'): return object_['id'] @@ -48,6 +51,8 @@ **object_, 'date': str(object_['date']), } + elif object_type == 'origin': + assert 'id' not in object_ return object_ def write_addition(self, object_type, object_): @@ -55,7 +60,7 @@ key = self._get_key(object_type, object_) object_ = self._sanitize_object(object_type, object_) logger.debug('topic: %s, key: %s, value: %s' % (topic, key, object_)) - self.producer.send(topic, key=key, value=object_) + self.send(topic, key=key, value=object_) write_update = write_addition diff --git a/swh/journal/replay.py b/swh/journal/replay.py --- a/swh/journal/replay.py +++ b/swh/journal/replay.py @@ -38,9 +38,12 @@ for object_type in object_types], ) + def poll(self): + yield from self.consumer + def fill(self, storage, max_messages=None): num = 0 - for message in self.consumer: + for message in self.poll(): object_type = message.topic.split('.')[-1] # Got a message from a topic we did not subscribe to. diff --git a/swh/journal/tests/test_write_replay.py b/swh/journal/tests/test_write_replay.py new file mode 100644 --- /dev/null +++ b/swh/journal/tests/test_write_replay.py @@ -0,0 +1,91 @@ +# Copyright (C) 2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from collections import namedtuple + +from hypothesis import given, settings, HealthCheck +from hypothesis.strategies import lists, one_of, composite + +from swh.model.hashutil import MultiHash +from swh.storage.in_memory import Storage +from swh.storage.tests.algos.test_snapshot import snapshots, origins +from swh.storage.tests.generate_data_test import gen_raw_content + +from swh.journal.serializers import ( + key_to_kafka, kafka_to_key, value_to_kafka, kafka_to_value) +from swh.journal.direct_writer import DirectKafkaWriter +from swh.journal.replay import StorageReplayer, OBJECT_TYPES + +FakeKafkaMessage = namedtuple('FakeKafkaMessage', 'topic key value') + + +class MockedDirectKafkaWriter(DirectKafkaWriter): + def __init__(self): + self._prefix = 'prefix' + + +class MockedStorageReplayer(StorageReplayer): + def __init__(self, object_types=OBJECT_TYPES): + self._object_types = object_types + + +@composite +def contents(draw): + """Generate valid and consistent content. + + Context: Test purposes + + Args: + **draw**: Used by hypothesis to generate data + + Returns: + dict representing a content. + + """ + raw_content = draw(gen_raw_content()) + return { + 'data': raw_content, + 'length': len(raw_content), + 'status': 'visible', + **MultiHash.from_data(raw_content).digest() + } + + +objects = lists(one_of( + origins().map(lambda x: ('origin', x)), + snapshots().map(lambda x: ('snapshot', x)), + contents().map(lambda x: ('content', x)), +)) + + +@given(objects) +@settings(suppress_health_check=[HealthCheck.too_slow]) +def test_write_replay_same_order(objects): + queue = [] + + def send(topic, key, value): + key = kafka_to_key(key_to_kafka(key)) + value = kafka_to_value(value_to_kafka(value)) + queue.append(FakeKafkaMessage(topic=topic, key=key, value=value)) + + def poll(): + yield from queue + + storage1 = Storage() + storage1.journal_writer = MockedDirectKafkaWriter() + storage1.journal_writer.send = send + + for (obj_type, obj) in objects: + method = getattr(storage1, obj_type + '_add') + method([obj]) + + storage2 = Storage() + replayer = MockedStorageReplayer() + replayer.poll = poll + replayer.fill(storage2) + + for attr in ('_contents', '_directories', '_revisions', '_releases', + '_snapshots', '_origin_visits', '_origins'): + assert getattr(storage1, attr) == getattr(storage2, attr), attr