diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1697,14 +1697,16 @@ for visit in expected_visits: assert visit.to_dict() in actual_origin_visits - actual_objects = set(swh_storage.journal_writer.journal.objects) - # we write to the journal as many times as we call the endpoint - assert actual_objects == set( + actual_objects = list(swh_storage.journal_writer.journal.objects) + expected_objects = list( [("origin", origin1)] - + [("origin_visit", visit) for visit in expected_visits] * 2 - + [("origin_visit_status", ovs) for ovs in expected_visit_statuses] * 2 + + [("origin_visit", visit) for visit in expected_visits] + + [("origin_visit_status", ovs) for ovs in expected_visit_statuses] ) + for obj in expected_objects: + assert obj in actual_objects + def test_origin_visit_add_validation(self, swh_storage): """Unknown origin when adding visits should raise""" visit = OriginVisit( @@ -1818,7 +1820,6 @@ + [("origin_visit_status", ovs) for ovs in expected_visit_statuses] ) - assert len(actual_objects) == len(expected_objects) for obj in expected_objects: assert obj in actual_objects @@ -1872,15 +1873,13 @@ visit_status.pop("type") expected_visit_statuses.append(OriginVisitStatus.from_dict(visit_status)) - # write twice in the journal - expected_visit_statuses += [visit_status1] * 2 + expected_visit_statuses += [visit_status1] expected_objects = ( [("origin", o) for o in expected_origins] + [("origin_visit", v) for v in expected_visits] + [("origin_visit_status", ovs) for ovs in expected_visit_statuses] ) - assert len(actual_objects) == len(expected_objects) for obj in expected_objects: assert obj in actual_objects @@ -2423,7 +2422,7 @@ "snapshot": data.empty_snapshot["id"], } actual_objects = list(swh_storage.journal_writer.journal.objects) - assert actual_objects == [ + expected_objects = [ ("origin", Origin.from_dict(data.origin)), ( "origin_visit", @@ -2433,6 +2432,8 @@ ("snapshot", Snapshot.from_dict(data.empty_snapshot)), ("origin_visit_status", OriginVisitStatus.from_dict(data2),), ] + for obj in expected_objects: + assert obj in actual_objects def test_snapshot_add_get_complete(self, swh_storage): origin_url = data.origin["url"] @@ -2858,7 +2859,7 @@ "snapshot": data.snapshot["id"], } actual_objects = list(swh_storage.journal_writer.journal.objects) - assert actual_objects == [ + expected_objects = [ ("origin", Origin.from_dict(data.origin)), ( "origin_visit", @@ -2866,15 +2867,18 @@ ), ("origin_visit_status", OriginVisitStatus.from_dict(data1)), ("snapshot", Snapshot.from_dict(data.snapshot)), - ("origin_visit_status", OriginVisitStatus.from_dict(data2),), + ("origin_visit_status", OriginVisitStatus.from_dict(data2)), ( "origin_visit", OriginVisit.from_dict({**data3, "type": data.type_visit2}), ), ("origin_visit_status", OriginVisitStatus.from_dict(data3)), - ("origin_visit_status", OriginVisitStatus.from_dict(data4),), + ("origin_visit_status", OriginVisitStatus.from_dict(data4)), ] + for obj in expected_objects: + assert obj in actual_objects + def test_snapshot_get_random(self, swh_storage): swh_storage.snapshot_add( [data.snapshot, data.empty_snapshot, data.complete_snapshot]