diff --git a/swh/journal/tests/test_inmemory.py b/swh/journal/tests/test_inmemory.py --- a/swh/journal/tests/test_inmemory.py +++ b/swh/journal/tests/test_inmemory.py @@ -11,34 +11,22 @@ value_sanitizer=model_object_dict_sanitizer ) expected = [] + priv_expected = [] for object_type, objects in TEST_OBJECTS.items(): writer.write_additions(object_type, objects) for object in objects: - expected.append((object_type, object)) + if object.anonymize(): + expected.append((object_type, object.anonymize())) + priv_expected.append((object_type, object)) + else: + expected.append((object_type, object)) - assert list(writer.privileged_objects) == [] + assert set(priv_expected) == set(writer.privileged_objects) assert set(expected) == set(writer.objects) -def test_write_additions_with_privileged_test_objects(): - writer = InMemoryJournalWriter[BaseModel]( - value_sanitizer=model_object_dict_sanitizer - ) - - expected = [] - - for object_type, objects in TEST_OBJECTS.items(): - writer.write_additions(object_type, objects, True) - - for object in objects: - expected.append((object_type, object)) - - assert list(writer.objects) == [] - assert set(expected) == set(writer.privileged_objects) - - def test_write_addition_errors_without_unique_key(): writer = InMemoryJournalWriter[BaseModel]( value_sanitizer=model_object_dict_sanitizer diff --git a/swh/journal/writer/__init__.py b/swh/journal/writer/__init__.py --- a/swh/journal/writer/__init__.py +++ b/swh/journal/writer/__init__.py @@ -3,25 +3,12 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Any, Dict, Optional, TypeVar +import os +import sys +from typing import Any, Dict, Type import warnings -from typing_extensions import Protocol - -from swh.model.model import KeyType - -TSelf = TypeVar("TSelf") - - -class ValueProtocol(Protocol): - def anonymize(self: TSelf) -> Optional[TSelf]: - ... - - def unique_key(self) -> KeyType: - ... - - def to_dict(self) -> Dict[str, Any]: - ... +from .interface import JournalWriterInterface def model_object_dict_sanitizer( @@ -33,7 +20,8 @@ return object_dict -def get_journal_writer(cls, **kwargs): +def get_journal_writer(cls, **kwargs) -> JournalWriterInterface: + if "args" in kwargs: warnings.warn( 'Explicit "args" key is deprecated, use keys directly instead.', @@ -48,12 +36,20 @@ "cls = 'inmemory' is deprecated, use 'memory' instead", DeprecationWarning ) cls = "memory" + + JournalWriter: Type[JournalWriterInterface] if cls == "memory": - from .inmemory import InMemoryJournalWriter as JournalWriter + from .inmemory import InMemoryJournalWriter + + JournalWriter = InMemoryJournalWriter elif cls == "kafka": - from .kafka import KafkaJournalWriter as JournalWriter + from .kafka import KafkaJournalWriter + + JournalWriter = KafkaJournalWriter elif cls == "stream": - from .stream import StreamJournalWriter as JournalWriter + from .stream import StreamJournalWriter + + JournalWriter = StreamJournalWriter assert "output_stream" in kwargs else: diff --git a/swh/journal/writer/inmemory.py b/swh/journal/writer/inmemory.py --- a/swh/journal/writer/inmemory.py +++ b/swh/journal/writer/inmemory.py @@ -5,16 +5,13 @@ import logging from multiprocessing import Manager -from typing import Any, Callable, Dict, Generic, List, Tuple, TypeVar +from typing import Any, Callable, Dict, Generic, Iterable, List, Tuple -from . import ValueProtocol +from .interface import TValue logger = logging.getLogger(__name__) -TValue = TypeVar("TValue", bound=ValueProtocol) - - class InMemoryJournalWriter(Generic[TValue]): objects: List[Tuple[str, TValue]] privileged_objects: List[Tuple[str, TValue]] @@ -27,19 +24,18 @@ self.objects = self.manager.list() self.privileged_objects = self.manager.list() - def write_addition( - self, object_type: str, object_: TValue, privileged: bool = False - ) -> None: + def write_addition(self, object_type: str, object_: TValue) -> None: object_.unique_key() # Check this does not error, to mimic the kafka writer - if privileged: + anon_object_ = object_.anonymize() + if anon_object_ is not None: self.privileged_objects.append((object_type, object_)) + self.objects.append((object_type, anon_object_)) else: self.objects.append((object_type, object_)) - write_update = write_addition - - def write_additions( - self, object_type: str, objects: List[TValue], privileged: bool = False - ) -> None: + def write_additions(self, object_type: str, objects: Iterable[TValue]) -> None: for object_ in objects: - self.write_addition(object_type, object_, privileged) + self.write_addition(object_type, object_) + + def flush(self) -> None: + pass diff --git a/swh/journal/writer/interface.py b/swh/journal/writer/interface.py new file mode 100644 --- /dev/null +++ b/swh/journal/writer/interface.py @@ -0,0 +1,41 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from typing import Any, Dict, Iterable, Optional, TypeVar + +from typing_extensions import Protocol, runtime_checkable + +from swh.model.model import KeyType + +TSelf = TypeVar("TSelf") + + +class ValueProtocol(Protocol): + def anonymize(self: TSelf) -> Optional[TSelf]: + ... + + def unique_key(self) -> KeyType: + ... + + def to_dict(self) -> Dict[str, Any]: + ... + + +TValue = TypeVar("TValue", bound=ValueProtocol) + + +@runtime_checkable +class JournalWriterInterface(Protocol): + def write_addition(self, object_type: str, object_: TValue) -> None: + """Add a SWH object of type object_type in the journal.""" + ... + + def write_additions(self, object_type: str, objects: Iterable[TValue]) -> None: + """Add a list of SWH objects of type object_type in the journal.""" + ... + + def flush(self) -> None: + """Flush the pending object additions in the backend, if any.""" + ... diff --git a/swh/journal/writer/kafka.py b/swh/journal/writer/kafka.py --- a/swh/journal/writer/kafka.py +++ b/swh/journal/writer/kafka.py @@ -15,14 +15,13 @@ NamedTuple, Optional, Type, - TypeVar, ) from confluent_kafka import KafkaException, Producer from swh.journal.serializers import KeyType, key_to_kafka, pprint_key, value_to_kafka -from . import ValueProtocol +from .interface import TValue logger = logging.getLogger(__name__) @@ -65,9 +64,6 @@ return f"KafkaDeliveryError({self.message}, [{self.pretty_failures()}])" -TValue = TypeVar("TValue", bound=ValueProtocol) - - class KafkaJournalWriter(Generic[TValue]): """This class is used to write serialized versions of value objects to a series of Kafka topics. The type parameter `TValue`, which must implement the @@ -222,7 +218,7 @@ return KafkaDeliveryError(message, ret) - def flush(self): + def flush(self) -> None: start = time.monotonic() self.producer.flush(self.flush_timeout) @@ -265,8 +261,6 @@ self._write_addition(object_type, object_) self.flush() - write_update = write_addition - def write_additions(self, object_type: str, objects: Iterable[TValue]) -> None: """Write a set of objects to the journal""" for object_ in objects: diff --git a/swh/journal/writer/stream.py b/swh/journal/writer/stream.py --- a/swh/journal/writer/stream.py +++ b/swh/journal/writer/stream.py @@ -4,18 +4,15 @@ # See top-level LICENSE file for more information import logging -from typing import Any, BinaryIO, Callable, Dict, Generic, List, TypeVar +from typing import Any, BinaryIO, Callable, Dict, Generic, Iterable from swh.journal.serializers import value_to_kafka -from . import ValueProtocol +from .interface import TValue logger = logging.getLogger(__name__) -TValue = TypeVar("TValue", bound=ValueProtocol) - - class StreamJournalWriter(Generic[TValue]): """A simple JournalWriter which serializes objects in a stream @@ -31,17 +28,14 @@ self.output = output_stream self.value_sanitizer = value_sanitizer - def write_addition( - self, object_type: str, object_: TValue, privileged: bool = False - ) -> None: + def write_addition(self, object_type: str, object_: TValue) -> None: object_.unique_key() # Check this does not error, to mimic the kafka writer dict_ = self.value_sanitizer(object_type, object_.to_dict()) self.output.write(value_to_kafka((object_type, dict_))) - write_update = write_addition - - def write_additions( - self, object_type: str, objects: List[TValue], privileged: bool = False - ) -> None: + def write_additions(self, object_type: str, objects: Iterable[TValue]) -> None: for object_ in objects: - self.write_addition(object_type, object_, privileged) + self.write_addition(object_type, object_) + + def flush(self) -> None: + self.output.flush()