diff --git a/swh/dataset/exporters/orc.py b/swh/dataset/exporters/orc.py --- a/swh/dataset/exporters/orc.py +++ b/swh/dataset/exporters/orc.py @@ -119,14 +119,13 @@ config = self.config.get("orc", {}) self.max_rows = config.get("max_rows", {}) invalid_tables = [ - table_name for table_name in self.max_rows - if table_name not in MAIN_TABLES + table_name for table_name in self.max_rows if table_name not in MAIN_TABLES ] if invalid_tables: raise ValueError( "Limiting the number of secondary table (%s) is not supported " "for now.", - invalid_tables + invalid_tables, ) self._reset() @@ -176,12 +175,10 @@ self.uuids.pop(table) self.uuid_main_table.pop(uuid, None) - def get_writer_for(self, table_name: str, directory_name=None, unique_id=None): + def get_writer_for(self, table_name: str, unique_id=None): self.maybe_close_writer_for(table_name) if table_name not in self.writers: - if directory_name is None: - directory_name = table_name - object_type_dir = self.export_path / directory_name + object_type_dir = self.export_path / table_name object_type_dir.mkdir(exist_ok=True) if unique_id is None: unique_id = self.get_unique_file_id() @@ -246,9 +243,7 @@ # we want to store branches in the same directory as snapshot objects, # and have both files have the same UUID. snapshot_branch_writer = self.get_writer_for( - "snapshot_branch", - directory_name="snapshot", - unique_id=self.uuids["snapshot"], + "snapshot_branch", unique_id=self.uuids["snapshot"], ) for branch_name, branch in snapshot["branches"].items(): if branch is None: @@ -294,9 +289,7 @@ ) revision_history_writer = self.get_writer_for( - "revision_history", - directory_name="revision", - unique_id=self.uuids["revision"], + "revision_history", unique_id=self.uuids["revision"], ) for i, parent_id in enumerate(revision["parents"]): revision_history_writer.write( @@ -308,9 +301,7 @@ ) revision_header_writer = self.get_writer_for( - "revision_extra_headers", - directory_name="revision", - unique_id=self.uuids["revision"], + "revision_extra_headers", unique_id=self.uuids["revision"], ) for key, value in revision["extra_headers"]: revision_header_writer.write( @@ -324,9 +315,7 @@ ) directory_entry_writer = self.get_writer_for( - "directory_entry", - directory_name="directory", - unique_id=self.uuids["directory"], + "directory_entry", unique_id=self.uuids["directory"], ) for entry in directory["entries"]: directory_entry_writer.write( diff --git a/swh/dataset/test/test_orc.py b/swh/dataset/test/test_orc.py --- a/swh/dataset/test/test_orc.py +++ b/swh/dataset/test/test_orc.py @@ -237,9 +237,9 @@ ) @pytest.mark.parametrize("max_rows", (None, 1, 2, 10000)) def test_export_related_files(max_rows, obj_type, tmpdir): - config = {} + config = {"orc": {}} if max_rows is not None: - config["orc"] = {"max_rows": {obj_type: max_rows}} + config["orc"]["max_rows"] = {obj_type: max_rows} exporter({obj_type: TEST_OBJECTS[obj_type]}, config=config, tmpdir=tmpdir) # check there are as many ORC files as objects orcfiles = [fname for fname in (tmpdir / obj_type).listdir(f"{obj_type}-*.orc")] @@ -250,7 +250,7 @@ # check the number of related ORC files for related in RELATED.get(obj_type, ()): related_orcfiles = [ - fname for fname in (tmpdir / obj_type).listdir(f"{related}-*.orc") + fname for fname in (tmpdir / related).listdir(f"{related}-*.orc") ] assert len(related_orcfiles) == len(orcfiles) @@ -268,7 +268,7 @@ # check the related tables for related in RELATED.get(obj_type, ()): - orc_file = tmpdir / obj_type / f"{related}-{uuid}.orc" + orc_file = tmpdir / related / f"{related}-{uuid}.orc" with orc_file.open("rb") as orc_obj: reader = pyorc.Reader( orc_obj, @@ -281,6 +281,22 @@ assert row[0] in obj_ids +@pytest.mark.parametrize( + "obj_type", MAIN_TABLES.keys(), +) +def test_export_related_files_separated(obj_type, tmpdir): + exporter({obj_type: TEST_OBJECTS[obj_type]}, tmpdir=tmpdir) + # check there are as many ORC files as objects + orcfiles = [fname for fname in (tmpdir / obj_type).listdir(f"{obj_type}-*.orc")] + assert len(orcfiles) == 1 + # check related ORC files are in their own directory + for related in RELATED.get(obj_type, ()): + related_orcfiles = [ + fname for fname in (tmpdir / related).listdir(f"{related}-*.orc") + ] + assert len(related_orcfiles) == len(orcfiles) + + @pytest.mark.parametrize("table_name", RELATION_TABLES.keys()) def test_export_invalid_max_rows(table_name): config = {"orc": {"max_rows": {table_name: 10}}}