class _ConcurrentCsvWritingTask(_BaseTask): """Base classes for tasks writing a CSV using asyncio. asyncio is only used for gRPC requires to swh-graph; file writes are synchronous to keep the code simpler, as performance improvements from making them async would be negligeable.""" CSV_HEADER: Tuple[str, str] blob_filter = luigi.ChoiceParameter(choices=list(SELECTION_QUERIES)) derived_datasets_path = luigi.PathParameter() graph_api = luigi.Parameter("localhost:50091") stub: "TraversalServiceStub" def requires(self) -> luigi.Task: """Returns an instance of :class:`LocalGraph` and :class:`CreateAthena`""" return SelectBlobs( blob_filter=self.blob_filter, derived_datasets_path=self.derived_datasets_path, ) def run(self) -> None: """Calls the :meth:`process_one` function, and writes its results as a two-column CSV to the target defined by :meth:`output`. """ import asyncio asyncio.run(self._run_async()) async def _run_async(self) -> None: import asyncio import re import sys import time import grpc.aio import requests import swh.graph.grpc.swhgraph_pb2 as swhgraph import swh.graph.grpc.swhgraph_pb2_grpc as swhgraph_grpc input_queue: asyncio.Queue[Tuple[str, str, str]] = asyncio.Queue(maxsize=20) result_queue: asyncio.Queue[Tuple[str, str]] = asyncio.Queue(maxsize=20) async with grpc.aio.insecure_channel(self.graph_api) as channel: self.stub = swhgraph_grpc.TraversalServiceStub(channel) fill_queue_task = asyncio.create_task(self._fill_input_queue(input_queue)) write_task = asyncio.create_task(self._write_results(result_queue)) worker_tasks = [ asyncio.create_task(self._worker(input_queue, result_queue)) for _ in range(GRAPH_REQUEST_CONCURRENCY) ] print("await 1") await write_task # wait for workers to write everything print("await 2") await fill_queue_task # should be instant print("cancelling") for task in worker_tasks: task.cancel() print("gathering") await asyncio.gather( fill_queue_task, write_task, *worker_tasks, return_exceptions=True, ) print("done") async def _fill_input_queue( self, input_queue: "asyncio.Queue[Tuple[str, str, str]]" ) -> None: for (swhid, sha1, name) in self.iter_blobs(with_tqdm=False, unique_sha1=True): print("got", swhid, sha1, name) if not swhid.startswith("swh:1:"): print("failed") raise ValueError(f"Invalid SWHID: {swhid}") print("putting") await input_queue.put((swhid, sha1, name)) print("end loop") print("============== done filling") async def _worker( self, input_queue: "asyncio.Queue[Tuple[str, str, str]]", result_queue: "asyncio.Queue[Tuple[str, str]]", ) -> None: import swh.graph.grpc.swhgraph_pb2 as swhgraph from google.protobuf.field_mask_pb2 import FieldMask while True: # exit via Task.cancel() row = await input_queue.get() try: res = await self.process_one(row) except: res = (swhid, "") logger.exception("Error while processing %r", row) await result_queue.put(res) async def _write_results( self, result_queue: "asyncio.Queue[Tuple[str, str]]" ) -> None: import tqdm.asyncio (target,) = self.output() result_path = Path(target.path) with atomic_csv_zstd_writer(result_path) as writer: writer.writerow(self.CSV_HEADER) assert len(list(self.iter_blobs(with_tqdm=False, unique_sha1=True))) == self.blob_count() assert self.blob_count() == len(list(tqdm.trange(self.blob_count()))) blob_count = self.blob_count() async for i in tqdm.asyncio.trange(self.blob_count()): print(f"{i+1}/{blob_count}") (swhid, result) = await result_queue.get() print("got result", swhid, result) writer.writerow((swhid, result)) print("wrote row") print("end for")