diff --git a/swh/objstorage/backends/azure.py b/swh/objstorage/backends/azure.py --- a/swh/objstorage/backends/azure.py +++ b/swh/objstorage/backends/azure.py @@ -121,7 +121,7 @@ ) super().__init__(**kwargs) - self.container_client = ContainerClient.from_container_url(container_url) + self.container_url = container_url self.compression = compression def get_container_client(self, hex_obj_id): @@ -132,7 +132,7 @@ client according to the prefix of the object id. """ - return self.container_client + return ContainerClient.from_container_url(self.container_url) def get_blob_client(self, hex_obj_id): """Get the azure blob client for the given hex obj id""" @@ -142,7 +142,7 @@ def get_all_container_clients(self): """Get all active block_blob_services""" - yield self.container_client + yield self.get_container_client("") def _internal_id(self, obj_id): """Internal id is the hex version in objstorage. @@ -318,7 +318,7 @@ do_warning = False - self.prefixes = {} + self.container_urls = {} for prefix, container_url in accounts.items(): if isinstance(container_url, dict): do_warning = True @@ -328,7 +328,7 @@ container_name=container_url["container_name"], access_policy="full", ) - self.prefixes[prefix] = ContainerClient.from_container_url(container_url) + self.container_urls[prefix] = container_url if do_warning: warnings.warn( @@ -341,10 +341,13 @@ """Get the block_blob_service and container that contains the object with internal id hex_obj_id """ - return self.prefixes[hex_obj_id[: self.prefix_len]] + prefix = hex_obj_id[: self.prefix_len] + return ContainerClient.from_container_url(self.container_urls[prefix]) def get_all_container_clients(self): """Get all active container clients""" # iterate on items() to sort blob services; # needed to be able to paginate in the list_content() method - yield from (v for _, v in sorted(self.prefixes.items())) + yield from ( + self.get_container_client(prefix) for prefix in sorted(self.container_urls) + ) diff --git a/swh/objstorage/tests/test_objstorage_azure.py b/swh/objstorage/tests/test_objstorage_azure.py --- a/swh/objstorage/tests/test_objstorage_azure.py +++ b/swh/objstorage/tests/test_objstorage_azure.py @@ -1,9 +1,10 @@ -# Copyright (C) 2016-2020 The Software Heritage developers +# Copyright (C) 2016-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 +import collections from dataclasses import dataclass import unittest from unittest.mock import patch @@ -67,27 +68,32 @@ del self.container.blobs[self.blob] -class MockContainerClient: - def __init__(self, container_url): - self.container_url = container_url - self.blobs = {} +def get_MockContainerClient(): + blobs = collections.defaultdict(dict) # {container_url: {blob_id: blob}} - @classmethod - def from_container_url(cls, container_url): - return cls(container_url) + class MockContainerClient: + def __init__(self, container_url): + self.container_url = container_url + self.blobs = blobs[self.container_url] - def get_container_properties(self): - return {"exists": True} + @classmethod + def from_container_url(cls, container_url): + return cls(container_url) + + def get_container_properties(self): + return {"exists": True} + + def get_blob_client(self, blob): + return MockBlobClient(self, blob) - def get_blob_client(self, blob): - return MockBlobClient(self, blob) + def list_blobs(self): + for obj in sorted(self.blobs): + yield MockListedObject(obj) - def list_blobs(self): - for obj in sorted(self.blobs): - yield MockListedObject(obj) + def delete_blob(self, blob): + self.get_blob_client(blob.name).delete_blob() - def delete_blob(self, blob): - self.get_blob_client(blob.name).delete_blob() + return MockContainerClient class TestAzureCloudObjStorage(ObjStorageTestFixture, unittest.TestCase): @@ -95,8 +101,9 @@ def setUp(self): super().setUp() + ContainerClient = get_MockContainerClient() patcher = patch( - "swh.objstorage.backends.azure.ContainerClient", MockContainerClient, + "swh.objstorage.backends.azure.ContainerClient", ContainerClient ) patcher.start() self.addCleanup(patcher.stop) @@ -161,8 +168,9 @@ class TestPrefixedAzureCloudObjStorage(ObjStorageTestFixture, unittest.TestCase): def setUp(self): super().setUp() + self.ContainerClient = get_MockContainerClient() patcher = patch( - "swh.objstorage.backends.azure.ContainerClient", MockContainerClient + "swh.objstorage.backends.azure.ContainerClient", self.ContainerClient ) patcher.start() self.addCleanup(patcher.stop) @@ -193,7 +201,7 @@ hex_obj_id = hash_to_hex(obj_id) prefix = hex_obj_id[0] self.assertTrue( - self.storage.prefixes[prefix] + self.ContainerClient(self.storage.container_urls[prefix]) .get_blob_client(hex_obj_id) .get_blob_properties() ) @@ -231,7 +239,7 @@ def test_bwcompat_args(monkeypatch): monkeypatch.setattr( - swh.objstorage.backends.azure, "ContainerClient", MockContainerClient, + swh.objstorage.backends.azure, "ContainerClient", get_MockContainerClient(), ) with pytest.deprecated_call(): @@ -249,7 +257,7 @@ def test_bwcompat_args_prefixed(monkeypatch): monkeypatch.setattr( - swh.objstorage.backends.azure, "ContainerClient", MockContainerClient, + swh.objstorage.backends.azure, "ContainerClient", get_MockContainerClient(), ) accounts = {