diff --git a/swh/journal/serializers.py b/swh/journal/serializers.py index ea23898..247b6a9 100644 --- a/swh/journal/serializers.py +++ b/swh/journal/serializers.py @@ -1,32 +1,34 @@ # Copyright (C) 2016-2017 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from typing import Any, Dict, Union + import msgpack from swh.core.api.serializers import msgpack_dumps, msgpack_loads -def key_to_kafka(key): +def key_to_kafka(key: Union[bytes, Dict]) -> bytes: """Serialize a key, possibly a dict, in a predictable way""" p = msgpack.Packer(use_bin_type=True) if isinstance(key, dict): return p.pack_map_pairs(sorted(key.items())) else: return p.pack(key) -def kafka_to_key(kafka_key): +def kafka_to_key(kafka_key: bytes) -> Union[bytes, Dict]: """Deserialize a key""" return msgpack.loads(kafka_key) -def value_to_kafka(value): +def value_to_kafka(value: Any) -> bytes: """Serialize some data for storage in kafka""" return msgpack_dumps(value) -def kafka_to_value(kafka_value): +def kafka_to_value(kafka_value: bytes) -> Any: """Deserialize some data stored in kafka""" return msgpack_loads(kafka_value) diff --git a/swh/journal/tests/test_kafka_writer.py b/swh/journal/tests/test_kafka_writer.py index 90861a2..575c040 100644 --- a/swh/journal/tests/test_kafka_writer.py +++ b/swh/journal/tests/test_kafka_writer.py @@ -1,161 +1,162 @@ # Copyright (C) 2018-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import defaultdict import datetime from confluent_kafka import Consumer, KafkaException from subprocess import Popen from typing import List, Tuple from swh.storage import get_storage from swh.journal.replay import object_converter_fn from swh.journal.serializers import ( kafka_to_key, kafka_to_value ) -from swh.journal.writer.kafka import KafkaJournalWriter +from swh.journal.writer.kafka import KafkaJournalWriter, OBJECT_TYPES from swh.model.model import Content, Origin, BaseModel from .conftest import OBJECT_TYPE_KEYS +MODEL_OBJECTS = {v: k for (k, v) in OBJECT_TYPES.items()} + def assert_written(consumer, kafka_prefix, expected_messages): consumed_objects = defaultdict(list) fetched_messages = 0 retries_left = 1000 while fetched_messages < expected_messages: if retries_left == 0: raise ValueError('Timed out fetching messages from kafka') msg = consumer.poll(timeout=0.01) if not msg: retries_left -= 1 continue error = msg.error() if error is not None: if error.fatal(): raise KafkaException(error) retries_left -= 1 continue fetched_messages += 1 consumed_objects[msg.topic()].append( (kafka_to_key(msg.key()), kafka_to_value(msg.value())) ) for (object_type, (key_name, objects)) in OBJECT_TYPE_KEYS.items(): topic = kafka_prefix + '.' + object_type (keys, values) = zip(*consumed_objects[topic]) if key_name: assert list(keys) == [object_[key_name] for object_ in objects] else: pass # TODO if object_type == 'origin_visit': for value in values: del value['visit'] elif object_type == 'content': for value in values: del value['ctime'] for object_ in objects: assert object_ in values def test_kafka_writer( kafka_prefix: str, kafka_server: Tuple[Popen, int], consumer: Consumer): kafka_prefix += '.swh.journal.objects' - config = { - 'brokers': ['localhost:%d' % kafka_server[1]], - 'client_id': 'kafka_writer', - 'prefix': kafka_prefix, - 'producer_config': { + writer = KafkaJournalWriter( + brokers=[f'localhost:{kafka_server[1]}'], + client_id='kafka_writer', + prefix=kafka_prefix, + producer_config={ 'message.max.bytes': 100000000, - } - } - - writer = KafkaJournalWriter(**config) + }) expected_messages = 0 for (object_type, (_, objects)) in OBJECT_TYPE_KEYS.items(): - for (num, object_) in enumerate(objects): + for (num, object_d) in enumerate(objects): if object_type == 'origin_visit': - object_ = {**object_, 'visit': num} + object_d = {**object_d, 'visit': num} if object_type == 'content': - object_ = {**object_, 'ctime': datetime.datetime.now()} + object_d = {**object_d, 'ctime': datetime.datetime.now()} + object_ = MODEL_OBJECTS[object_type].from_dict(object_d) + writer.write_addition(object_type, object_) expected_messages += 1 assert_written(consumer, kafka_prefix, expected_messages) def test_storage_direct_writer( kafka_prefix: str, kafka_server: Tuple[Popen, int], consumer: Consumer): kafka_prefix += '.swh.journal.objects' writer_config = { 'cls': 'kafka', 'brokers': ['localhost:%d' % kafka_server[1]], 'client_id': 'kafka_writer', 'prefix': kafka_prefix, 'producer_config': { 'message.max.bytes': 100000000, } } storage_config = { 'cls': 'pipeline', 'steps': [ {'cls': 'memory', 'journal_writer': writer_config}, ] } storage = get_storage(**storage_config) expected_messages = 0 for (object_type, (_, objects)) in OBJECT_TYPE_KEYS.items(): method = getattr(storage, object_type + '_add') if object_type in ('content', 'directory', 'revision', 'release', 'snapshot', 'origin'): objects_: List[BaseModel] if object_type == 'content': objects_ = [ Content.from_dict({ **obj, 'data': b''}) for obj in objects ] else: objects_ = [ object_converter_fn[object_type](obj) for obj in objects ] method(objects_) expected_messages += len(objects) elif object_type in ('origin_visit',): for object_ in objects: object_ = object_.copy() origin_url = object_.pop('origin') storage.origin_add_one(Origin(url=origin_url)) visit = method(origin_url, date=object_.pop('date'), type=object_.pop('type')) expected_messages += 1 storage.origin_visit_update(origin_url, visit.visit, **object_) expected_messages += 1 else: assert False, object_type assert_written(consumer, kafka_prefix, expected_messages) diff --git a/swh/journal/writer/inmemory.py b/swh/journal/writer/inmemory.py index 6c7b84c..175f473 100644 --- a/swh/journal/writer/inmemory.py +++ b/swh/journal/writer/inmemory.py @@ -1,30 +1,33 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import copy + from multiprocessing import Manager +from typing import List from swh.model.model import BaseModel +from .kafka import ModelObject + logger = logging.getLogger(__name__) class InMemoryJournalWriter: def __init__(self): # Share the list of objects across processes, for RemoteAPI tests. self.manager = Manager() self.objects = self.manager.list() - def write_addition(self, object_type, object_): - if isinstance(object_, BaseModel): - object_ = object_.to_dict() + def write_addition(self, object_type: str, object_: ModelObject) -> None: + assert isinstance(object_, BaseModel) self.objects.append((object_type, copy.deepcopy(object_))) write_update = write_addition - def write_additions(self, object_type, objects): + def write_additions(self, object_type: str, objects: List[ModelObject]) -> None: for object_ in objects: self.write_addition(object_type, object_) diff --git a/swh/journal/writer/kafka.py b/swh/journal/writer/kafka.py index 648d20a..3c2949c 100644 --- a/swh/journal/writer/kafka.py +++ b/swh/journal/writer/kafka.py @@ -1,115 +1,163 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging +from typing import Dict, Iterable, List, Type, Union, overload from confluent_kafka import Producer, KafkaException from swh.model.hashutil import DEFAULT_ALGORITHMS -from swh.model.model import BaseModel +from swh.model.model import ( + BaseModel, + Content, + Directory, + Origin, + OriginVisit, + Release, + Revision, + SkippedContent, + Snapshot, +) from swh.journal.serializers import key_to_kafka, value_to_kafka logger = logging.getLogger(__name__) +OBJECT_TYPES: Dict[Type[BaseModel], str] = { + Content: "content", + Directory: "directory", + Origin: "origin", + OriginVisit: "origin_visit", + Release: "release", + Revision: "revision", + SkippedContent: "skipped_content", + Snapshot: "snapshot", +} + +ModelObject = Union[ + Content, Directory, Origin, OriginVisit, Release, Revision, SkippedContent, Snapshot +] + class KafkaJournalWriter: """This class is instantiated and used by swh-storage to write incoming new objects to Kafka before adding them to the storage backend (eg. postgresql) itself.""" - def __init__(self, brokers, prefix, client_id, producer_config=None): - self._prefix = prefix - if isinstance(brokers, str): - brokers = [brokers] + def __init__( + self, + brokers: Iterable[str], + prefix: str, + client_id: str, + producer_config: Dict = Union[None, Dict], + ): + self._prefix = prefix if not producer_config: producer_config = {} - self.producer = Producer({ - 'bootstrap.servers': ','.join(brokers), - 'client.id': client_id, - 'on_delivery': self._on_delivery, - 'error_cb': self._error_cb, - 'logger': logger, - 'acks': 'all', - **producer_config, - }) + self.producer = Producer( + { + "bootstrap.servers": ",".join(brokers), + "client.id": client_id, + "on_delivery": self._on_delivery, + "error_cb": self._error_cb, + "logger": logger, + "acks": "all", + **producer_config, + } + ) def _error_cb(self, error): if error.fatal(): raise KafkaException(error) - logger.info('Received non-fatal kafka error: %s', error) + logger.info("Received non-fatal kafka error: %s", error) def _on_delivery(self, error, message): if error is not None: self._error_cb(error) - def send(self, topic, key, value): + def send(self, topic: str, key, value): self.producer.produce( - topic=topic, - key=key_to_kafka(key), - value=value_to_kafka(value), + topic=topic, key=key_to_kafka(key), value=value_to_kafka(value), ) # Need to service the callbacks regularly by calling poll self.producer.poll(0) def flush(self): self.producer.flush() + # these @overload'ed versions of the _get_key method aim at helping mypy figuring + # the correct type-ing. + @overload + def _get_key( + self, object_type: str, object_: Union[Revision, Release, Directory, Snapshot] + ) -> bytes: + ... + + @overload + def _get_key(self, object_type: str, object_: Content) -> bytes: + ... + + @overload + def _get_key(self, object_type: str, object_: SkippedContent) -> Dict[str, bytes]: + ... + + @overload + def _get_key(self, object_type: str, object_: Origin) -> Dict[str, bytes]: + ... + + @overload + def _get_key(self, object_type: str, object_: OriginVisit) -> Dict[str, str]: + ... + def _get_key(self, object_type, object_): - if object_type in ('revision', 'release', 'directory', 'snapshot'): - return object_['id'] - elif object_type == 'content': - return object_['sha1'] # TODO: use a dict of hashes - elif object_type == 'skipped_content': - return { - hash: object_[hash] - for hash in DEFAULT_ALGORITHMS - } - elif object_type == 'origin': - return {'url': object_['url']} - elif object_type == 'origin_visit': + if object_type in ("revision", "release", "directory", "snapshot"): + return object_.id + elif object_type == "content": + return object_.sha1 # TODO: use a dict of hashes + elif object_type == "skipped_content": + return {hash: getattr(object_, hash) for hash in DEFAULT_ALGORITHMS} + elif object_type == "origin": + return {"url": object_.url} + elif object_type == "origin_visit": return { - 'origin': object_['origin'], - 'date': str(object_['date']), + "origin": object_.origin, + "date": str(object_.date), } else: - raise ValueError('Unknown object type: %s.' % object_type) - - def _sanitize_object(self, object_type, object_): - if object_type == 'origin_visit': - return { - **object_, - 'date': str(object_['date']), - } - elif object_type == 'origin': - assert 'id' not in object_ - return object_ - - def _write_addition(self, object_type, object_): + raise ValueError("Unknown object type: %s." % object_type) + + def _sanitize_object( + self, object_type: str, object_: ModelObject + ) -> Dict[str, str]: + dict_ = object_.to_dict() + if object_type == "origin_visit": + # :( + dict_["date"] = str(dict_["date"]) + return dict_ + + def _write_addition(self, object_type: str, object_: ModelObject) -> None: """Write a single object to the journal""" - if isinstance(object_, BaseModel): - object_ = object_.to_dict() - topic = '%s.%s' % (self._prefix, object_type) + topic = f"{self._prefix}.{object_type}" key = self._get_key(object_type, object_) dict_ = self._sanitize_object(object_type, object_) - logger.debug('topic: %s, key: %s, value: %s', topic, key, dict_) + logger.debug("topic: %s, key: %s, value: %s", topic, key, dict_) self.send(topic, key=key, value=dict_) - def write_addition(self, object_type, object_): + def write_addition(self, object_type: str, object_: ModelObject) -> None: """Write a single object to the journal""" self._write_addition(object_type, object_) self.flush() write_update = write_addition - def write_additions(self, object_type, objects): + def write_additions(self, object_type: str, objects: List[ModelObject]) -> None: """Write a set of objects to the journal""" for object_ in objects: self._write_addition(object_type, object_) self.flush()