diff --git a/swh/journal/serializers.py b/swh/journal/serializers.py --- a/swh/journal/serializers.py +++ b/swh/journal/serializers.py @@ -5,7 +5,7 @@ import datetime from enum import Enum -from typing import Any, Union +from typing import Any, BinaryIO, Union import msgpack @@ -111,3 +111,16 @@ strict_map_key=False, timestamp=3, # convert Timestamp in datetime objects (tz UTC) ) + + +def kafka_stream_to_value(file_like: BinaryIO) -> msgpack.Unpacker: + """Return a deserializer for data stored in kafka""" + return msgpack.Unpacker( + file_like, + raw=False, + object_hook=decode_types_bw, + ext_hook=msgpack_ext_hook, + strict_map_key=False, + use_list=False, + timestamp=3, # convert Timestamp in datetime objects (tz UTC) + ) diff --git a/swh/journal/tests/test_stream.py b/swh/journal/tests/test_stream.py --- a/swh/journal/tests/test_stream.py +++ b/swh/journal/tests/test_stream.py @@ -4,25 +4,16 @@ # See top-level LICENSE file for more information import io +from typing import Dict, List, Tuple -import msgpack - -from swh.journal.serializers import msgpack_ext_hook +from swh.journal.serializers import kafka_stream_to_value from swh.journal.writer import get_journal_writer, model_object_dict_sanitizer +from swh.journal.writer.interface import JournalWriterInterface from swh.model.tests.swh_model_data import TEST_OBJECTS -def test_write_additions_with_test_objects(): - outs = io.BytesIO() - - writer = get_journal_writer( - cls="stream", - value_sanitizer=model_object_dict_sanitizer, - output_stream=outs, - ) +def fill_writer(writer: JournalWriterInterface) -> List[Tuple[str, Dict]]: expected = [] - - n = 0 for object_type, objects in TEST_OBJECTS.items(): writer.write_additions(object_type, objects) @@ -32,18 +23,57 @@ objd.pop("data") expected.append((object_type, objd)) - n += len(objects) + writer.flush() + return expected + + +def test_stream_journal_writer_stream(): + outs = io.BytesIO() + + writer = get_journal_writer( + cls="stream", + value_sanitizer=model_object_dict_sanitizer, + output_stream=outs, + ) + expected = fill_writer(writer) outs.seek(0, 0) - unpacker = msgpack.Unpacker( - outs, - raw=False, - ext_hook=msgpack_ext_hook, - strict_map_key=False, - use_list=False, - timestamp=3, # convert Timestamp in datetime objects (tz UTC) + unpacker = kafka_stream_to_value(outs) + for i, (objtype, objd) in enumerate(unpacker, start=1): + assert (objtype, objd) in expected + assert len(expected) == i + + +def test_stream_journal_writer_filename(tmp_path): + out_fname = str(tmp_path / "journal.msgpack") + + writer = get_journal_writer( + cls="stream", + value_sanitizer=model_object_dict_sanitizer, + output_stream=out_fname, ) + expected = fill_writer(writer) + + with open(out_fname, "rb") as outs: + unpacker = kafka_stream_to_value(outs) + for i, (objtype, objd) in enumerate(unpacker, start=1): + assert (objtype, objd) in expected + assert len(expected) == i + + +def test_stream_journal_writer_stdout(capfdbinary): + writer = get_journal_writer( + cls="stream", + value_sanitizer=model_object_dict_sanitizer, + output_stream="-", + ) + expected = fill_writer(writer) + + captured = capfdbinary.readouterr() + assert captured.err == b"" + outs = io.BytesIO(captured.out) + unpacker = kafka_stream_to_value(outs) for i, (objtype, objd) in enumerate(unpacker, start=1): assert (objtype, objd) in expected assert len(expected) == i diff --git a/swh/journal/writer/__init__.py b/swh/journal/writer/__init__.py --- a/swh/journal/writer/__init__.py +++ b/swh/journal/writer/__init__.py @@ -3,25 +3,12 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Any, Dict, Optional, TypeVar +import os +import sys +from typing import Any, BinaryIO, Dict, Type import warnings -from typing_extensions import Protocol - -from swh.model.model import KeyType - -TSelf = TypeVar("TSelf") - - -class ValueProtocol(Protocol): - def anonymize(self: TSelf) -> Optional[TSelf]: - ... - - def unique_key(self) -> KeyType: - ... - - def to_dict(self) -> Dict[str, Any]: - ... +from .interface import JournalWriterInterface def model_object_dict_sanitizer( @@ -33,7 +20,8 @@ return object_dict -def get_journal_writer(cls, **kwargs): +def get_journal_writer(cls, **kwargs) -> JournalWriterInterface: + if "args" in kwargs: warnings.warn( 'Explicit "args" key is deprecated, use keys directly instead.', @@ -48,14 +36,30 @@ "cls = 'inmemory' is deprecated, use 'memory' instead", DeprecationWarning ) cls = "memory" + + JournalWriter: Type[JournalWriterInterface] if cls == "memory": - from .inmemory import InMemoryJournalWriter as JournalWriter + from .inmemory import InMemoryJournalWriter + + JournalWriter = InMemoryJournalWriter elif cls == "kafka": - from .kafka import KafkaJournalWriter as JournalWriter + from .kafka import KafkaJournalWriter + + JournalWriter = KafkaJournalWriter elif cls == "stream": - from .stream import StreamJournalWriter as JournalWriter + from .stream import StreamJournalWriter + + JournalWriter = StreamJournalWriter assert "output_stream" in kwargs + outstream: BinaryIO + if kwargs["output_stream"] in ("-", b"-"): + outstream = os.fdopen(sys.stdout.fileno(), "wb", closefd=False) + elif isinstance(kwargs["output_stream"], (str, bytes)): + outstream = open(kwargs["output_stream"], "wb") + else: + outstream = kwargs["output_stream"] + kwargs["output_stream"] = outstream else: raise ValueError("Unknown journal writer class `%s`" % cls) diff --git a/swh/journal/writer/inmemory.py b/swh/journal/writer/inmemory.py --- a/swh/journal/writer/inmemory.py +++ b/swh/journal/writer/inmemory.py @@ -5,16 +5,13 @@ import logging from multiprocessing import Manager -from typing import Any, Callable, Dict, Generic, List, Tuple, TypeVar +from typing import Any, Callable, Dict, Generic, Iterable, List, Tuple -from . import ValueProtocol +from .interface import TValue logger = logging.getLogger(__name__) -TValue = TypeVar("TValue", bound=ValueProtocol) - - class InMemoryJournalWriter(Generic[TValue]): objects: List[Tuple[str, TValue]] privileged_objects: List[Tuple[str, TValue]] @@ -39,7 +36,10 @@ write_update = write_addition def write_additions( - self, object_type: str, objects: List[TValue], privileged: bool = False + self, object_type: str, objects: Iterable[TValue], privileged: bool = False ) -> None: for object_ in objects: self.write_addition(object_type, object_, privileged) + + def flush(self) -> None: + pass diff --git a/swh/journal/writer/interface.py b/swh/journal/writer/interface.py new file mode 100644 --- /dev/null +++ b/swh/journal/writer/interface.py @@ -0,0 +1,45 @@ +# Copyright (C) 2015-2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from typing import Any, Dict, Iterable, Optional, TypeVar + +from typing_extensions import Protocol, runtime_checkable + +from swh.model.model import KeyType + +TSelf = TypeVar("TSelf") + + +class ValueProtocol(Protocol): + def anonymize(self: TSelf) -> Optional[TSelf]: + ... + + def unique_key(self) -> KeyType: + ... + + def to_dict(self) -> Dict[str, Any]: + ... + + +TValue = TypeVar("TValue", bound=ValueProtocol) + + +@runtime_checkable +class JournalWriterInterface(Protocol): + def write_addition( + self, object_type: str, object_: TValue, privileged: bool = False + ) -> None: + """Add a SWH object of type object_type in the journal.""" + ... + + def write_additions( + self, object_type: str, objects: Iterable[TValue], privileged: bool = False + ) -> None: + """Add a list of SWH objects of type object_type in the journal.""" + ... + + def flush(self) -> None: + """Flush the pending object additions in the backend, if any.""" + ... diff --git a/swh/journal/writer/kafka.py b/swh/journal/writer/kafka.py --- a/swh/journal/writer/kafka.py +++ b/swh/journal/writer/kafka.py @@ -15,14 +15,13 @@ NamedTuple, Optional, Type, - TypeVar, ) from confluent_kafka import KafkaException, Producer from swh.journal.serializers import KeyType, key_to_kafka, pprint_key, value_to_kafka -from . import ValueProtocol +from .interface import TValue logger = logging.getLogger(__name__) @@ -65,9 +64,6 @@ return f"KafkaDeliveryError({self.message}, [{self.pretty_failures()}])" -TValue = TypeVar("TValue", bound=ValueProtocol) - - class KafkaJournalWriter(Generic[TValue]): """This class is used to write serialized versions of value objects to a series of Kafka topics. The type parameter `TValue`, which must implement the @@ -222,7 +218,7 @@ return KafkaDeliveryError(message, ret) - def flush(self): + def flush(self) -> None: start = time.monotonic() self.producer.flush(self.flush_timeout) @@ -260,14 +256,18 @@ logger.debug("topic: %s, key: %s, value: %s", topic, key, dict_) self.send(topic, key=key, value=dict_) - def write_addition(self, object_type: str, object_: TValue) -> None: + def write_addition( + self, object_type: str, object_: TValue, privileged: bool = False + ) -> None: """Write a single object to the journal""" self._write_addition(object_type, object_) self.flush() write_update = write_addition - def write_additions(self, object_type: str, objects: Iterable[TValue]) -> None: + def write_additions( + self, object_type: str, objects: Iterable[TValue], privileged: bool = False + ) -> None: """Write a set of objects to the journal""" for object_ in objects: self._write_addition(object_type, object_) diff --git a/swh/journal/writer/stream.py b/swh/journal/writer/stream.py --- a/swh/journal/writer/stream.py +++ b/swh/journal/writer/stream.py @@ -4,18 +4,15 @@ # See top-level LICENSE file for more information import logging -from typing import Any, BinaryIO, Callable, Dict, Generic, List, TypeVar +from typing import Any, BinaryIO, Callable, Dict, Generic, Iterable from swh.journal.serializers import value_to_kafka -from . import ValueProtocol +from .interface import TValue logger = logging.getLogger(__name__) -TValue = TypeVar("TValue", bound=ValueProtocol) - - class StreamJournalWriter(Generic[TValue]): """A simple JournalWriter which serializes objects in a stream @@ -41,7 +38,10 @@ write_update = write_addition def write_additions( - self, object_type: str, objects: List[TValue], privileged: bool = False + self, object_type: str, objects: Iterable[TValue], privileged: bool = False ) -> None: for object_ in objects: self.write_addition(object_type, object_, privileged) + + def flush(self) -> None: + self.output.flush()