Page MenuHomeSoftware Heritage
Paste P1545

stuck async code
ActivePublic

Authored by vlorentz on Jan 4 2023, 10:00 AM.
class _ConcurrentCsvWritingTask(_BaseTask):
"""Base classes for tasks writing a CSV using asyncio.
asyncio is only used for gRPC requires to swh-graph; file writes are synchronous
to keep the code simpler, as performance improvements from making them async
would be negligeable."""
CSV_HEADER: Tuple[str, str]
blob_filter = luigi.ChoiceParameter(choices=list(SELECTION_QUERIES))
derived_datasets_path = luigi.PathParameter()
graph_api = luigi.Parameter("localhost:50091")
stub: "TraversalServiceStub"
def requires(self) -> luigi.Task:
"""Returns an instance of :class:`LocalGraph` and :class:`CreateAthena`"""
return SelectBlobs(
blob_filter=self.blob_filter,
derived_datasets_path=self.derived_datasets_path,
)
def run(self) -> None:
"""Calls the :meth:`process_one` function, and writes its results as
a two-column CSV to the target defined by :meth:`output`.
"""
import asyncio
asyncio.run(self._run_async())
async def _run_async(self) -> None:
import asyncio
import re
import sys
import time
import grpc.aio
import requests
import swh.graph.grpc.swhgraph_pb2 as swhgraph
import swh.graph.grpc.swhgraph_pb2_grpc as swhgraph_grpc
input_queue: asyncio.Queue[Tuple[str, str, str]] = asyncio.Queue(maxsize=20)
result_queue: asyncio.Queue[Tuple[str, str]] = asyncio.Queue(maxsize=20)
async with grpc.aio.insecure_channel(self.graph_api) as channel:
self.stub = swhgraph_grpc.TraversalServiceStub(channel)
fill_queue_task = asyncio.create_task(self._fill_input_queue(input_queue))
write_task = asyncio.create_task(self._write_results(result_queue))
worker_tasks = [
asyncio.create_task(self._worker(input_queue, result_queue))
for _ in range(GRAPH_REQUEST_CONCURRENCY)
]
print("await 1")
await write_task # wait for workers to write everything
print("await 2")
await fill_queue_task # should be instant
print("cancelling")
for task in worker_tasks:
task.cancel()
print("gathering")
await asyncio.gather(
fill_queue_task,
write_task,
*worker_tasks,
return_exceptions=True,
)
print("done")
async def _fill_input_queue(
self, input_queue: "asyncio.Queue[Tuple[str, str, str]]"
) -> None:
for (swhid, sha1, name) in self.iter_blobs(with_tqdm=False, unique_sha1=True):
print("got", swhid, sha1, name)
if not swhid.startswith("swh:1:"):
print("failed")
raise ValueError(f"Invalid SWHID: {swhid}")
print("putting")
await input_queue.put((swhid, sha1, name))
print("end loop")
print("============== done filling")
async def _worker(
self,
input_queue: "asyncio.Queue[Tuple[str, str, str]]",
result_queue: "asyncio.Queue[Tuple[str, str]]",
) -> None:
import swh.graph.grpc.swhgraph_pb2 as swhgraph
from google.protobuf.field_mask_pb2 import FieldMask
while True: # exit via Task.cancel()
row = await input_queue.get()
try:
res = await self.process_one(row)
except:
res = (swhid, "")
logger.exception("Error while processing %r", row)
await result_queue.put(res)
async def _write_results(
self, result_queue: "asyncio.Queue[Tuple[str, str]]"
) -> None:
import tqdm.asyncio
(target,) = self.output()
result_path = Path(target.path)
with atomic_csv_zstd_writer(result_path) as writer:
writer.writerow(self.CSV_HEADER)
assert len(list(self.iter_blobs(with_tqdm=False, unique_sha1=True))) == self.blob_count()
assert self.blob_count() == len(list(tqdm.trange(self.blob_count())))
blob_count = self.blob_count()
async for i in tqdm.asyncio.trange(self.blob_count()):
print(f"{i+1}/{blob_count}")
(swhid, result) = await result_queue.get()
print("got result", swhid, result)
writer.writerow((swhid, result))
print("wrote row")
print("end for")