diff --git a/swh/journal/serializers.py b/swh/journal/serializers.py index 8a6a798..07ecca7 100644 --- a/swh/journal/serializers.py +++ b/swh/journal/serializers.py @@ -1,96 +1,67 @@ # Copyright (C) 2016-2017 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from typing import Any, Union import msgpack from swh.core.api.serializers import msgpack_dumps, msgpack_loads -from swh.model.model import ( - Content, - Directory, - KeyType, - MetadataAuthority, - MetadataFetcher, - Origin, - OriginVisit, - OriginVisitStatus, - RawExtrinsicMetadata, - Release, - Revision, - SkippedContent, - Snapshot, -) - -ModelObject = Union[ - Content, - Directory, - MetadataAuthority, - MetadataFetcher, - Origin, - OriginVisit, - OriginVisitStatus, - RawExtrinsicMetadata, - Release, - Revision, - SkippedContent, - Snapshot, -] +from swh.model.model import KeyType def stringify_key_item(k: str, v: Union[str, bytes]) -> str: """Turn the item of a dict key into a string""" if isinstance(v, str): return v if k == "url": return v.decode("utf-8") return v.hex() def pprint_key(key: KeyType) -> str: """Pretty-print a kafka key""" if isinstance(key, dict): return "{%s}" % ", ".join( f"{k}: {stringify_key_item(k, v)}" for k, v in key.items() ) elif isinstance(key, bytes): return key.hex() else: return key def key_to_kafka(key: KeyType) -> bytes: """Serialize a key, possibly a dict, in a predictable way""" p = msgpack.Packer(use_bin_type=True) if isinstance(key, dict): return p.pack_map_pairs(sorted(key.items())) else: return p.pack(key) def kafka_to_key(kafka_key: bytes) -> KeyType: """Deserialize a key""" return msgpack.loads(kafka_key, raw=False) def value_to_kafka(value: Any) -> bytes: """Serialize some data for storage in kafka""" return msgpack_dumps(value) def kafka_to_value(kafka_value: bytes) -> Any: """Deserialize some data stored in kafka""" value = msgpack_loads(kafka_value) return ensure_tuples(value) def ensure_tuples(value: Any) -> Any: if isinstance(value, (tuple, list)): return tuple(map(ensure_tuples, value)) elif isinstance(value, dict): return dict(ensure_tuples(list(value.items()))) else: return value diff --git a/swh/journal/tests/journal_data.py b/swh/journal/tests/journal_data.py index 2d5fc57..e9bd662 100644 --- a/swh/journal/tests/journal_data.py +++ b/swh/journal/tests/journal_data.py @@ -1,343 +1,343 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime from typing import Dict, Sequence import attr -from swh.journal.serializers import ModelObject from swh.model.hashutil import MultiHash, hash_to_bytes, hash_to_hex from swh.model.identifiers import SWHID from swh.model.model import ( + BaseModel, Content, Directory, DirectoryEntry, MetadataAuthority, MetadataAuthorityType, MetadataFetcher, MetadataTargetType, ObjectType, Origin, OriginVisit, OriginVisitStatus, Person, RawExtrinsicMetadata, Release, Revision, RevisionType, SkippedContent, Snapshot, SnapshotBranch, TargetType, Timestamp, TimestampWithTimezone, ) UTC = datetime.timezone.utc CONTENTS = [ Content( length=4, data=f"foo{i}".encode(), status="visible", **MultiHash.from_data(f"foo{i}".encode()).digest(), ) for i in range(10) ] + [ Content( length=14, data=f"forbidden foo{i}".encode(), status="hidden", **MultiHash.from_data(f"forbidden foo{i}".encode()).digest(), ) for i in range(10) ] SKIPPED_CONTENTS = [ SkippedContent( length=4, status="absent", reason=f"because chr({i}) != '*'", **MultiHash.from_data(f"bar{i}".encode()).digest(), ) for i in range(2) ] duplicate_content1 = Content( length=4, sha1=hash_to_bytes("44973274ccef6ab4dfaaf86599792fa9c3fe4689"), sha1_git=b"another-foo", blake2s256=b"another-bar", sha256=b"another-baz", status="visible", ) # Craft a sha1 collision sha1_array = bytearray(duplicate_content1.sha1_git) sha1_array[0] += 1 duplicate_content2 = attr.evolve(duplicate_content1, sha1_git=bytes(sha1_array)) DUPLICATE_CONTENTS = [duplicate_content1, duplicate_content2] COMMITTERS = [ Person(fullname=b"foo", name=b"foo", email=b""), Person(fullname=b"bar", name=b"bar", email=b""), ] DATES = [ TimestampWithTimezone( timestamp=Timestamp(seconds=1234567891, microseconds=0,), offset=120, negative_utc=False, ), TimestampWithTimezone( timestamp=Timestamp(seconds=1234567892, microseconds=0,), offset=120, negative_utc=False, ), ] REVISIONS = [ Revision( id=hash_to_bytes("4ca486e65eb68e4986aeef8227d2db1d56ce51b3"), message=b"hello", date=DATES[0], committer=COMMITTERS[0], author=COMMITTERS[0], committer_date=DATES[0], type=RevisionType.GIT, directory=b"\x01" * 20, synthetic=False, metadata=None, parents=(), ), Revision( id=hash_to_bytes("677063f5c405d6fc1781fc56379c9a9adf43d3a0"), message=b"hello again", date=DATES[1], committer=COMMITTERS[1], author=COMMITTERS[1], committer_date=DATES[1], type=RevisionType.MERCURIAL, directory=b"\x02" * 20, synthetic=False, metadata=None, parents=(), extra_headers=((b"foo", b"bar"),), ), ] RELEASES = [ Release( id=hash_to_bytes("8059dc4e17fcd0e51ca3bcd6b80f4577d281fd08"), name=b"v0.0.1", date=TimestampWithTimezone( timestamp=Timestamp(seconds=1234567890, microseconds=0,), offset=120, negative_utc=False, ), author=COMMITTERS[0], target_type=ObjectType.REVISION, target=b"\x04" * 20, message=b"foo", synthetic=False, ), ] ORIGINS = [ Origin(url="https://somewhere.org/den/fox",), Origin(url="https://overtherainbow.org/fox/den",), ] ORIGIN_VISITS = [ OriginVisit( origin=ORIGINS[0].url, date=datetime.datetime(2013, 5, 7, 4, 20, 39, 369271, tzinfo=UTC), visit=1, type="git", ), OriginVisit( origin=ORIGINS[1].url, date=datetime.datetime(2014, 11, 27, 17, 20, 39, tzinfo=UTC), visit=1, type="hg", ), OriginVisit( origin=ORIGINS[0].url, date=datetime.datetime(2018, 11, 27, 17, 20, 39, tzinfo=UTC), visit=2, type="git", ), OriginVisit( origin=ORIGINS[0].url, date=datetime.datetime(2018, 11, 27, 17, 20, 39, tzinfo=UTC), visit=3, type="git", ), OriginVisit( origin=ORIGINS[1].url, date=datetime.datetime(2015, 11, 27, 17, 20, 39, tzinfo=UTC), visit=2, type="hg", ), ] # The origin-visit-status dates needs to be shifted slightly in the future from their # visit dates counterpart. Otherwise, we are hitting storage-wise the "on conflict" # ignore policy (because origin-visit-add creates an origin-visit-status with the same # parameters from the origin-visit {origin, visit, date}... ORIGIN_VISIT_STATUSES = [ OriginVisitStatus( origin=ORIGINS[0].url, date=datetime.datetime(2013, 5, 7, 4, 20, 39, 432222, tzinfo=UTC), visit=1, status="ongoing", snapshot=None, metadata=None, ), OriginVisitStatus( origin=ORIGINS[1].url, date=datetime.datetime(2014, 11, 27, 17, 21, 12, tzinfo=UTC), visit=1, status="ongoing", snapshot=None, metadata=None, ), OriginVisitStatus( origin=ORIGINS[0].url, date=datetime.datetime(2018, 11, 27, 17, 20, 59, tzinfo=UTC), visit=2, status="ongoing", snapshot=None, metadata=None, ), OriginVisitStatus( origin=ORIGINS[0].url, date=datetime.datetime(2018, 11, 27, 17, 20, 49, tzinfo=UTC), visit=3, status="full", snapshot=hash_to_bytes("17d0066a4a80aba4a0e913532ee8ff2014f006a9"), metadata=None, ), OriginVisitStatus( origin=ORIGINS[1].url, date=datetime.datetime(2015, 11, 27, 17, 22, 18, tzinfo=UTC), visit=2, status="partial", snapshot=hash_to_bytes("8ce268b87faf03850693673c3eb5c9bb66e1ca38"), metadata=None, ), ] DIRECTORIES = [ Directory(id=hash_to_bytes("4b825dc642cb6eb9a060e54bf8d69288fbee4904"), entries=()), Directory( id=hash_to_bytes("21416d920e0ebf0df4a7888bed432873ed5cb3a7"), entries=( DirectoryEntry( name=b"file1.ext", perms=0o644, type="file", target=CONTENTS[0].sha1_git, ), DirectoryEntry( name=b"dir1", perms=0o755, type="dir", target=hash_to_bytes("4b825dc642cb6eb9a060e54bf8d69288fbee4904"), ), DirectoryEntry( name=b"subprepo1", perms=0o160000, type="rev", target=REVISIONS[1].id, ), ), ), ] SNAPSHOTS = [ Snapshot( id=hash_to_bytes("17d0066a4a80aba4a0e913532ee8ff2014f006a9"), branches={ b"master": SnapshotBranch( target_type=TargetType.REVISION, target=REVISIONS[0].id ) }, ), Snapshot( id=hash_to_bytes("8ce268b87faf03850693673c3eb5c9bb66e1ca38"), branches={ b"target/revision": SnapshotBranch( target_type=TargetType.REVISION, target=REVISIONS[0].id, ), b"target/alias": SnapshotBranch( target_type=TargetType.ALIAS, target=b"target/revision" ), b"target/directory": SnapshotBranch( target_type=TargetType.DIRECTORY, target=DIRECTORIES[0].id, ), b"target/release": SnapshotBranch( target_type=TargetType.RELEASE, target=RELEASES[0].id ), b"target/snapshot": SnapshotBranch( target_type=TargetType.SNAPSHOT, target=hash_to_bytes("17d0066a4a80aba4a0e913532ee8ff2014f006a9"), ), }, ), ] METADATA_AUTHORITIES = [ MetadataAuthority( type=MetadataAuthorityType.FORGE, url="http://example.org/", metadata={}, ), ] METADATA_FETCHERS = [ MetadataFetcher(name="test-fetcher", version="1.0.0", metadata={},) ] RAW_EXTRINSIC_METADATA = [ RawExtrinsicMetadata( type=MetadataTargetType.ORIGIN, target="http://example.org/foo.git", discovery_date=datetime.datetime(2020, 7, 30, 17, 8, 20, tzinfo=UTC), authority=attr.evolve(METADATA_AUTHORITIES[0], metadata=None), fetcher=attr.evolve(METADATA_FETCHERS[0], metadata=None), format="json", metadata=b'{"foo": "bar"}', ), RawExtrinsicMetadata( type=MetadataTargetType.CONTENT, target=SWHID( object_type="content", object_id=hash_to_hex(CONTENTS[0].sha1_git) ), discovery_date=datetime.datetime(2020, 7, 30, 17, 8, 20, tzinfo=UTC), authority=attr.evolve(METADATA_AUTHORITIES[0], metadata=None), fetcher=attr.evolve(METADATA_FETCHERS[0], metadata=None), format="json", metadata=b'{"foo": "bar"}', ), ] -TEST_OBJECTS: Dict[str, Sequence[ModelObject]] = { +TEST_OBJECTS: Dict[str, Sequence[BaseModel]] = { "content": CONTENTS, "directory": DIRECTORIES, "metadata_authority": METADATA_AUTHORITIES, "metadata_fetcher": METADATA_FETCHERS, "origin": ORIGINS, "origin_visit": ORIGIN_VISITS, "origin_visit_status": ORIGIN_VISIT_STATUSES, "raw_extrinsic_metadata": RAW_EXTRINSIC_METADATA, "release": RELEASES, "revision": REVISIONS, "snapshot": SNAPSHOTS, "skipped_content": SKIPPED_CONTENTS, } diff --git a/swh/journal/tests/test_kafka_writer.py b/swh/journal/tests/test_kafka_writer.py index 9fa0a8d..75735f3 100644 --- a/swh/journal/tests/test_kafka_writer.py +++ b/swh/journal/tests/test_kafka_writer.py @@ -1,172 +1,172 @@ # Copyright (C) 2018-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from typing import Iterable from confluent_kafka import Consumer, Producer import pytest from swh.journal.pytest_plugin import assert_all_objects_consumed, consume_messages from swh.journal.tests.journal_data import TEST_OBJECTS from swh.journal.writer import model_object_dict_sanitizer from swh.journal.writer.kafka import KafkaDeliveryError, KafkaJournalWriter -from swh.model.model import Directory, Release, Revision +from swh.model.model import BaseModel, Directory, Release, Revision def test_kafka_writer( kafka_prefix: str, kafka_server: str, consumer: Consumer, privileged_object_types: Iterable[str], ): - writer = KafkaJournalWriter( + writer = KafkaJournalWriter[BaseModel]( brokers=[kafka_server], client_id="kafka_writer", prefix=kafka_prefix, value_sanitizer=model_object_dict_sanitizer, anonymize=False, ) expected_messages = 0 for object_type, objects in TEST_OBJECTS.items(): writer.write_additions(object_type, objects) expected_messages += len(objects) consumed_messages = consume_messages(consumer, kafka_prefix, expected_messages) assert_all_objects_consumed(consumed_messages) for key, obj_dict in consumed_messages["revision"]: obj = Revision.from_dict(obj_dict) for person in (obj.author, obj.committer): assert not ( len(person.fullname) == 32 and person.name is None and person.email is None ) for key, obj_dict in consumed_messages["release"]: obj = Release.from_dict(obj_dict) for person in (obj.author,): assert not ( len(person.fullname) == 32 and person.name is None and person.email is None ) def test_kafka_writer_anonymized( kafka_prefix: str, kafka_server: str, consumer: Consumer, privileged_object_types: Iterable[str], ): - writer = KafkaJournalWriter( + writer = KafkaJournalWriter[BaseModel]( brokers=[kafka_server], client_id="kafka_writer", prefix=kafka_prefix, value_sanitizer=model_object_dict_sanitizer, anonymize=True, ) expected_messages = 0 for object_type, objects in TEST_OBJECTS.items(): writer.write_additions(object_type, objects) expected_messages += len(objects) if object_type in privileged_object_types: expected_messages += len(objects) consumed_messages = consume_messages(consumer, kafka_prefix, expected_messages) assert_all_objects_consumed(consumed_messages, exclude=["revision", "release"]) for key, obj_dict in consumed_messages["revision"]: obj = Revision.from_dict(obj_dict) for person in (obj.author, obj.committer): assert ( len(person.fullname) == 32 and person.name is None and person.email is None ) for key, obj_dict in consumed_messages["release"]: obj = Release.from_dict(obj_dict) for person in (obj.author,): assert ( len(person.fullname) == 32 and person.name is None and person.email is None ) def test_write_delivery_failure( kafka_prefix: str, kafka_server: str, consumer: Consumer ): class MockKafkaError: """A mocked kafka error""" def str(self): return "Mocked Kafka Error" def name(self): return "SWH_MOCK_ERROR" class KafkaJournalWriterFailDelivery(KafkaJournalWriter): """A journal writer which always fails delivering messages""" def _on_delivery(self, error, message): """Replace the inbound error with a fake delivery error""" super()._on_delivery(MockKafkaError(), message) kafka_prefix += ".swh.journal.objects" writer = KafkaJournalWriterFailDelivery( brokers=[kafka_server], client_id="kafka_writer", prefix=kafka_prefix, value_sanitizer=model_object_dict_sanitizer, ) empty_dir = Directory(entries=()) with pytest.raises(KafkaDeliveryError) as exc: writer.write_addition("directory", empty_dir) assert "Failed deliveries" in exc.value.message assert len(exc.value.delivery_failures) == 1 delivery_failure = exc.value.delivery_failures[0] assert delivery_failure.key == empty_dir.id assert delivery_failure.code == "SWH_MOCK_ERROR" def test_write_delivery_timeout( kafka_prefix: str, kafka_server: str, consumer: Consumer ): produced = [] class MockProducer(Producer): """A kafka producer which pretends to produce messages, but never sends any delivery acknowledgements""" def produce(self, **kwargs): produced.append(kwargs) - writer = KafkaJournalWriter( + writer = KafkaJournalWriter[BaseModel]( brokers=[kafka_server], client_id="kafka_writer", prefix=kafka_prefix, value_sanitizer=model_object_dict_sanitizer, flush_timeout=1, producer_class=MockProducer, ) empty_dir = Directory(entries=()) with pytest.raises(KafkaDeliveryError) as exc: writer.write_addition("directory", empty_dir) assert len(produced) == 1 assert "timeout" in exc.value.message assert len(exc.value.delivery_failures) == 1 delivery_failure = exc.value.delivery_failures[0] assert delivery_failure.key == empty_dir.id assert delivery_failure.code == "SWH_FLUSH_TIMEOUT" diff --git a/swh/journal/writer/__init__.py b/swh/journal/writer/__init__.py index 7957970..92879e6 100644 --- a/swh/journal/writer/__init__.py +++ b/swh/journal/writer/__init__.py @@ -1,41 +1,58 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Any, Dict +from typing import Any, Dict, Optional, TypeVar import warnings +from typing_extensions import Protocol + +from swh.model.model import KeyType + +TSelf = TypeVar("TSelf") + + +class ValueProtocol(Protocol): + def anonymize(self: TSelf) -> Optional[TSelf]: + ... + + def unique_key(self) -> KeyType: + ... + + def to_dict(self) -> Dict[str, Any]: + ... + def model_object_dict_sanitizer( object_type: str, object_dict: Dict[str, Any] ) -> Dict[str, str]: object_dict = object_dict.copy() if object_type == "content": object_dict.pop("data", None) return object_dict def get_journal_writer(cls, **kwargs): if "args" in kwargs: warnings.warn( 'Explicit "args" key is deprecated, use keys directly instead.', DeprecationWarning, ) kwargs = kwargs["args"] kwargs.setdefault("value_sanitizer", model_object_dict_sanitizer) if cls == "inmemory": # FIXME: Remove inmemory in due time warnings.warn( "cls = 'inmemory' is deprecated, use 'memory' instead", DeprecationWarning ) cls = "memory" if cls == "memory": from .inmemory import InMemoryJournalWriter as JournalWriter elif cls == "kafka": from .kafka import KafkaJournalWriter as JournalWriter else: raise ValueError("Unknown journal writer class `%s`" % cls) return JournalWriter(**kwargs) diff --git a/swh/journal/writer/inmemory.py b/swh/journal/writer/inmemory.py index d8f2cae..baecadc 100644 --- a/swh/journal/writer/inmemory.py +++ b/swh/journal/writer/inmemory.py @@ -1,41 +1,42 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging from multiprocessing import Manager -from typing import Any, List, Tuple +from typing import Any, Generic, List, Tuple, TypeVar -from swh.journal.serializers import ModelObject -from swh.model.model import BaseModel +from . import ValueProtocol logger = logging.getLogger(__name__) -class InMemoryJournalWriter: - objects: List[Tuple[str, ModelObject]] - privileged_objects: List[Tuple[str, ModelObject]] +TValue = TypeVar("TValue", bound=ValueProtocol) + + +class InMemoryJournalWriter(Generic[TValue]): + objects: List[Tuple[str, TValue]] + privileged_objects: List[Tuple[str, TValue]] def __init__(self, value_sanitizer: Any): # Share the list of objects across processes, for RemoteAPI tests. self.manager = Manager() self.objects = self.manager.list() self.privileged_objects = self.manager.list() def write_addition( - self, object_type: str, object_: ModelObject, privileged: bool = False + self, object_type: str, object_: TValue, privileged: bool = False ) -> None: - assert isinstance(object_, BaseModel) if privileged: self.privileged_objects.append((object_type, object_)) else: self.objects.append((object_type, object_)) write_update = write_addition def write_additions( - self, object_type: str, objects: List[ModelObject], privileged: bool = False + self, object_type: str, objects: List[TValue], privileged: bool = False ) -> None: for object_ in objects: self.write_addition(object_type, object_, privileged) diff --git a/swh/journal/writer/kafka.py b/swh/journal/writer/kafka.py index 8c6a1b7..3df3efb 100644 --- a/swh/journal/writer/kafka.py +++ b/swh/journal/writer/kafka.py @@ -1,233 +1,249 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import time -from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Type +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + List, + NamedTuple, + Optional, + Type, + TypeVar, +) from confluent_kafka import KafkaException, Producer -from swh.journal.serializers import ( - KeyType, - ModelObject, - key_to_kafka, - pprint_key, - value_to_kafka, -) +from swh.journal.serializers import KeyType, key_to_kafka, pprint_key, value_to_kafka + +from . import ValueProtocol logger = logging.getLogger(__name__) class DeliveryTag(NamedTuple): """Unique tag allowing us to check for a message delivery""" topic: str kafka_key: bytes class DeliveryFailureInfo(NamedTuple): """Verbose information for failed deliveries""" object_type: str key: KeyType message: str code: str def get_object_type(topic: str) -> str: """Get the object type from a topic string""" return topic.rsplit(".", 1)[-1] class KafkaDeliveryError(Exception): """Delivery failed on some kafka messages.""" def __init__(self, message: str, delivery_failures: Iterable[DeliveryFailureInfo]): self.message = message self.delivery_failures = list(delivery_failures) def pretty_failures(self) -> str: return ", ".join( f"{f.object_type} {pprint_key(f.key)} ({f.message})" for f in self.delivery_failures ) def __str__(self): return f"KafkaDeliveryError({self.message}, [{self.pretty_failures()}])" -class KafkaJournalWriter: - """This class is used to write serialized versions of swh.model.model objects to a - series of Kafka topics. +TValue = TypeVar("TValue", bound=ValueProtocol) + + +class KafkaJournalWriter(Generic[TValue]): + """This class is used to write serialized versions of value objects to a + series of Kafka topics. The type parameter `TValue`, which must implement the + `ValueProtocol`, is the type of values this writer will write. + Typically, `TValue` will be `swh.model.model.BaseModel`. Topics used to send objects representations are built from a ``prefix`` plus the type of the object: ``{prefix}.{object_type}`` Objects can be sent as is, or can be anonymized. The anonymization feature, when - activated, will write anonymized versions of model objects in the main topic, and + activated, will write anonymized versions of value objects in the main topic, and stock (non-anonymized) objects will be sent to a dedicated (privileged) set of topics: ``{prefix}_privileged.{object_type}`` - The anonymization of a swh.model object is the result of calling its - ``BaseModel.anonymize()`` method. An object is considered anonymizable if this + The anonymization of a value object is the result of calling its + ``anonymize()`` method. An object is considered anonymizable if this method returns a (non-None) value. Args: brokers: list of broker addresses and ports. prefix: the prefix used to build the topic names for objects. client_id: the id of the writer sent to kafka. + value_sanitizer: a function that takes the object type and the dict + representation of an object as argument, and returns an other dict + that should be actually stored in the journal (eg. removing keys + that do no belong there) producer_config: extra configuration keys passed to the `Producer`. flush_timeout: timeout, in seconds, after which the `flush` operation will fail if some message deliveries are still pending. producer_class: override for the kafka producer class. anonymize: if True, activate the anonymization feature. """ def __init__( self, brokers: Iterable[str], prefix: str, client_id: str, value_sanitizer: Callable[[str, Dict[str, Any]], Dict[str, Any]], producer_config: Optional[Dict] = None, flush_timeout: float = 120, producer_class: Type[Producer] = Producer, anonymize: bool = False, ): self._prefix = prefix self._prefix_privileged = f"{self._prefix}_privileged" self.anonymize = anonymize if not producer_config: producer_config = {} if "message.max.bytes" not in producer_config: producer_config = { "message.max.bytes": 100 * 1024 * 1024, **producer_config, } self.producer = producer_class( { "bootstrap.servers": ",".join(brokers), "client.id": client_id, "on_delivery": self._on_delivery, "error_cb": self._error_cb, "logger": logger, "acks": "all", **producer_config, } ) # Delivery management self.flush_timeout = flush_timeout # delivery tag -> original object "key" mapping self.deliveries_pending: Dict[DeliveryTag, KeyType] = {} # List of (object_type, key, error_msg, error_name) for failed deliveries self.delivery_failures: List[DeliveryFailureInfo] = [] self.value_sanitizer = value_sanitizer def _error_cb(self, error): if error.fatal(): raise KafkaException(error) logger.info("Received non-fatal kafka error: %s", error) def _on_delivery(self, error, message): (topic, key) = delivery_tag = DeliveryTag(message.topic(), message.key()) sent_key = self.deliveries_pending.pop(delivery_tag, None) if error is not None: self.delivery_failures.append( DeliveryFailureInfo( get_object_type(topic), sent_key, error.str(), error.name() ) ) def send(self, topic: str, key: KeyType, value): kafka_key = key_to_kafka(key) self.producer.produce( topic=topic, key=kafka_key, value=value_to_kafka(value), ) self.deliveries_pending[DeliveryTag(topic, kafka_key)] = key def delivery_error(self, message) -> KafkaDeliveryError: """Get all failed deliveries, and clear them""" ret = self.delivery_failures self.delivery_failures = [] while self.deliveries_pending: delivery_tag, orig_key = self.deliveries_pending.popitem() (topic, kafka_key) = delivery_tag ret.append( DeliveryFailureInfo( get_object_type(topic), orig_key, "No delivery before flush() timeout", "SWH_FLUSH_TIMEOUT", ) ) return KafkaDeliveryError(message, ret) def flush(self): start = time.monotonic() self.producer.flush(self.flush_timeout) while self.deliveries_pending: if time.monotonic() - start > self.flush_timeout: break self.producer.poll(0.1) if self.deliveries_pending: # Delivery timeout raise self.delivery_error( "flush() exceeded timeout (%ss)" % self.flush_timeout, ) elif self.delivery_failures: raise self.delivery_error("Failed deliveries after flush()") - def _write_addition(self, object_type: str, object_: ModelObject) -> None: + def _write_addition(self, object_type: str, object_: TValue) -> None: """Write a single object to the journal""" key = object_.unique_key() if self.anonymize: anon_object_ = object_.anonymize() if anon_object_: # can be either None, or an anonymized object # if the object is anonymizable, send the non-anonymized version in the # privileged channel topic = f"{self._prefix_privileged}.{object_type}" dict_ = self.value_sanitizer(object_type, object_.to_dict()) logger.debug("topic: %s, key: %s, value: %s", topic, key, dict_) self.send(topic, key=key, value=dict_) object_ = anon_object_ topic = f"{self._prefix}.{object_type}" dict_ = self.value_sanitizer(object_type, object_.to_dict()) logger.debug("topic: %s, key: %s, value: %s", topic, key, dict_) self.send(topic, key=key, value=dict_) - def write_addition(self, object_type: str, object_: ModelObject) -> None: + def write_addition(self, object_type: str, object_: TValue) -> None: """Write a single object to the journal""" self._write_addition(object_type, object_) self.flush() write_update = write_addition - def write_additions(self, object_type: str, objects: Iterable[ModelObject]) -> None: + def write_additions(self, object_type: str, objects: Iterable[TValue]) -> None: """Write a set of objects to the journal""" for object_ in objects: self._write_addition(object_type, object_) self.flush()