diff --git a/swh/dataset/exporters/orc.py b/swh/dataset/exporters/orc.py --- a/swh/dataset/exporters/orc.py +++ b/swh/dataset/exporters/orc.py @@ -1,10 +1,12 @@ -# Copyright (C) 2020 The Software Heritage developers +# Copyright (C) 2020-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime +import logging import math +from types import TracebackType from typing import Any, Optional, Tuple, Type, cast from pkg_resources import get_distribution @@ -23,7 +25,7 @@ from pyorc.converters import ORCConverter from swh.dataset.exporter import ExporterDispatch -from swh.dataset.relational import TABLES +from swh.dataset.relational import MAIN_TABLES, TABLES from swh.dataset.utils import remove_pull_requests from swh.model.hashutil import hash_to_hex from swh.model.model import TimestampWithTimezone @@ -48,6 +50,9 @@ } +logger = logging.getLogger(__name__) + + def hash_to_hex_or_none(hash): return hash_to_hex(hash) if hash is not None else None @@ -111,31 +116,87 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.max_rows = self.config.get("max_rows", {}) + if any(table_name not in MAIN_TABLES for table_name in self.max_rows): + raise ValueError( + "Limiting the number of secondary table (%s) is not supported " + "for now.", + [ + table_name + for table_name in self.max_rows + if table_name not in MAIN_TABLES + ], + ) + self._reset() + + def _reset(self): self.writers = {} + self.writer_files = {} self.uuids = {} + self.uuid_main_table = {} + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> Optional[bool]: + for writer in self.writers.values(): + writer.close() + for fileobj in self.writer_files.values(): + fileobj.close() + self._reset() + return super().__exit__(exc_type, exc_value, traceback) + + def maybe_close_writer_for(self, table_name: str): + uuid = self.uuids.get(table_name) + if ( + uuid is not None + and table_name in self.max_rows + and self.writers[table_name].current_row >= self.max_rows[table_name] + ): + main_table = self.uuid_main_table[uuid] + if table_name != main_table: + logger.warning( + "Limiting the number of secondary table (%s) is not supported " + "for now (size limit ignored).", + table_name, + ) + else: + # sync/close all tables having the current uuid (aka main and + # related tables) + for table in [ + tname for tname, tuuid in self.uuids.items() if tuuid == uuid + ]: + # close the writer and remove from the writers dict + self.writers.pop(table).close() + self.writer_files.pop(table).close() + # and clean uuids dicts + self.uuids.pop(table) + self.uuid_main_table.pop(uuid, None) def get_writer_for(self, table_name: str, directory_name=None, unique_id=None): + self.maybe_close_writer_for(table_name) if table_name not in self.writers: if directory_name is None: directory_name = table_name object_type_dir = self.export_path / directory_name object_type_dir.mkdir(exist_ok=True) if unique_id is None: - unique_id = unique_id = self.get_unique_file_id() + unique_id = self.get_unique_file_id() + self.uuid_main_table[unique_id] = table_name export_file = object_type_dir / (f"{table_name}-{unique_id}.orc") - export_obj = self.exit_stack.enter_context(export_file.open("wb")) - self.writers[table_name] = self.exit_stack.enter_context( - Writer( - export_obj, - EXPORT_SCHEMA[table_name], - compression=CompressionKind.ZSTD, - converters={ - TypeKind.TIMESTAMP: cast( - Type[ORCConverter], SWHTimestampConverter - ) - }, - ) + export_obj = export_file.open("wb") + self.writer_files[table_name] = export_obj + self.writers[table_name] = Writer( + export_obj, + EXPORT_SCHEMA[table_name], + compression=CompressionKind.ZSTD, + converters={ + TypeKind.TIMESTAMP: cast(Type[ORCConverter], SWHTimestampConverter) + }, ) + self.writers[table_name].set_user_metadata( swh_object_type=table_name.encode(), swh_uuid=unique_id.encode(), @@ -258,10 +319,7 @@ def process_directory(self, directory): directory_writer = self.get_writer_for("directory") directory_writer.write( - ( - hash_to_hex_or_none(directory["id"]), - directory.get("raw_manifest"), - ) + (hash_to_hex_or_none(directory["id"]), directory.get("raw_manifest"),) ) directory_entry_writer = self.get_writer_for( diff --git a/swh/dataset/relational.py b/swh/dataset/relational.py --- a/swh/dataset/relational.py +++ b/swh/dataset/relational.py @@ -4,7 +4,7 @@ # See top-level LICENSE file for more information # fmt: off -TABLES = { +MAIN_TABLES = { "origin": [ ("url", "string"), ], @@ -25,12 +25,7 @@ "snapshot": [ ("id", "string"), ], - "snapshot_branch": [ - ("snapshot_id", "string"), - ("name", "binary"), - ("target", "string"), - ("target_type", "string"), - ], + # snapshot_branches is in RELATED_TABLES "release": [ ("id", "string"), ("name", "binary"), @@ -58,27 +53,13 @@ ("type", "string"), ("raw_manifest", "binary"), ], - "revision_history": [ - ("id", "string"), - ("parent_id", "string"), - ("parent_rank", "int"), - ], - "revision_extra_headers": [ - ("id", "string"), - ("key", "binary"), - ("value", "binary"), - ], + # revision_history is in RELATED_TABLES + # revision_extra_headers is in RELATED_TABLES "directory": [ ("id", "string"), ("raw_manifest", "binary"), ], - "directory_entry": [ - ("directory_id", "string"), - ("name", "binary"), - ("type", "string"), - ("target", "string"), - ("perms", "int"), - ], + # direcory_entry is in RELATED_TABLES "content": [ ("sha1", "string"), ("sha1_git", "string"), @@ -97,4 +78,33 @@ ("reason", "string"), ], } + +RELATION_TABLES = { + "snapshot_branch": [ + ("snapshot_id", "string"), + ("name", "binary"), + ("target", "string"), + ("target_type", "string"), + ], + "revision_history": [ + ("id", "string"), + ("parent_id", "string"), + ("parent_rank", "int"), + ], + "revision_extra_headers": [ + ("id", "string"), + ("key", "binary"), + ("value", "binary"), + ], + "directory_entry": [ + ("directory_id", "string"), + ("name", "binary"), + ("type", "string"), + ("target", "string"), + ("perms", "int"), + ], +} + +TABLES = MAIN_TABLES | RELATION_TABLES + # fmt: on diff --git a/swh/dataset/test/test_orc.py b/swh/dataset/test/test_orc.py --- a/swh/dataset/test/test_orc.py +++ b/swh/dataset/test/test_orc.py @@ -1,9 +1,16 @@ +# Copyright (C) 2020-2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + import collections from contextlib import contextmanager +import math from pathlib import Path import tempfile import pyorc +import pytest from swh.dataset.exporters.orc import ( ORCExporter, @@ -12,24 +19,35 @@ hash_to_hex_or_none, swh_date_to_tuple, ) +from swh.dataset.relational import MAIN_TABLES, RELATION_TABLES from swh.model.tests.swh_model_data import TEST_OBJECTS @contextmanager -def orc_export(messages, config=None): - with tempfile.TemporaryDirectory() as tmpname: - tmppath = Path(tmpname) +def orc_tmpdir(tmpdir): + if tmpdir: + yield Path(tmpdir) + else: + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@contextmanager +def orc_export(messages, config=None, tmpdir=None): + + with orc_tmpdir(tmpdir) as tmpdir: if config is None: config = {} - with ORCExporter(config, tmppath) as exporter: + with ORCExporter(config, tmpdir) as exporter: for object_type, objects in messages.items(): for obj in objects: exporter.process_object(object_type, obj.to_dict()) - yield tmppath + yield tmpdir def orc_load(rootdir): res = collections.defaultdict(list) + res["rootdir"] = rootdir for obj_type_dir in rootdir.iterdir(): for orc_file in obj_type_dir.iterdir(): with orc_file.open("rb") as orc_obj: @@ -42,8 +60,8 @@ return res -def exporter(messages, config=None): - with orc_export(messages, config) as exportdir: +def exporter(messages, config=None, tmpdir=None): + with orc_export(messages, config, tmpdir) as exportdir: return orc_load(exportdir) @@ -204,3 +222,67 @@ 0, b"-0000", ) + + +# mapping of related tables for each main table (if any) +RELATED = { + "snapshot": ["snapshot_branch"], + "revision": ["revision_history", "revision_extra_headers"], + "directory": ["directory_entry"], +} + + +@pytest.mark.parametrize( + "obj_type", MAIN_TABLES.keys(), +) +@pytest.mark.parametrize("max_rows", (None, 1, 2, 10000)) +def test_export_related_files(max_rows, obj_type, tmpdir): + config = {} + if max_rows is not None: + config["max_rows"] = {obj_type: max_rows} + exporter({obj_type: TEST_OBJECTS[obj_type]}, config=config, tmpdir=tmpdir) + # check there are as many ORC files as objects + orcfiles = [fname for fname in (tmpdir / obj_type).listdir(f"{obj_type}-*.orc")] + if max_rows is None: + assert len(orcfiles) == 1 + else: + assert len(orcfiles) == math.ceil(len(TEST_OBJECTS[obj_type]) / max_rows) + # check the number of related ORC files + for related in RELATED.get(obj_type, ()): + related_orcfiles = [ + fname for fname in (tmpdir / obj_type).listdir(f"{related}-*.orc") + ] + assert len(related_orcfiles) == len(orcfiles) + + # for each ORC file, check related files only reference objects in the + # corresponding main table + for orc_file in orcfiles: + with orc_file.open("rb") as orc_obj: + reader = pyorc.Reader( + orc_obj, converters={pyorc.TypeKind.TIMESTAMP: SWHTimestampConverter}, + ) + uuid = reader.user_metadata["swh_uuid"].decode() + assert orc_file.basename == f"{obj_type}-{uuid}.orc" + rows = list(reader) + obj_ids = [row[0] for row in rows] + + # check the related tables + for related in RELATED.get(obj_type, ()): + orc_file = tmpdir / obj_type / f"{related}-{uuid}.orc" + with orc_file.open("rb") as orc_obj: + reader = pyorc.Reader( + orc_obj, + converters={pyorc.TypeKind.TIMESTAMP: SWHTimestampConverter}, + ) + assert reader.user_metadata["swh_uuid"].decode() == uuid + rows = list(reader) + # check branches in this file only concern current snapshot (obj_id) + for row in rows: + assert row[0] in obj_ids + + +@pytest.mark.parametrize("table_name", RELATION_TABLES.keys()) +def test_export_invalid_max_rows(table_name): + config = {"max_rows": {table_name: 10}} + with pytest.raises(ValueError): + exporter({}, config=config)