diff --git a/swh/journal/serializers.py b/swh/journal/serializers.py --- a/swh/journal/serializers.py +++ b/swh/journal/serializers.py @@ -9,8 +9,10 @@ from swh.core.api.serializers import msgpack_dumps, msgpack_loads +KeyType = Union[Dict[str, str], Dict[str, bytes], bytes] -def key_to_kafka(key: Union[bytes, Dict]) -> bytes: + +def key_to_kafka(key: KeyType) -> bytes: """Serialize a key, possibly a dict, in a predictable way""" p = msgpack.Packer(use_bin_type=True) if isinstance(key, dict): @@ -19,7 +21,7 @@ return p.pack(key) -def kafka_to_key(kafka_key: bytes) -> Union[bytes, Dict]: +def kafka_to_key(kafka_key: bytes) -> KeyType: """Deserialize a key""" return msgpack.loads(kafka_key) diff --git a/swh/journal/writer/kafka.py b/swh/journal/writer/kafka.py --- a/swh/journal/writer/kafka.py +++ b/swh/journal/writer/kafka.py @@ -21,7 +21,7 @@ Snapshot, ) -from swh.journal.serializers import key_to_kafka, value_to_kafka +from swh.journal.serializers import KeyType, key_to_kafka, value_to_kafka logger = logging.getLogger(__name__) @@ -79,9 +79,10 @@ if error is not None: self._error_cb(error) - def send(self, topic: str, key, value): + def send(self, topic: str, key: KeyType, value): + kafka_key = key_to_kafka(key) self.producer.produce( - topic=topic, key=key_to_kafka(key), value=value_to_kafka(value), + topic=topic, key=kafka_key, value=value_to_kafka(value), ) # Need to service the callbacks regularly by calling poll @@ -114,7 +115,7 @@ def _get_key(self, object_type: str, object_: OriginVisit) -> Dict[str, str]: ... - def _get_key(self, object_type, object_): + def _get_key(self, object_type: str, object_) -> KeyType: if object_type in ("revision", "release", "directory", "snapshot"): return object_.id elif object_type == "content":