diff --git a/swh/objstorage/backends/azure.py b/swh/objstorage/backends/azure.py --- a/swh/objstorage/backends/azure.py +++ b/swh/objstorage/backends/azure.py @@ -38,6 +38,9 @@ container_name: str, access_policy: str = "read_only", expiry: datetime.timedelta = datetime.timedelta(days=365), + container_url_template: str = ( + "https://{account_name}.blob.core.windows.net/{container_name}?{signature}" + ), **kwargs, ) -> str: """Get the full url, for the given container on the given account, with a @@ -76,7 +79,11 @@ expiry=current_time + expiry, ) - return f"https://{account_name}.blob.core.windows.net/{container_name}?{signature}" + return container_url_template.format( + account_name=account_name, + container_name=container_name, + signature=signature, + ) class AzureCloudObjStorage(ObjStorage): @@ -106,14 +113,15 @@ account_name: Optional[str] = None, api_secret_key: Optional[str] = None, container_name: Optional[str] = None, + connection_string: Optional[str] = None, compression="gzip", **kwargs, ): - if container_url is None: + if container_url is None and connection_string is None: if account_name is None or api_secret_key is None or container_name is None: raise ValueError( - "AzureCloudObjStorage must have a container_url or all three " - "account_name, api_secret_key and container_name" + "AzureCloudObjStorage must have a container_url, a connection_string," + "or all three account_name, api_secret_key and container_name" ) else: warnings.warn( @@ -127,9 +135,16 @@ container_name=container_name, access_policy="full", ) + elif connection_string: + if container_name is None: + raise ValueError( + "container_name is required when using connection_string." + ) + self.container_name = container_name super().__init__(**kwargs) self.container_url = container_url + self.connection_string = connection_string self.compression = compression def get_container_client(self, hex_obj_id): @@ -140,7 +155,12 @@ client according to the prefix of the object id. """ - return ContainerClient.from_container_url(self.container_url) + if self.connection_string: + return ContainerClient.from_connection_string( + self.connection_string, self.container_name + ) + else: + return ContainerClient.from_container_url(self.container_url) @contextlib.asynccontextmanager async def get_async_container_clients(self): @@ -148,7 +168,12 @@ ``get_async_blob_client``. Each container may not be used in more than one asyncio loop.""" - client = AsyncContainerClient.from_container_url(self.container_url) + if self.connection_string: + client = AsyncContainerClient.from_connection_string( + self.connection_string, self.container_name + ) + else: + client = AsyncContainerClient.from_container_url(self.container_url) async with client: yield {"": client} diff --git a/swh/objstorage/tests/test_objstorage_azure.py b/swh/objstorage/tests/test_objstorage_azure.py --- a/swh/objstorage/tests/test_objstorage_azure.py +++ b/swh/objstorage/tests/test_objstorage_azure.py @@ -1,4 +1,4 @@ -# Copyright (C) 2016-2021 The Software Heritage developers +# Copyright (C) 2016-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -7,11 +7,17 @@ import base64 import collections from dataclasses import dataclass +import os +import secrets +import shutil +import subprocess +import tempfile import unittest from unittest.mock import patch from urllib.parse import parse_qs, urlparse from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError +from azure.storage.blob import BlobServiceClient import pytest from swh.model.hashutil import hash_to_hex @@ -22,6 +28,10 @@ from .objstorage_testing import ObjStorageTestFixture +AZURITE_EXE = shutil.which( + "azurite-blob", path=os.environ.get("AZURITE_PATH", os.environ.get("PATH")) +) + @dataclass class MockListedObject: @@ -83,6 +93,78 @@ del self.container.blobs[self.blob] +@pytest.mark.skipif(not AZURITE_EXE, reason="azurite not found in AZURITE_PATH or PATH") +class TestAzuriteCloudObjStorage(ObjStorageTestFixture, unittest.TestCase): + compression = "none" + + @classmethod + def setUpClass(cls): + super().setUpClass() + + host = "127.0.0.1" + + cls._azurite_path = tempfile.mkdtemp() + + cls._azurite_proc = subprocess.Popen( + [ + AZURITE_EXE, + "--blobHost", + host, + "--blobPort", + "0", + ], + stdout=subprocess.PIPE, + cwd=cls._azurite_path, + ) + + prefix = b"Azurite Blob service successfully listens on " + for line in cls._azurite_proc.stdout: + if line.startswith(prefix): + base_url = line[len(prefix) :].decode().strip() + break + else: + assert False, "Did not get Azurite Blob service port." + + # https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azurite#well-known-storage-account-and-key + account_name = "devstoreaccount1" + account_key = ( + "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq" + "/K1SZFPTOtr/KBHBeksoGMGw==" + ) + + container_url = f"{base_url}/{account_name}" + cls._connection_string = ( + f"DefaultEndpointsProtocol=https;" + f"AccountName={account_name};" + f"AccountKey={account_key};" + f"BlobEndpoint={container_url};" + ) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + cls._azurite_proc.kill() + cls._azurite_proc.wait(2) + shutil.rmtree(cls._azurite_path) + + def setUp(self): + super().setUp() + self._container_name = secrets.token_hex(10) + client = BlobServiceClient.from_connection_string(self._connection_string) + client.create_container(self._container_name) + + self.storage = get_objstorage( + "azure", + connection_string=self._connection_string, + container_name=self._container_name, + compression=self.compression, + ) + + +class TestAzuriteCloudObjStorageGzip(TestAzuriteCloudObjStorage): + compression = "gzip" + + def get_MockContainerClient(): blobs = collections.defaultdict(dict) # {container_url: {blob_id: blob}} @@ -122,7 +204,7 @@ return MockContainerClient -class TestAzureCloudObjStorage(ObjStorageTestFixture, unittest.TestCase): +class TestMockedAzureCloudObjStorage(ObjStorageTestFixture, unittest.TestCase): compression = "none" def setUp(self): @@ -179,19 +261,19 @@ assert "trailing data" in e.exception.args[0] -class TestAzureCloudObjStorageGzip(TestAzureCloudObjStorage): +class TestMockedAzureCloudObjStorageGzip(TestMockedAzureCloudObjStorage): compression = "gzip" -class TestAzureCloudObjStorageZlib(TestAzureCloudObjStorage): +class TestMockedAzureCloudObjStorageZlib(TestMockedAzureCloudObjStorage): compression = "zlib" -class TestAzureCloudObjStorageLzma(TestAzureCloudObjStorage): +class TestMockedAzureCloudObjStorageLzma(TestMockedAzureCloudObjStorage): compression = "lzma" -class TestAzureCloudObjStorageBz2(TestAzureCloudObjStorage): +class TestMockedAzureCloudObjStorageBz2(TestMockedAzureCloudObjStorage): compression = "bz2" diff --git a/tox.ini b/tox.ini --- a/tox.ini +++ b/tox.ini @@ -4,6 +4,8 @@ [testenv] extras = testing +passenv = + AZURITE_PATH deps = pytest-cov dev: pdbpp