diff --git a/swh/journal/serializers.py b/swh/journal/serializers.py --- a/swh/journal/serializers.py +++ b/swh/journal/serializers.py @@ -111,3 +111,16 @@ strict_map_key=False, timestamp=3, # convert Timestamp in datetime objects (tz UTC) ) + + +def kafka_stream_to_value(stream) -> msgpack.Unpacker: + """Return a deserializer for data stored in kafka""" + return msgpack.Unpacker( + stream, + raw=False, + object_hook=decode_types_bw, + ext_hook=msgpack_ext_hook, + strict_map_key=False, + use_list=False, + timestamp=3, # convert Timestamp in datetime objects (tz UTC) + ) diff --git a/swh/journal/tests/test_stream.py b/swh/journal/tests/test_stream.py --- a/swh/journal/tests/test_stream.py +++ b/swh/journal/tests/test_stream.py @@ -5,24 +5,13 @@ import io -import msgpack - -from swh.journal.serializers import msgpack_ext_hook +from swh.journal.serializers import kafka_stream_to_value from swh.journal.writer import get_journal_writer, model_object_dict_sanitizer from swh.model.tests.swh_model_data import TEST_OBJECTS -def test_write_additions_with_test_objects(): - outs = io.BytesIO() - - writer = get_journal_writer( - cls="stream", - value_sanitizer=model_object_dict_sanitizer, - output_stream=outs, - ) +def fill_writer(writer): expected = [] - - n = 0 for object_type, objects in TEST_OBJECTS.items(): writer.write_additions(object_type, objects) @@ -32,18 +21,57 @@ objd.pop("data") expected.append((object_type, objd)) - n += len(objects) + writer.flush() + return expected + + +def test_stream_journal_writer_stream(): + outs = io.BytesIO() + + writer = get_journal_writer( + cls="stream", + value_sanitizer=model_object_dict_sanitizer, + output_stream=outs, + ) + expected = fill_writer(writer) outs.seek(0, 0) - unpacker = msgpack.Unpacker( - outs, - raw=False, - ext_hook=msgpack_ext_hook, - strict_map_key=False, - use_list=False, - timestamp=3, # convert Timestamp in datetime objects (tz UTC) + unpacker = kafka_stream_to_value(outs) + for i, (objtype, objd) in enumerate(unpacker, start=1): + assert (objtype, objd) in expected + assert len(expected) == i + + +def test_stream_journal_writer_filename(tmp_path): + out_fname = str(tmp_path / "journal.msgpack") + + writer = get_journal_writer( + cls="stream", + value_sanitizer=model_object_dict_sanitizer, + output_stream=out_fname, ) + expected = fill_writer(writer) + + with open(out_fname, "rb") as outs: + unpacker = kafka_stream_to_value(outs) + for i, (objtype, objd) in enumerate(unpacker, start=1): + assert (objtype, objd) in expected + assert len(expected) == i + + +def test_stream_journal_writer_stdout(capfdbinary): + writer = get_journal_writer( + cls="stream", + value_sanitizer=model_object_dict_sanitizer, + output_stream="-", + ) + expected = fill_writer(writer) + + captured = capfdbinary.readouterr() + assert captured.err == b"" + outs = io.BytesIO(captured.out) + unpacker = kafka_stream_to_value(outs) for i, (objtype, objd) in enumerate(unpacker, start=1): assert (objtype, objd) in expected assert len(expected) == i diff --git a/swh/journal/writer/__init__.py b/swh/journal/writer/__init__.py --- a/swh/journal/writer/__init__.py +++ b/swh/journal/writer/__init__.py @@ -3,7 +3,9 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Any, Dict, Optional, TypeVar +import os +import sys +from typing import Any, BinaryIO, Dict, Optional, TypeVar import warnings from typing_extensions import Protocol @@ -56,6 +58,14 @@ from .stream import StreamJournalWriter as JournalWriter assert "output_stream" in kwargs + outstream: BinaryIO + if kwargs["output_stream"] == "-": + outstream = os.fdopen(sys.stdout.fileno(), "wb", closefd=False) + elif isinstance(kwargs["output_stream"], (str, bytes)): + outstream = open(kwargs["output_stream"], "wb") + else: + outstream = kwargs["output_stream"] + kwargs["output_stream"] = outstream else: raise ValueError("Unknown journal writer class `%s`" % cls) diff --git a/swh/journal/writer/stream.py b/swh/journal/writer/stream.py --- a/swh/journal/writer/stream.py +++ b/swh/journal/writer/stream.py @@ -45,3 +45,6 @@ ) -> None: for object_ in objects: self.write_addition(object_type, object_, privileged) + + def flush(self): + self.output.flush()