diff --git a/swh/objstorage/backends/azure.py b/swh/objstorage/backends/azure.py --- a/swh/objstorage/backends/azure.py +++ b/swh/objstorage/backends/azure.py @@ -1,8 +1,10 @@ -# Copyright (C) 2016-2020 The Software Heritage developers +# Copyright (C) 2016-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import asyncio +import contextlib import datetime from itertools import product import string @@ -15,6 +17,7 @@ ContainerSasPermissions, generate_container_sas, ) +from azure.storage.blob.aio import ContainerClient as AsyncContainerClient from swh.model import hashutil from swh.objstorage.exc import Error, ObjNotFoundError @@ -134,12 +137,28 @@ """ return ContainerClient.from_container_url(self.container_url) + @contextlib.asynccontextmanager + async def get_async_container_clients(self): + """Returns a collection of container clients, to be passed to + ``get_async_blob_client``. + + Each container may not be used in more than one asyncio loop.""" + client = AsyncContainerClient.from_container_url(self.container_url) + async with client: + yield {"": client} + def get_blob_client(self, hex_obj_id): """Get the azure blob client for the given hex obj id""" container_client = self.get_container_client(hex_obj_id) return container_client.get_blob_client(blob=hex_obj_id) + def get_async_blob_client(self, hex_obj_id, container_clients): + """Get the azure blob client for the given hex obj id and a collection + yielded by ``get_async_container_clients``.""" + + return container_clients[""].get_blob_client(blob=hex_obj_id) + def get_all_container_clients(self): """Get all active block_blob_services""" yield self.get_container_client("") @@ -254,6 +273,52 @@ raise Error("Corrupt object %s: trailing data found" % hex_obj_id) return ret + async def _get_async(self, obj_id, container_clients): + """Like ``get(obj_id)``, but asynchronous""" + hex_obj_id = self._internal_id(obj_id) + client = self.get_async_blob_client(hex_obj_id, container_clients) + + try: + download = await client.download_blob() + except ResourceNotFoundError: + raise ObjNotFoundError(obj_id) from None + else: + data = await download.content_as_bytes() + + decompressor = decompressors[self.compression]() + ret = decompressor.decompress(data) + if decompressor.unused_data: + raise Error("Corrupt object %s: trailing data found" % hex_obj_id) + return ret + + async def _get_async_or_none(self, obj_id, container_clients): + """Like ``get_async(obj_id)``, but returns None instead of raising + ResourceNotFoundError. Used by ``get_batch`` so other blobs can be returned + even if one is missing.""" + try: + return await self._get_async(obj_id, container_clients) + except ObjNotFoundError: + return None + + async def _get_batch_async(self, obj_ids): + async with self.get_async_container_clients() as container_clients: + return await asyncio.gather( + *[ + self._get_async_or_none(obj_id, container_clients) + for obj_id in obj_ids + ] + ) + + def get_batch(self, obj_ids): + """Retrieve objects' raw content in bulk from storage, concurrently.""" + loop = asyncio.new_event_loop() + try: + result = loop.run_until_complete(self._get_batch_async(obj_ids)) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() + return result + def check(self, obj_id): """Check the content integrity. @@ -344,6 +409,30 @@ prefix = hex_obj_id[: self.prefix_len] return ContainerClient.from_container_url(self.container_urls[prefix]) + @contextlib.asynccontextmanager + async def get_async_container_clients(self): + # This is equivalent to: + # client1 = AsyncContainerClient.from_container_url(url1) + # ... + # client16 = AsyncContainerClient.from_container_url(url16) + # async with client1, ..., client16: + # yield {prefix1: client1, ..., prefix16: client16} + clients = { + prefix: AsyncContainerClient.from_container_url(url) + for (prefix, url) in self.container_urls.items() + } + async with contextlib.AsyncExitStack() as stack: + for client in clients.values(): + await stack.enter_async_context(client) + yield clients + + def get_async_blob_client(self, hex_obj_id, container_clients): + """Get the azure blob client for the given hex obj id and a collection + yielded by ``get_async_container_clients``.""" + + prefix = hex_obj_id[: self.prefix_len] + return container_clients[prefix].get_blob_client(blob=hex_obj_id) + def get_all_container_clients(self): """Get all active container clients""" # iterate on items() to sort blob services; diff --git a/swh/objstorage/tests/test_objstorage_azure.py b/swh/objstorage/tests/test_objstorage_azure.py --- a/swh/objstorage/tests/test_objstorage_azure.py +++ b/swh/objstorage/tests/test_objstorage_azure.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import asyncio import base64 import collections from dataclasses import dataclass @@ -27,6 +28,16 @@ name: str +class MockAsyncDownloadClient: + def __init__(self, blob_data): + self.blob_data = blob_data + + def content_as_bytes(self): + future = asyncio.Future() + future.set_result(self.blob_data) + return future + + class MockDownloadClient: def __init__(self, blob_data): self.blob_data = blob_data @@ -34,6 +45,10 @@ def content_as_bytes(self): return self.blob_data + def __await__(self): + yield from () + return MockAsyncDownloadClient(self.blob_data) + class MockBlobClient: def __init__(self, container, blob): @@ -93,6 +108,17 @@ def delete_blob(self, blob): self.get_blob_client(blob.name).delete_blob() + def __aenter__(self): + return self + + def __await__(self): + future = asyncio.Future() + future.set_result(self) + yield from future + + def __aexit__(self, *args): + return self + return MockContainerClient @@ -108,6 +134,12 @@ patcher.start() self.addCleanup(patcher.stop) + patcher = patch( + "swh.objstorage.backends.azure.AsyncContainerClient", ContainerClient + ) + patcher.start() + self.addCleanup(patcher.stop) + self.storage = get_objstorage( "azure", { @@ -175,6 +207,12 @@ patcher.start() self.addCleanup(patcher.stop) + patcher = patch( + "swh.objstorage.backends.azure.AsyncContainerClient", self.ContainerClient + ) + patcher.start() + self.addCleanup(patcher.stop) + self.accounts = {} for prefix in "0123456789abcdef": self.accounts[prefix] = "https://bogus-container-url.example/" + prefix