diff --git a/swh/journal/client.py b/swh/journal/client.py --- a/swh/journal/client.py +++ b/swh/journal/client.py @@ -4,10 +4,14 @@ # See top-level LICENSE file for more information from collections import defaultdict +from enum import Enum from importlib import import_module from itertools import cycle import logging import os +import queue +from threading import Thread +import time from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from confluent_kafka import Consumer, KafkaError, KafkaException @@ -20,7 +24,6 @@ logger = logging.getLogger(__name__) rdkafka_logger = logging.getLogger(__name__ + ".rdkafka") - # Only accepted offset reset policy accepted ACCEPTED_OFFSET_RESET = ["earliest", "latest"] @@ -34,6 +37,14 @@ JOURNAL_STATUS_METRIC = "swh_journal_client_status" +class PollThreadState(Enum): + INITIALIZING = "INITIALIZING" + WAITING = "WAITING" + PROCESSING = "PROCESSING" + REBALANCING = "REBALANCING" + TERMINATING = "TERMINATING" + + def get_journal_client(cls: str, **kwargs: Any): """Factory function to instantiate a journal client object. @@ -215,6 +226,11 @@ logger.debug(" %s: %s", k, v) self.consumer = Consumer(consumer_settings) + + self.poll_thread = Thread(target=self.kafka_poll_thread) + self.poll_thread_queue = queue.Queue() + self.poll_thread.start() + if privileged: privileged_prefix = f"{prefix}_privileged" else: # do not attempt to subscribe to privileged topics @@ -266,13 +282,67 @@ "JournalClient; please remove it from your configuration.", ) + def kafka_poll_thread(self): + prev_state = PollThreadState.INITIALIZING + state = PollThreadState.INITIALIZING + last_state_change = time.monotonic() + paused_partitions = False + while True: + try: + new_state = self.poll_thread_queue.get(timeout=0.1) + except queue.Empty: + pass + else: + if new_state != state: + logger.debug("Poll thread now %s", new_state) + if state != PollThreadState.REBALANCING: + prev_state = state + state = new_state + last_state_change = time.monotonic() + + now = time.monotonic() + + if state == PollThreadState.INITIALIZING: + continue + elif state == PollThreadState.REBALANCING: + paused_partitions = True + self.poll_thread_queue.put(prev_state) + elif state == PollThreadState.WAITING: + if paused_partitions: + self.consumer.resume(self.consumer.assignment()) + paused_partitions = False + elif state == PollThreadState.PROCESSING: + if not paused_partitions and now - last_state_change > 15: + self.consumer.pause(self.consumer.assignment()) + paused_partitions = True + if paused_partitions: + msg = self.consumer.poll(timeout=10) + if not msg: + continue + error = msg.error() + if not error: + raise ValueError("poll thread got a non-error message?") + _error_cb(error) + elif state == PollThreadState.TERMINATING: + break + else: + logger.warning("Unknown poll_thread_state: %s", state) + + def rebalance_cb(self, consumer, partitions): + consumer.pause(partitions) + self.poll_thread_queue.put(PollThreadState.REBALANCING) + def subscribe(self): """Subscribe to topics listed in self.subscription This can be overridden if you need, for instance, to manually assign partitions. """ logger.debug(f"Subscribing to: {self.subscription}") - self.consumer.subscribe(topics=self.subscription) + self.consumer.subscribe( + topics=self.subscription, + on_assign=self.rebalance_cb, + on_revoke=self.rebalance_cb, + ) def process(self, worker_fn): """Polls Kafka for a batch of messages, and calls the worker_fn @@ -303,6 +373,7 @@ batch_size, ) set_status("waiting") + self.poll_thread_queue.put(PollThreadState.WAITING) for i in cycle(reversed(range(10))): messages = self.consumer.consume( timeout=timeout, num_messages=batch_size @@ -322,9 +393,11 @@ break if messages: set_status("processing") + self.poll_thread_queue.put(PollThreadState.PROCESSING) batch_processed, at_eof = self.handle_messages(messages, worker_fn) set_status("idle") + self.poll_thread_queue.put(PollThreadState.WAITING) # report the number of handled messages statsd.increment( JOURNAL_MESSAGE_NUMBER_METRIC, value=batch_processed @@ -374,4 +447,6 @@ return self.value_deserializer(object_type, message.value()) def close(self): + self.poll_thread_queue.put(PollThreadState.TERMINATING) + self.poll_thread.join() self.consumer.close()