diff --git a/swh/journal/tests/test_stream.py b/swh/journal/tests/test_stream.py new file mode 100644 index 0000000..c9bfc90 --- /dev/null +++ b/swh/journal/tests/test_stream.py @@ -0,0 +1,47 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import io + +import msgpack + +from swh.journal.serializers import msgpack_ext_hook +from swh.journal.writer import get_journal_writer, model_object_dict_sanitizer +from swh.model.tests.swh_model_data import TEST_OBJECTS + + +def test_write_additions_with_test_objects(): + outs = io.BytesIO() + + writer = get_journal_writer( + cls="stream", value_sanitizer=model_object_dict_sanitizer, output_stream=outs, + ) + expected = [] + + n = 0 + for object_type, objects in TEST_OBJECTS.items(): + writer.write_additions(object_type, objects) + + for object in objects: + objd = object.to_dict() + if object_type == "content": + objd.pop("data") + + expected.append((object_type, objd)) + n += len(objects) + + outs.seek(0, 0) + unpacker = msgpack.Unpacker( + outs, + raw=False, + ext_hook=msgpack_ext_hook, + strict_map_key=False, + use_list=False, + timestamp=3, # convert Timestamp in datetime objects (tz UTC) + ) + + for i, (objtype, objd) in enumerate(unpacker, start=1): + assert (objtype, objd) in expected + assert len(expected) == i diff --git a/swh/journal/writer/__init__.py b/swh/journal/writer/__init__.py index 92879e6..662fa80 100644 --- a/swh/journal/writer/__init__.py +++ b/swh/journal/writer/__init__.py @@ -1,58 +1,62 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from typing import Any, Dict, Optional, TypeVar import warnings from typing_extensions import Protocol from swh.model.model import KeyType TSelf = TypeVar("TSelf") class ValueProtocol(Protocol): def anonymize(self: TSelf) -> Optional[TSelf]: ... def unique_key(self) -> KeyType: ... def to_dict(self) -> Dict[str, Any]: ... def model_object_dict_sanitizer( object_type: str, object_dict: Dict[str, Any] ) -> Dict[str, str]: object_dict = object_dict.copy() if object_type == "content": object_dict.pop("data", None) return object_dict def get_journal_writer(cls, **kwargs): if "args" in kwargs: warnings.warn( 'Explicit "args" key is deprecated, use keys directly instead.', DeprecationWarning, ) kwargs = kwargs["args"] kwargs.setdefault("value_sanitizer", model_object_dict_sanitizer) if cls == "inmemory": # FIXME: Remove inmemory in due time warnings.warn( "cls = 'inmemory' is deprecated, use 'memory' instead", DeprecationWarning ) cls = "memory" if cls == "memory": from .inmemory import InMemoryJournalWriter as JournalWriter elif cls == "kafka": from .kafka import KafkaJournalWriter as JournalWriter + elif cls == "stream": + from .stream import StreamJournalWriter as JournalWriter + + assert "output_stream" in kwargs else: raise ValueError("Unknown journal writer class `%s`" % cls) return JournalWriter(**kwargs) diff --git a/swh/journal/writer/stream.py b/swh/journal/writer/stream.py new file mode 100644 index 0000000..202e13c --- /dev/null +++ b/swh/journal/writer/stream.py @@ -0,0 +1,47 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import logging +from typing import Any, BinaryIO, Callable, Dict, Generic, List, TypeVar + +from swh.journal.serializers import value_to_kafka + +from . import ValueProtocol + +logger = logging.getLogger(__name__) + + +TValue = TypeVar("TValue", bound=ValueProtocol) + + +class StreamJournalWriter(Generic[TValue]): + """A simple JournalWriter which serializes objects in a stream + + Might be used to serialize a storage in a file to generate a test dataset. + """ + + def __init__( + self, + output_stream: BinaryIO, + value_sanitizer: Callable[[str, Dict[str, Any]], Dict[str, Any]], + ): + # Share the list of objects across processes, for RemoteAPI tests. + self.output = output_stream + self.value_sanitizer = value_sanitizer + + def write_addition( + self, object_type: str, object_: TValue, privileged: bool = False + ) -> None: + object_.unique_key() # Check this does not error, to mimic the kafka writer + dict_ = self.value_sanitizer(object_type, object_.to_dict()) + self.output.write(value_to_kafka((object_type, dict_))) + + write_update = write_addition + + def write_additions( + self, object_type: str, objects: List[TValue], privileged: bool = False + ) -> None: + for object_ in objects: + self.write_addition(object_type, object_, privileged)