diff --git a/requirements-test.txt b/requirements-test.txt --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,4 +1,4 @@ pytest -swh.model +swh.model >= 0.0.32 pytest-kafka hypothesis diff --git a/swh/journal/replay.py b/swh/journal/replay.py --- a/swh/journal/replay.py +++ b/swh/journal/replay.py @@ -7,6 +7,8 @@ from kafka import KafkaConsumer +from swh.storage import HashCollision + from .serializers import kafka_to_value logger = logging.getLogger(__name__) @@ -60,15 +62,16 @@ if object_type in ('content', 'directory', 'revision', 'release', 'snapshot', 'origin'): if object_type == 'content': - method = storage.content_add_metadata + try: + storage.content_add_metadata([object_]) + except HashCollision as e: + logger.error('Hash collision: %s', e.args) else: method = getattr(storage, object_type + '_add') - method([object_]) + method([object_]) elif object_type == 'origin_visit': - origin_id = storage.origin_add_one(object_.pop('origin')) - visit = storage.origin_visit_add( - origin=origin_id, date=object_.pop('date')) - storage.origin_visit_update( - origin_id, visit['visit'], **object_) + storage.origin_visit_upsert([{ + **object_, + 'origin': storage.origin_add_one(object_['origin'])}]) else: assert False diff --git a/swh/journal/tests/test_replay.py b/swh/journal/tests/test_replay.py --- a/swh/journal/tests/test_replay.py +++ b/swh/journal/tests/test_replay.py @@ -38,6 +38,7 @@ # Fill Kafka nb_sent = 0 + nb_visits = 0 for (object_type, (_, objects)) in OBJECT_TYPE_KEYS.items(): topic = kafka_prefix + '.' + object_type for object_ in objects: @@ -45,6 +46,9 @@ object_ = object_.copy() if object_type == 'content': object_['ctime'] = now + elif object_type == 'origin_visit': + nb_visits += 1 + object_['visit'] = nb_visits producer.send(topic, key=key, value=object_) nb_sent += 1 diff --git a/swh/journal/tests/test_write_replay.py b/swh/journal/tests/test_write_replay.py --- a/swh/journal/tests/test_write_replay.py +++ b/swh/journal/tests/test_write_replay.py @@ -6,12 +6,11 @@ from collections import namedtuple from hypothesis import given, settings, HealthCheck -from hypothesis.strategies import lists, one_of, composite +from hypothesis.strategies import lists -from swh.model.hashutil import MultiHash +from swh.model.hypothesis_strategies import object_dicts from swh.storage.in_memory import Storage -from swh.storage.tests.algos.test_snapshot import snapshots, origins -from swh.storage.tests.generate_data_test import gen_raw_content +from swh.storage import HashCollision from swh.journal.serializers import ( key_to_kafka, kafka_to_key, value_to_kafka, kafka_to_value) @@ -31,36 +30,7 @@ self._object_types = object_types -@composite -def contents(draw): - """Generate valid and consistent content. - - Context: Test purposes - - Args: - **draw**: Used by hypothesis to generate data - - Returns: - dict representing a content. - - """ - raw_content = draw(gen_raw_content()) - return { - 'data': raw_content, - 'length': len(raw_content), - 'status': 'visible', - **MultiHash.from_data(raw_content).digest() - } - - -objects = lists(one_of( - origins().map(lambda x: ('origin', x)), - snapshots().map(lambda x: ('snapshot', x)), - contents().map(lambda x: ('content', x)), -)) - - -@given(objects) +@given(lists(object_dicts())) @settings(suppress_health_check=[HealthCheck.too_slow]) def test_write_replay_same_order(objects): queue = [] @@ -78,14 +48,28 @@ storage1.journal_writer.send = send for (obj_type, obj) in objects: - method = getattr(storage1, obj_type + '_add') - method([obj]) + obj = obj.copy() + if obj_type == 'origin_visit': + origin_id = storage1.origin_add_one(obj.pop('origin')) + if 'visit' in obj: + del obj['visit'] + storage1.origin_visit_add(origin_id, **obj) + else: + method = getattr(storage1, obj_type + '_add') + try: + method([obj]) + except HashCollision: + pass storage2 = Storage() replayer = MockedStorageReplayer() replayer.poll = poll replayer.fill(storage2) - for attr in ('_contents', '_directories', '_revisions', '_releases', - '_snapshots', '_origin_visits', '_origins'): - assert getattr(storage1, attr) == getattr(storage2, attr), attr + for attr_name in ('_contents', '_directories', '_revisions', '_releases', + '_snapshots', '_origin_visits', '_origins'): + assert getattr(storage1, attr_name) == getattr(storage2, attr_name), \ + attr_name + + +# TODO: add test for hash collision