diff --git a/swh/journal/tests/test_kafka_writer.py b/swh/journal/tests/test_kafka_writer.py --- a/swh/journal/tests/test_kafka_writer.py +++ b/swh/journal/tests/test_kafka_writer.py @@ -10,6 +10,7 @@ from swh.journal.pytest_plugin import assert_all_objects_consumed, consume_messages from swh.journal.tests.journal_data import TEST_OBJECTS +from swh.journal.writer import model_object_dict_sanitizer from swh.journal.writer.kafka import KafkaDeliveryError, KafkaJournalWriter from swh.model.model import Directory, Release, Revision @@ -24,6 +25,7 @@ brokers=[kafka_server], client_id="kafka_writer", prefix=kafka_prefix, + value_sanitizer=model_object_dict_sanitizer, anonymize=False, ) @@ -64,6 +66,7 @@ brokers=[kafka_server], client_id="kafka_writer", prefix=kafka_prefix, + value_sanitizer=model_object_dict_sanitizer, anonymize=True, ) @@ -117,7 +120,10 @@ kafka_prefix += ".swh.journal.objects" writer = KafkaJournalWriterFailDelivery( - brokers=[kafka_server], client_id="kafka_writer", prefix=kafka_prefix, + brokers=[kafka_server], + client_id="kafka_writer", + prefix=kafka_prefix, + value_sanitizer=model_object_dict_sanitizer, ) empty_dir = Directory(entries=()) @@ -148,6 +154,7 @@ brokers=[kafka_server], client_id="kafka_writer", prefix=kafka_prefix, + value_sanitizer=model_object_dict_sanitizer, flush_timeout=1, producer_class=MockProducer, ) diff --git a/swh/journal/writer/__init__.py b/swh/journal/writer/__init__.py --- a/swh/journal/writer/__init__.py +++ b/swh/journal/writer/__init__.py @@ -3,9 +3,19 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from typing import Any, Dict import warnings +def model_object_dict_sanitizer( + object_type: str, object_dict: Dict[str, Any] +) -> Dict[str, str]: + object_dict = object_dict.copy() + if object_type == "content": + object_dict.pop("data", None) + return object_dict + + def get_journal_writer(cls, **kwargs): if "args" in kwargs: warnings.warn( @@ -14,6 +24,8 @@ ) kwargs = kwargs["args"] + kwargs.setdefault("value_sanitizer", model_object_dict_sanitizer) + if cls == "inmemory": # FIXME: Remove inmemory in due time warnings.warn( "cls = 'inmemory' is deprecated, use 'memory' instead", DeprecationWarning diff --git a/swh/journal/writer/inmemory.py b/swh/journal/writer/inmemory.py --- a/swh/journal/writer/inmemory.py +++ b/swh/journal/writer/inmemory.py @@ -5,7 +5,7 @@ import logging from multiprocessing import Manager -from typing import List +from typing import Any, List, Tuple from swh.journal.serializers import ModelObject from swh.model.model import BaseModel @@ -14,7 +14,10 @@ class InMemoryJournalWriter: - def __init__(self): + objects: List[Tuple[str, ModelObject]] + privileged_objects: List[Tuple[str, ModelObject]] + + def __init__(self, value_sanitizer: Any): # Share the list of objects across processes, for RemoteAPI tests. self.manager = Manager() self.objects = self.manager.list() diff --git a/swh/journal/writer/kafka.py b/swh/journal/writer/kafka.py --- a/swh/journal/writer/kafka.py +++ b/swh/journal/writer/kafka.py @@ -5,7 +5,7 @@ import logging import time -from typing import Dict, Iterable, List, NamedTuple, Optional, Type +from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Type from confluent_kafka import KafkaException, Producer @@ -95,6 +95,7 @@ brokers: Iterable[str], prefix: str, client_id: str, + value_sanitizer: Callable[[str, Dict[str, Any]], Dict[str, Any]], producer_config: Optional[Dict] = None, flush_timeout: float = 120, producer_class: Type[Producer] = Producer, @@ -134,6 +135,8 @@ # List of (object_type, key, error_msg, error_name) for failed deliveries self.delivery_failures: List[DeliveryFailureInfo] = [] + self.value_sanitizer = value_sanitizer + def _error_cb(self, error): if error.fatal(): raise KafkaException(error) @@ -195,14 +198,6 @@ elif self.delivery_failures: raise self.delivery_error("Failed deliveries after flush()") - def _sanitize_object( - self, object_type: str, object_: ModelObject - ) -> Dict[str, str]: - dict_ = object_.to_dict() - if object_type == "content": - dict_.pop("data", None) - return dict_ - def _write_addition(self, object_type: str, object_: ModelObject) -> None: """Write a single object to the journal""" key = object_.unique_key() @@ -213,13 +208,13 @@ # if the object is anonymizable, send the non-anonymized version in the # privileged channel topic = f"{self._prefix_privileged}.{object_type}" - dict_ = self._sanitize_object(object_type, object_) + dict_ = self.value_sanitizer(object_type, object_.to_dict()) logger.debug("topic: %s, key: %s, value: %s", topic, key, dict_) self.send(topic, key=key, value=dict_) object_ = anon_object_ topic = f"{self._prefix}.{object_type}" - dict_ = self._sanitize_object(object_type, object_) + dict_ = self.value_sanitizer(object_type, object_.to_dict()) logger.debug("topic: %s, key: %s, value: %s", topic, key, dict_) self.send(topic, key=key, value=dict_)