diff --git a/swh/dataset/journalprocessor.py b/swh/dataset/journalprocessor.py --- a/swh/dataset/journalprocessor.py +++ b/swh/dataset/journalprocessor.py @@ -12,9 +12,9 @@ import multiprocessing from pathlib import Path import time -from typing import Any, Dict, Mapping, Sequence, Tuple, Type +from typing import Any, Dict, List, Mapping, Sequence, Tuple, Type -from confluent_kafka import TopicPartition +from confluent_kafka import Message, TopicPartition import tqdm from swh.dataset.exporter import Exporter @@ -55,6 +55,7 @@ self.progress_queue = progress_queue self.refresh_every = refresh_every self.assignment = assignment + self._messages_to_commit: List[Message] = [] self.count = None self.topic_name = None kwargs["stop_on_eof"] = True # Stop when the assignment is empty @@ -68,6 +69,12 @@ [TopicPartition(self.topic_name, pid) for pid in self.assignment] ) + def unsubscribe(self, partitions): + self.assignment = [pid for pid in self.assignment if pid not in partitions] + self.consumer.assign( + [TopicPartition(self.topic_name, pid) for pid in self.assignment] + ) + def process(self, worker_fn): self.count = 0 try: @@ -76,11 +83,14 @@ finally: self.progress_queue.put(None) - def handle_offset(self, partition_id, offset): + def handle_offset(self, message): """ Check whether the client has reached the end of the current partition, and trigger a reassignment if that is the case. """ + offset = message.offset() + partition_id = message.partition() + if offset < 0: # Uninitialized partition offset return @@ -90,10 +100,11 @@ if offset >= self.offset_ranges[partition_id][1] - 1: if partition_id in self.assignment: self.progress_queue.put({partition_id: offset}) - self.assignment = [ - pid for pid in self.assignment if pid != partition_id - ] - self.subscribe() # Actually, unsubscribes from the partition_id + # unsubscribe from partition but make sure current message's + # offset will be committed after executing the worker_fn in + # process(); see handle_messages() below + self._messages_to_commit.append(message) + self.unsubscribe([partition_id]) def deserialize_message(self, message): """ @@ -102,10 +113,17 @@ We also return the raw objects instead of deserializing them because we will need the partition ID later. """ - self.handle_offset(message.partition(), message.offset()) + self.handle_offset(message) self.count += 1 return message + def handle_messages(self, messages, worker_fn): + nb_processed, at_eof = super().handle_messages(messages, worker_fn) + for msg in self._messages_to_commit: + self.consumer.commit(message=msg) + self._messages_to_commit.clear() + return nb_processed, at_eof + class ParallelJournalProcessor: """