diff --git a/swh/dataset/test/test_orc.py b/swh/dataset/test/test_orc.py --- a/swh/dataset/test/test_orc.py +++ b/swh/dataset/test/test_orc.py @@ -1,9 +1,9 @@ import collections +from contextlib import contextmanager from pathlib import Path import tempfile import pyorc -import pytest from swh.dataset.exporters.orc import ( ORCExporter, @@ -13,42 +13,49 @@ from swh.model.tests.swh_model_data import TEST_OBJECTS -@pytest.fixture -def exporter(): - def wrapped(messages, config=None): - with tempfile.TemporaryDirectory() as tmpname: - tmppath = Path(tmpname) - if config is None: - config = {} - with ORCExporter(config, tmppath) as exporter: - for object_type, objects in messages.items(): - for obj in objects: - exporter.process_object(object_type, obj.to_dict()) - res = collections.defaultdict(set) - for obj_type_dir in tmppath.iterdir(): - for orc_file in obj_type_dir.iterdir(): - with orc_file.open("rb") as orc_obj: - res[obj_type_dir.name] |= set(pyorc.Reader(orc_obj)) - return res - - return wrapped - - -def test_export_origin(exporter): +@contextmanager +def orc_export(messages, config=None): + with tempfile.TemporaryDirectory() as tmpname: + tmppath = Path(tmpname) + if config is None: + config = {} + with ORCExporter(config, tmppath) as exporter: + for object_type, objects in messages.items(): + for obj in objects: + exporter.process_object(object_type, obj.to_dict()) + yield tmppath + + +def orc_load(rootdir): + res = collections.defaultdict(set) + for obj_type_dir in rootdir.iterdir(): + for orc_file in obj_type_dir.iterdir(): + with orc_file.open("rb") as orc_obj: + reader = pyorc.Reader(orc_obj) + res[obj_type_dir.name] |= set(reader) + return res + + +def exporter(messages, config=None): + with orc_export(messages, config) as exportdir: + return orc_load(exportdir) + + +def test_export_origin(): obj_type = "origin" output = exporter({obj_type: TEST_OBJECTS[obj_type]}) for obj in TEST_OBJECTS[obj_type]: assert (obj.url,) in output[obj_type] -def test_export_origin_visit(exporter): +def test_export_origin_visit(): obj_type = "origin_visit" output = exporter({obj_type: TEST_OBJECTS[obj_type]}) for obj in TEST_OBJECTS[obj_type]: assert (obj.origin, obj.visit, obj.date, obj.type) in output[obj_type] -def test_export_origin_visit_status(exporter): +def test_export_origin_visit_status(): obj_type = "origin_visit_status" output = exporter({obj_type: TEST_OBJECTS[obj_type]}) for obj in TEST_OBJECTS[obj_type]: @@ -61,7 +68,7 @@ ) in output[obj_type] -def test_export_snapshot(exporter): +def test_export_snapshot(): obj_type = "snapshot" output = exporter({obj_type: TEST_OBJECTS[obj_type]}) for obj in TEST_OBJECTS[obj_type]: @@ -77,7 +84,7 @@ ) in output["snapshot_branch"] -def test_export_release(exporter): +def test_export_release(): obj_type = "release" output = exporter({obj_type: TEST_OBJECTS[obj_type]}) for obj in TEST_OBJECTS[obj_type]: @@ -92,7 +99,7 @@ ) in output[obj_type] -def test_export_revision(exporter): +def test_export_revision(): obj_type = "revision" output = exporter({obj_type: TEST_OBJECTS[obj_type]}) for obj in TEST_OBJECTS[obj_type]: @@ -115,7 +122,7 @@ ) in output["revision_history"] -def test_export_directory(exporter): +def test_export_directory(): obj_type = "directory" output = exporter({obj_type: TEST_OBJECTS[obj_type]}) for obj in TEST_OBJECTS[obj_type]: @@ -130,7 +137,7 @@ ) in output["directory_entry"] -def test_export_content(exporter): +def test_export_content(): obj_type = "content" output = exporter({obj_type: TEST_OBJECTS[obj_type]}) for obj in TEST_OBJECTS[obj_type]: @@ -144,7 +151,7 @@ ) in output[obj_type] -def test_export_skipped_content(exporter): +def test_export_skipped_content(): obj_type = "skipped_content" output = exporter({obj_type: TEST_OBJECTS[obj_type]}) for obj in TEST_OBJECTS[obj_type]: