diff --git a/swh/dataset/exporters/orc.py b/swh/dataset/exporters/orc.py --- a/swh/dataset/exporters/orc.py +++ b/swh/dataset/exporters/orc.py @@ -117,6 +117,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) config = self.config.get("orc", {}) + self.group_tables = kwargs.get( + "group_tables", config.get("group_tables", False) + ) self.max_rows = config.get("max_rows", {}) invalid_tables = [ table_name for table_name in self.max_rows @@ -179,7 +182,7 @@ def get_writer_for(self, table_name: str, directory_name=None, unique_id=None): self.maybe_close_writer_for(table_name) if table_name not in self.writers: - if directory_name is None: + if not self.group_tables or directory_name is None: directory_name = table_name object_type_dir = self.export_path / directory_name object_type_dir.mkdir(exist_ok=True) diff --git a/swh/dataset/test/test_orc.py b/swh/dataset/test/test_orc.py --- a/swh/dataset/test/test_orc.py +++ b/swh/dataset/test/test_orc.py @@ -237,9 +237,9 @@ ) @pytest.mark.parametrize("max_rows", (None, 1, 2, 10000)) def test_export_related_files(max_rows, obj_type, tmpdir): - config = {} + config = {"orc": {"group_tables": True}} if max_rows is not None: - config["orc"] = {"max_rows": {obj_type: max_rows}} + config["orc"]["max_rows"] = {obj_type: max_rows} exporter({obj_type: TEST_OBJECTS[obj_type]}, config=config, tmpdir=tmpdir) # check there are as many ORC files as objects orcfiles = [fname for fname in (tmpdir / obj_type).listdir(f"{obj_type}-*.orc")] @@ -281,6 +281,23 @@ assert row[0] in obj_ids +@pytest.mark.parametrize( + "obj_type", MAIN_TABLES.keys(), +) +def test_export_related_files_separated(obj_type, tmpdir): + config = {} + exporter({obj_type: TEST_OBJECTS[obj_type]}, config=config, tmpdir=tmpdir) + # check there are as many ORC files as objects + orcfiles = [fname for fname in (tmpdir / obj_type).listdir(f"{obj_type}-*.orc")] + assert len(orcfiles) == 1 + # check related ORC files are in their own directory + for related in RELATED.get(obj_type, ()): + related_orcfiles = [ + fname for fname in (tmpdir / related).listdir(f"{related}-*.orc") + ] + assert len(related_orcfiles) == len(orcfiles) + + @pytest.mark.parametrize("table_name", RELATION_TABLES.keys()) def test_export_invalid_max_rows(table_name): config = {"orc": {"max_rows": {table_name: 10}}}