diff --git a/swh/journal/tests/test_stream.py b/swh/journal/tests/test_stream.py index 4991a4f..fb74505 100644 --- a/swh/journal/tests/test_stream.py +++ b/swh/journal/tests/test_stream.py @@ -1,43 +1,79 @@ # Copyright (C) 2021-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import io from typing import Dict, List, Tuple from swh.journal.serializers import kafka_stream_to_value from swh.journal.writer import get_journal_writer, model_object_dict_sanitizer from swh.journal.writer.interface import JournalWriterInterface from swh.model.tests.swh_model_data import TEST_OBJECTS def fill_writer(writer: JournalWriterInterface) -> List[Tuple[str, Dict]]: expected = [] for object_type, objects in TEST_OBJECTS.items(): writer.write_additions(object_type, objects) for object in objects: objd = object.to_dict() if object_type == "content": objd.pop("data") expected.append((object_type, objd)) + writer.flush() return expected def test_stream_journal_writer_stream(): outs = io.BytesIO() writer = get_journal_writer( cls="stream", value_sanitizer=model_object_dict_sanitizer, output_stream=outs, ) expected = fill_writer(writer) outs.seek(0, 0) unpacker = kafka_stream_to_value(outs) for i, (objtype, objd) in enumerate(unpacker, start=1): assert (objtype, objd) in expected assert len(expected) == i + + +def test_stream_journal_writer_filename(tmp_path): + out_fname = str(tmp_path / "journal.msgpack") + + writer = get_journal_writer( + cls="stream", + value_sanitizer=model_object_dict_sanitizer, + output_stream=out_fname, + ) + expected = fill_writer(writer) + + with open(out_fname, "rb") as outs: + unpacker = kafka_stream_to_value(outs) + for i, (objtype, objd) in enumerate(unpacker, start=1): + assert (objtype, objd) in expected + assert len(expected) == i + + +def test_stream_journal_writer_stdout(capfdbinary): + writer = get_journal_writer( + cls="stream", + value_sanitizer=model_object_dict_sanitizer, + output_stream="-", + ) + expected = fill_writer(writer) + + captured = capfdbinary.readouterr() + assert captured.err == b"" + outs = io.BytesIO(captured.out) + + unpacker = kafka_stream_to_value(outs) + for i, (objtype, objd) in enumerate(unpacker, start=1): + assert (objtype, objd) in expected + assert len(expected) == i diff --git a/swh/journal/writer/__init__.py b/swh/journal/writer/__init__.py index 13d44f6..bf82eab 100644 --- a/swh/journal/writer/__init__.py +++ b/swh/journal/writer/__init__.py @@ -1,56 +1,66 @@ # Copyright (C) 2019-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Any, Dict, Type +import os +import sys +from typing import Any, BinaryIO, Dict, Type import warnings from .interface import JournalWriterInterface def model_object_dict_sanitizer( object_type: str, object_dict: Dict[str, Any] ) -> Dict[str, str]: object_dict = object_dict.copy() if object_type == "content": object_dict.pop("data", None) return object_dict def get_journal_writer(cls, **kwargs) -> JournalWriterInterface: if "args" in kwargs: warnings.warn( 'Explicit "args" key is deprecated, use keys directly instead.', DeprecationWarning, ) kwargs = kwargs["args"] kwargs.setdefault("value_sanitizer", model_object_dict_sanitizer) if cls == "inmemory": # FIXME: Remove inmemory in due time warnings.warn( "cls = 'inmemory' is deprecated, use 'memory' instead", DeprecationWarning ) cls = "memory" JournalWriter: Type[JournalWriterInterface] if cls == "memory": from .inmemory import InMemoryJournalWriter JournalWriter = InMemoryJournalWriter elif cls == "kafka": from .kafka import KafkaJournalWriter JournalWriter = KafkaJournalWriter elif cls == "stream": from .stream import StreamJournalWriter JournalWriter = StreamJournalWriter assert "output_stream" in kwargs + outstream: BinaryIO + if kwargs["output_stream"] in ("-", b"-"): + outstream = os.fdopen(sys.stdout.fileno(), "wb", closefd=False) + elif isinstance(kwargs["output_stream"], (str, bytes)): + outstream = open(kwargs["output_stream"], "wb") + else: + outstream = kwargs["output_stream"] + kwargs["output_stream"] = outstream else: raise ValueError("Unknown journal writer class `%s`" % cls) return JournalWriter(**kwargs)