diff --git a/swh/dataset/journalprocessor.py b/swh/dataset/journalprocessor.py --- a/swh/dataset/journalprocessor.py +++ b/swh/dataset/journalprocessor.py @@ -12,7 +12,7 @@ import multiprocessing from pathlib import Path import time -from typing import Any, Dict, Mapping, Sequence, Tuple, Type +from typing import Any, Container, Dict, Mapping, Optional, Sequence, Tuple, Type from confluent_kafka import TopicPartition import tqdm @@ -56,7 +56,7 @@ self.refresh_every = refresh_every self.assignment = assignment self.count = None - self.topic_name = None + self.topic_name: Optional[str] = None kwargs["stop_on_eof"] = True # Stop when the assignment is empty super().__init__(*args, **kwargs) @@ -68,6 +68,13 @@ [TopicPartition(self.topic_name, pid) for pid in self.assignment] ) + def unsubscribe(self, partitions: Container[int]): + assert self.assignment is not None + self.assignment = [pid for pid in self.assignment if pid not in partitions] + self.consumer.assign( + [TopicPartition(self.topic_name, pid) for pid in self.assignment] + ) + def process(self, worker_fn): self.count = 0 try: @@ -90,10 +97,7 @@ if offset >= self.offset_ranges[partition_id][1] - 1: if partition_id in self.assignment: self.progress_queue.put({partition_id: offset}) - self.assignment = [ - pid for pid in self.assignment if pid != partition_id - ] - self.subscribe() # Actually, unsubscribes from the partition_id + self.unsubscribe([partition_id]) def deserialize_message(self, message): """