diff --git a/swh/dataset/exporters/orc.py b/swh/dataset/exporters/orc.py --- a/swh/dataset/exporters/orc.py +++ b/swh/dataset/exporters/orc.py @@ -4,7 +4,9 @@ # See top-level LICENSE file for more information from datetime import datetime +import logging import math +from types import TracebackType from typing import Any, Optional, Tuple, Type, cast from pkg_resources import get_distribution @@ -48,6 +50,9 @@ } +logger = logging.getLogger(__name__) + + def hash_to_hex_or_none(hash): return hash_to_hex(hash) if hash is not None else None @@ -111,31 +116,77 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.max_rows = self.config.get("max-rows", {}) + self._reset() + + def _reset(self): self.writers = {} + self.writer_files = {} self.uuids = {} + self.uuid_main_table = {} + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> Optional[bool]: + for writer in self.writers.values(): + writer.close() + for fileobj in self.writer_files.values(): + fileobj.close() + self._reset() + return super().__exit__(exc_type, exc_value, traceback) + + def maybe_close_writer_for(self, table_name: str): + uuid = self.uuids.get(table_name) + if ( + uuid is not None + and table_name in self.max_rows + and self.writers[table_name].current_row >= self.max_rows[table_name] + ): + main_table = self.uuid_main_table[uuid] + if table_name != main_table: + logger.warning( + "Limiting the number of secondary table (%s) is not supported " + "for now (size limit ignored).", + table_name, + ) + else: + # sync/close all tables having the current uuid (aka main and + # related tables) + for table in [ + tname for tname, tuuid in self.uuids.items() if tuuid == uuid + ]: + # close the writer and remove from the writers dict + self.writers.pop(table).close() + self.writer_files.pop(table).close() + # and clean uuids dicts + self.uuids.pop(table) + self.uuid_main_table.pop(uuid, None) def get_writer_for(self, table_name: str, directory_name=None, unique_id=None): + self.maybe_close_writer_for(table_name) if table_name not in self.writers: if directory_name is None: directory_name = table_name object_type_dir = self.export_path / directory_name object_type_dir.mkdir(exist_ok=True) if unique_id is None: - unique_id = unique_id = self.get_unique_file_id() + unique_id = self.get_unique_file_id() + self.uuid_main_table[unique_id] = table_name export_file = object_type_dir / (f"{table_name}-{unique_id}.orc") - export_obj = self.exit_stack.enter_context(export_file.open("wb")) - self.writers[table_name] = self.exit_stack.enter_context( - Writer( - export_obj, - EXPORT_SCHEMA[table_name], - compression=CompressionKind.ZSTD, - converters={ - TypeKind.TIMESTAMP: cast( - Type[ORCConverter], SWHTimestampConverter - ) - }, - ) + export_obj = export_file.open("wb") + self.writer_files[table_name] = export_obj + self.writers[table_name] = Writer( + export_obj, + EXPORT_SCHEMA[table_name], + compression=CompressionKind.ZSTD, + converters={ + TypeKind.TIMESTAMP: cast(Type[ORCConverter], SWHTimestampConverter) + }, ) + self.writers[table_name].set_user_metadata( swh_object_type=table_name.encode(), swh_uuid=unique_id.encode(), @@ -258,10 +309,7 @@ def process_directory(self, directory): directory_writer = self.get_writer_for("directory") directory_writer.write( - ( - hash_to_hex_or_none(directory["id"]), - directory.get("raw_manifest"), - ) + (hash_to_hex_or_none(directory["id"]), directory.get("raw_manifest"),) ) directory_entry_writer = self.get_writer_for( diff --git a/swh/dataset/test/test_orc.py b/swh/dataset/test/test_orc.py --- a/swh/dataset/test/test_orc.py +++ b/swh/dataset/test/test_orc.py @@ -1,9 +1,11 @@ import collections from contextlib import contextmanager +import math from pathlib import Path import tempfile import pyorc +import pytest from swh.dataset.exporters.orc import ( ORCExporter, @@ -16,20 +18,30 @@ @contextmanager -def orc_export(messages, config=None): - with tempfile.TemporaryDirectory() as tmpname: - tmppath = Path(tmpname) +def orc_tmpdir(tmpdir): + if tmpdir: + yield Path(tmpdir) + else: + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@contextmanager +def orc_export(messages, config=None, tmpdir=None): + + with orc_tmpdir(tmpdir) as tmpdir: if config is None: config = {} - with ORCExporter(config, tmppath) as exporter: + with ORCExporter(config, tmpdir) as exporter: for object_type, objects in messages.items(): for obj in objects: exporter.process_object(object_type, obj.to_dict()) - yield tmppath + yield tmpdir def orc_load(rootdir): res = collections.defaultdict(list) + res["rootdir"] = rootdir for obj_type_dir in rootdir.iterdir(): for orc_file in obj_type_dir.iterdir(): with orc_file.open("rb") as orc_obj: @@ -42,8 +54,8 @@ return res -def exporter(messages, config=None): - with orc_export(messages, config) as exportdir: +def exporter(messages, config=None, tmpdir=None): + with orc_export(messages, config, tmpdir) as exportdir: return orc_load(exportdir) @@ -204,3 +216,71 @@ 0, b"-0000", ) + + +# mapping of related tables for each main table (if any) +RELATED = { + "snapshot": ["snapshot_branch"], + "revision": ["revision_history", "revision_extra_headers"], + "directory": ["directory_entry"], +} + + +@pytest.mark.parametrize( + "obj_type", + ( + "origin", + "origin_visit", + "origin_visit_status", + "snapshot", + "release", + "revision", + "directory", + "content", + "skipped_content", + ), +) +@pytest.mark.parametrize("max_rows", (None, 1, 2, 10000)) +def test_export_related_files(max_rows, obj_type, tmpdir): + config = {} + if max_rows is not None: + config["max-rows"] = {obj_type: max_rows} + exporter({obj_type: TEST_OBJECTS[obj_type]}, config=config, tmpdir=tmpdir) + # check there are as many ORC files as objects + orcfiles = [fname for fname in (tmpdir / obj_type).listdir(f"{obj_type}-*.orc")] + if max_rows is None: + assert len(orcfiles) == 1 + else: + assert len(orcfiles) == math.ceil(len(TEST_OBJECTS[obj_type]) / max_rows) + # check the number of related ORC files + for related in RELATED.get(obj_type, ()): + related_orcfiles = [ + fname for fname in (tmpdir / obj_type).listdir(f"{related}-*.orc") + ] + assert len(related_orcfiles) == len(orcfiles) + + # for each ORC file, check related files only reference objects in the + # corresponding main table + for orc_file in orcfiles: + with orc_file.open("rb") as orc_obj: + reader = pyorc.Reader( + orc_obj, converters={pyorc.TypeKind.TIMESTAMP: SWHTimestampConverter}, + ) + uuid = reader.user_metadata["swh_uuid"].decode() + assert orc_file.basename == f"{obj_type}-{uuid}.orc" + rows = list(reader) + obj_ids = [row[0] for row in rows] + + # check the related tables + for related in RELATED.get(obj_type, ()): + orc_file = tmpdir / obj_type / f"{related}-{uuid}.orc" + with orc_file.open("rb") as orc_obj: + reader = pyorc.Reader( + orc_obj, + converters={pyorc.TypeKind.TIMESTAMP: SWHTimestampConverter}, + ) + assert reader.user_metadata["swh_uuid"].decode() == uuid + rows = list(reader) + # check branches in this file only concern current snapshot (obj_id) + for row in rows: + assert row[0] in obj_ids