diff --git a/swh/journal/tests/test_stream.py b/swh/journal/tests/test_stream.py new file mode 100644 --- /dev/null +++ b/swh/journal/tests/test_stream.py @@ -0,0 +1,48 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import io + +import msgpack + +from swh.journal.serializers import msgpack_ext_hook +from swh.journal.writer import model_object_dict_sanitizer +from swh.journal.writer.stream import StreamJournalWriter +from swh.model.model import BaseModel +from swh.model.tests.swh_model_data import TEST_OBJECTS + + +def test_write_additions_with_test_objects(): + outs = io.BytesIO() + writer = StreamJournalWriter[BaseModel]( + value_sanitizer=model_object_dict_sanitizer, output_stream=outs, + ) + expected = [] + + n = 0 + for object_type, objects in TEST_OBJECTS.items(): + writer.write_additions(object_type, objects) + + for object in objects: + objd = object.to_dict() + if object_type == "content": + objd.pop("data") + + expected.append((object_type, objd)) + n += len(objects) + + outs.seek(0, 0) + unpacker = msgpack.Unpacker( + outs, + raw=False, + ext_hook=msgpack_ext_hook, + strict_map_key=False, + use_list=False, + timestamp=3, # convert Timestamp in datetime objects (tz UTC) + ) + + for i, (objtype, objd) in enumerate(unpacker, start=1): + assert (objtype, objd) in expected + assert len(expected) == i diff --git a/swh/journal/writer/__init__.py b/swh/journal/writer/__init__.py --- a/swh/journal/writer/__init__.py +++ b/swh/journal/writer/__init__.py @@ -1,4 +1,4 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -52,6 +52,10 @@ from .inmemory import InMemoryJournalWriter as JournalWriter elif cls == "kafka": from .kafka import KafkaJournalWriter as JournalWriter + elif cls == "stream": + from .stream import StreamJournalWriter as JournalWriter + + assert "output_stream" in kwargs else: raise ValueError("Unknown journal writer class `%s`" % cls) diff --git a/swh/journal/writer/stream.py b/swh/journal/writer/stream.py new file mode 100644 --- /dev/null +++ b/swh/journal/writer/stream.py @@ -0,0 +1,40 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import logging +from typing import Any, Generic, List, TypeVar + +from swh.journal.serializers import value_to_kafka + +from . import ValueProtocol + +logger = logging.getLogger(__name__) + + +TValue = TypeVar("TValue", bound=ValueProtocol) + + +class StreamJournalWriter(Generic[TValue]): + """A simple JournalWriter that serialize objects in a stream""" + + def __init__(self, output_stream, value_sanitizer: Any): + # Share the list of objects across processes, for RemoteAPI tests. + self.output = output_stream + self.value_sanitizer = value_sanitizer + + def write_addition( + self, object_type: str, object_: TValue, privileged: bool = False + ) -> None: + object_.unique_key() # Check this does not error, to mimic the kafka writer + dict_ = self.value_sanitizer(object_type, object_.to_dict()) + self.output.write(value_to_kafka((object_type, dict_))) + + write_update = write_addition + + def write_additions( + self, object_type: str, objects: List[TValue], privileged: bool = False + ) -> None: + for object_ in objects: + self.write_addition(object_type, object_, privileged)