diff --git a/swh/journal/tests/test_replay.py b/swh/journal/tests/test_replay.py --- a/swh/journal/tests/test_replay.py +++ b/swh/journal/tests/test_replay.py @@ -108,8 +108,17 @@ assert contents == OBJECT_TYPE_KEYS['content'][1] -def test_write_replay_legacy_origin_visit1(): - """Test origin_visit when the 'origin' is just a string.""" +def _test_write_replay_origin_visit(visits): + """Helper function to write tests for origin_visit. + + Each visit (a dict) given in the 'visits' argument will be sent to + a (mocked) kafka queue, which a in-memory-storage backed replayer is + listening to. + + Check that corresponding origin visits entities are present in the storage + and have correct values. + + """ queue = [] replayer = MockedJournalClient(queue) writer = MockedKafkaWriter(queue) @@ -117,16 +126,12 @@ # Note that flipping the order of these two insertions will crash # the test, because the legacy origin_format does not allow to create # the origin when needed (type is missing) - now = datetime.datetime.now() writer.send('origin', 'foo', { 'url': 'http://example.com/', 'type': 'git', }) - writer.send('origin_visit', 'foo', { - 'visit': 1, - 'origin': 'http://example.com/', - 'date': now, - }) + for visit in visits: + writer.send('origin_visit', 'foo', visit) queue_size = sum(len(partition) for batch in queue @@ -138,68 +143,46 @@ while nb_messages < queue_size: nb_messages += replayer.process(worker_fn) - visits = list(storage.origin_visit_get('http://example.com/')) + actual_visits = list(storage.origin_visit_get('http://example.com/')) + + assert len(actual_visits) == len(visits), actual_visits - if ENABLE_ORIGIN_IDS: - assert visits == [{ - 'visit': 1, - 'origin': 1, - 'date': now, - }] - else: - assert visits == [{ - 'visit': 1, - 'origin': 'http://example.com/', - 'date': now, - }] + for vin, vout in zip(visits, actual_visits): + vin = vin.copy() + vout = vout.copy() + if ENABLE_ORIGIN_IDS: + assert vout.pop('origin') == 1 + else: + assert vout.pop('origin') == 'http://example.com/' + vin.pop('origin') + vin.setdefault('type', 'git') + assert vin == vout + + +def test_write_replay_legacy_origin_visit1(): + """Test origin_visit when the 'origin' is just a string.""" + now = datetime.datetime.now() + visits = [{ + 'visit': 1, + 'origin': 'http://example.com/', + 'date': now, + 'type': 'hg' + }] + _test_write_replay_origin_visit(visits) def test_write_replay_legacy_origin_visit2(): """Test origin_visit when 'type' is missing.""" - queue = [] - replayer = MockedJournalClient(queue) - writer = MockedKafkaWriter(queue) - now = datetime.datetime.now() - writer.send('origin', 'foo', { - 'url': 'http://example.com/', - 'type': 'git', - }) - writer.send('origin_visit', 'foo', { + visits = [{ 'visit': 1, 'origin': { 'url': 'http://example.com/', 'type': 'git', }, 'date': now, - }) - - queue_size = sum(len(partition) - for batch in queue - for partition in batch.values()) - - storage = get_storage('memory', {}) - worker_fn = functools.partial(process_replay_objects, storage=storage) - nb_messages = 0 - while nb_messages < queue_size: - nb_messages += replayer.process(worker_fn) - - visits = list(storage.origin_visit_get('http://example.com/')) - - if ENABLE_ORIGIN_IDS: - assert visits == [{ - 'visit': 1, - 'origin': 1, - 'date': now, - 'type': 'git', - }] - else: - assert visits == [{ - 'visit': 1, - 'origin': 'http://example.com/', - 'date': now, - 'type': 'git', - }] + }] + _test_write_replay_origin_visit(visits) hash_strategy = strategies.binary(min_size=20, max_size=20)