diff --git a/swh/journal/serializers.py b/swh/journal/serializers.py --- a/swh/journal/serializers.py +++ b/swh/journal/serializers.py @@ -3,15 +3,69 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Any, Dict, Union +from typing import Any, Dict, Union, overload import msgpack from swh.core.api.serializers import msgpack_dumps, msgpack_loads +from swh.model.hashutil import DEFAULT_ALGORITHMS +from swh.model.model import ( + Content, + Directory, + Origin, + OriginVisit, + Release, + Revision, + SkippedContent, + Snapshot, +) + +ModelObject = Union[ + Content, Directory, Origin, OriginVisit, Release, Revision, SkippedContent, Snapshot +] KeyType = Union[Dict[str, str], Dict[str, bytes], bytes] +# these @overload'ed versions of the object_key method aim at helping mypy figuring +# the correct type-ing. +@overload +def object_key( + object_type: str, object_: Union[Content, Directory, Revision, Release, Snapshot] +) -> bytes: + ... + + +@overload +def object_key( + object_type: str, object_: Union[Origin, SkippedContent] +) -> Dict[str, bytes]: + ... + + +@overload +def object_key(object_type: str, object_: OriginVisit) -> Dict[str, str]: + ... + + +def object_key(object_type: str, object_) -> KeyType: + if object_type in ("revision", "release", "directory", "snapshot"): + return object_.id + elif object_type == "content": + return object_.sha1 # TODO: use a dict of hashes + elif object_type == "skipped_content": + return {hash: getattr(object_, hash) for hash in DEFAULT_ALGORITHMS} + elif object_type == "origin": + return {"url": object_.url} + elif object_type == "origin_visit": + return { + "origin": object_.origin, + "date": str(object_.date), + } + else: + raise ValueError("Unknown object type: %s." % object_type) + + def key_to_kafka(key: KeyType) -> bytes: """Serialize a key, possibly a dict, in a predictable way""" p = msgpack.Packer(use_bin_type=True) diff --git a/swh/journal/tests/conftest.py b/swh/journal/tests/conftest.py --- a/swh/journal/tests/conftest.py +++ b/swh/journal/tests/conftest.py @@ -28,8 +28,8 @@ from swh.model import hypothesis_strategies as strategies from swh.model.hashutil import MultiHash, hash_to_bytes - -from swh.journal.writer.kafka import OBJECT_TYPES, ModelObject +from swh.journal.serializers import ModelObject +from swh.journal.writer.kafka import OBJECT_TYPES logger = logging.getLogger(__name__) diff --git a/swh/journal/tests/test_serializers.py b/swh/journal/tests/test_serializers.py --- a/swh/journal/tests/test_serializers.py +++ b/swh/journal/tests/test_serializers.py @@ -8,6 +8,8 @@ from swh.journal import serializers +from .conftest import TEST_OBJECTS + def test_key_to_kafka_repeatable(): """Check the kafka key encoding is repeatable""" @@ -25,3 +27,10 @@ d[k] = base_dict[k] assert key == serializers.key_to_kafka(d) + + +def test_get_key(): + """Test whether get_key works on all our objects""" + for object_type, objects in TEST_OBJECTS.items(): + for obj in objects: + assert serializers.object_key(object_type, obj) is not None diff --git a/swh/journal/writer/kafka.py b/swh/journal/writer/kafka.py --- a/swh/journal/writer/kafka.py +++ b/swh/journal/writer/kafka.py @@ -4,11 +4,10 @@ # See top-level LICENSE file for more information import logging -from typing import Dict, Iterable, Optional, Type, Union, overload +from typing import Dict, Iterable, Optional, Type from confluent_kafka import Producer, KafkaException -from swh.model.hashutil import DEFAULT_ALGORITHMS from swh.model.model import ( BaseModel, Content, @@ -21,7 +20,13 @@ Snapshot, ) -from swh.journal.serializers import KeyType, key_to_kafka, value_to_kafka +from swh.journal.serializers import ( + KeyType, + ModelObject, + object_key, + key_to_kafka, + value_to_kafka, +) logger = logging.getLogger(__name__) @@ -36,10 +41,6 @@ Snapshot: "snapshot", } -ModelObject = Union[ - Content, Directory, Origin, OriginVisit, Release, Revision, SkippedContent, Snapshot -] - class KafkaJournalWriter: """This class is instantiated and used by swh-storage to write incoming @@ -105,47 +106,6 @@ def flush(self): self.producer.flush() - # these @overload'ed versions of the _get_key method aim at helping mypy figuring - # the correct type-ing. - @overload - def _get_key( - self, object_type: str, object_: Union[Revision, Release, Directory, Snapshot] - ) -> bytes: - ... - - @overload - def _get_key(self, object_type: str, object_: Content) -> bytes: - ... - - @overload - def _get_key(self, object_type: str, object_: SkippedContent) -> Dict[str, bytes]: - ... - - @overload - def _get_key(self, object_type: str, object_: Origin) -> Dict[str, bytes]: - ... - - @overload - def _get_key(self, object_type: str, object_: OriginVisit) -> Dict[str, str]: - ... - - def _get_key(self, object_type: str, object_) -> KeyType: - if object_type in ("revision", "release", "directory", "snapshot"): - return object_.id - elif object_type == "content": - return object_.sha1 # TODO: use a dict of hashes - elif object_type == "skipped_content": - return {hash: getattr(object_, hash) for hash in DEFAULT_ALGORITHMS} - elif object_type == "origin": - return {"url": object_.url} - elif object_type == "origin_visit": - return { - "origin": object_.origin, - "date": str(object_.date), - } - else: - raise ValueError("Unknown object type: %s." % object_type) - def _sanitize_object( self, object_type: str, object_: ModelObject ) -> Dict[str, str]: @@ -160,7 +120,7 @@ def _write_addition(self, object_type: str, object_: ModelObject) -> None: """Write a single object to the journal""" topic = f"{self._prefix}.{object_type}" - key = self._get_key(object_type, object_) + key = object_key(object_type, object_) dict_ = self._sanitize_object(object_type, object_) logger.debug("topic: %s, key: %s, value: %s", topic, key, dict_) self.send(topic, key=key, value=dict_)