diff --git a/.gitignore b/.gitignore index 7ee0a81..6c29f49 100644 --- a/.gitignore +++ b/.gitignore @@ -1,15 +1,16 @@ *.pyc *.sw? *~ .coverage .eggs/ __pycache__ *.egg-info/ build/ dist/ version.txt .tox/ kafka/ kafka*.tgz* kafka*.tar.gz* -swh/journal/tests/kafka* \ No newline at end of file +swh/journal/tests/kafka* +.hypothesis/ diff --git a/requirements-test.txt b/requirements-test.txt index 03262c1..6d644bb 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,3 +1,4 @@ pytest swh.model pytest-kafka +hypothesis diff --git a/swh/journal/direct_writer.py b/swh/journal/direct_writer.py index 6900a32..bff463c 100644 --- a/swh/journal/direct_writer.py +++ b/swh/journal/direct_writer.py @@ -1,64 +1,69 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging from kafka import KafkaProducer from .serializers import key_to_kafka, value_to_kafka logger = logging.getLogger(__name__) class DirectKafkaWriter: """This class is instantiated and used by swh-storage to write incoming new objects to Kafka before adding them to the storage backend (eg. postgresql) itself.""" def __init__(self, brokers, prefix, client_id): self._prefix = prefix self.producer = KafkaProducer( bootstrap_servers=brokers, key_serializer=key_to_kafka, value_serializer=value_to_kafka, client_id=client_id, ) + def send(self, topic, key, value): + self.producer.send(topic=topic, key=key, value=value) + def _get_key(self, object_type, object_): if object_type in ('revision', 'release', 'directory', 'snapshot'): return object_['id'] elif object_type == 'content': return object_['sha1'] # TODO: use a dict of hashes elif object_type == 'origin': return {'url': object_['url'], 'type': object_['type']} elif object_type == 'origin_visit': return { 'origin': object_['origin'], 'date': str(object_['date']), } else: raise ValueError('Unknown object type: %s.' % object_type) def _sanitize_object(self, object_type, object_): if object_type == 'origin_visit': # Compatibility with the publisher's format return { **object_, 'date': str(object_['date']), } + elif object_type == 'origin': + assert 'id' not in object_ return object_ def write_addition(self, object_type, object_): topic = '%s.%s' % (self._prefix, object_type) key = self._get_key(object_type, object_) object_ = self._sanitize_object(object_type, object_) logger.debug('topic: %s, key: %s, value: %s' % (topic, key, object_)) - self.producer.send(topic, key=key, value=object_) + self.send(topic, key=key, value=object_) write_update = write_addition def write_additions(self, object_type, objects): for object_ in objects: self.write_addition(object_type, object_) diff --git a/swh/journal/replay.py b/swh/journal/replay.py index ec84d5c..eb894b6 100644 --- a/swh/journal/replay.py +++ b/swh/journal/replay.py @@ -1,71 +1,74 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging from kafka import KafkaConsumer from .serializers import kafka_to_value logger = logging.getLogger(__name__) OBJECT_TYPES = frozenset([ 'origin', 'origin_visit', 'snapshot', 'release', 'revision', 'directory', 'content', ]) class StorageReplayer: def __init__(self, brokers, prefix, consumer_id, object_types=OBJECT_TYPES): if not set(object_types).issubset(OBJECT_TYPES): raise ValueError('Unknown object types: %s' % ', '.join( set(object_types) - OBJECT_TYPES)) self._object_types = object_types self.consumer = KafkaConsumer( bootstrap_servers=brokers, value_deserializer=kafka_to_value, auto_offset_reset='earliest', enable_auto_commit=False, group_id=consumer_id, ) self.consumer.subscribe( topics=['%s.%s' % (prefix, object_type) for object_type in object_types], ) + def poll(self): + yield from self.consumer + def fill(self, storage, max_messages=None): num = 0 - for message in self.consumer: + for message in self.poll(): object_type = message.topic.split('.')[-1] # Got a message from a topic we did not subscribe to. assert object_type in self._object_types, object_type self.insert_object(storage, object_type, message.value) num += 1 if max_messages and num >= max_messages: break return num def insert_object(self, storage, object_type, object_): if object_type in ('content', 'directory', 'revision', 'release', 'snapshot', 'origin'): if object_type == 'content': method = storage.content_add_metadata else: method = getattr(storage, object_type + '_add') method([object_]) elif object_type == 'origin_visit': origin_id = storage.origin_add_one(object_.pop('origin')) visit = storage.origin_visit_add( origin=origin_id, date=object_.pop('date')) storage.origin_visit_update( origin_id, visit['visit'], **object_) else: assert False diff --git a/swh/journal/tests/test_write_replay.py b/swh/journal/tests/test_write_replay.py new file mode 100644 index 0000000..aa3f79c --- /dev/null +++ b/swh/journal/tests/test_write_replay.py @@ -0,0 +1,91 @@ +# Copyright (C) 2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from collections import namedtuple + +from hypothesis import given, settings, HealthCheck +from hypothesis.strategies import lists, one_of, composite + +from swh.model.hashutil import MultiHash +from swh.storage.in_memory import Storage +from swh.storage.tests.algos.test_snapshot import snapshots, origins +from swh.storage.tests.generate_data_test import gen_raw_content + +from swh.journal.serializers import ( + key_to_kafka, kafka_to_key, value_to_kafka, kafka_to_value) +from swh.journal.direct_writer import DirectKafkaWriter +from swh.journal.replay import StorageReplayer, OBJECT_TYPES + +FakeKafkaMessage = namedtuple('FakeKafkaMessage', 'topic key value') + + +class MockedDirectKafkaWriter(DirectKafkaWriter): + def __init__(self): + self._prefix = 'prefix' + + +class MockedStorageReplayer(StorageReplayer): + def __init__(self, object_types=OBJECT_TYPES): + self._object_types = object_types + + +@composite +def contents(draw): + """Generate valid and consistent content. + + Context: Test purposes + + Args: + **draw**: Used by hypothesis to generate data + + Returns: + dict representing a content. + + """ + raw_content = draw(gen_raw_content()) + return { + 'data': raw_content, + 'length': len(raw_content), + 'status': 'visible', + **MultiHash.from_data(raw_content).digest() + } + + +objects = lists(one_of( + origins().map(lambda x: ('origin', x)), + snapshots().map(lambda x: ('snapshot', x)), + contents().map(lambda x: ('content', x)), +)) + + +@given(objects) +@settings(suppress_health_check=[HealthCheck.too_slow]) +def test_write_replay_same_order(objects): + queue = [] + + def send(topic, key, value): + key = kafka_to_key(key_to_kafka(key)) + value = kafka_to_value(value_to_kafka(value)) + queue.append(FakeKafkaMessage(topic=topic, key=key, value=value)) + + def poll(): + yield from queue + + storage1 = Storage() + storage1.journal_writer = MockedDirectKafkaWriter() + storage1.journal_writer.send = send + + for (obj_type, obj) in objects: + method = getattr(storage1, obj_type + '_add') + method([obj]) + + storage2 = Storage() + replayer = MockedStorageReplayer() + replayer.poll = poll + replayer.fill(storage2) + + for attr in ('_contents', '_directories', '_revisions', '_releases', + '_snapshots', '_origin_visits', '_origins'): + assert getattr(storage1, attr) == getattr(storage2, attr), attr