Changeset View
Changeset View
Standalone View
Standalone View
swh/objstorage/backends/azure.py
# Copyright (C) 2016-2020 The Software Heritage developers | # Copyright (C) 2016-2020 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
import logging | |||||
import string | import string | ||||
from itertools import dropwhile, islice, product | from itertools import product | ||||
from azure.storage.blob import BlockBlobService | from azure.storage.blob import ContainerClient | ||||
from azure.common import AzureMissingResourceHttpError | from azure.core.exceptions import ResourceNotFoundError | ||||
import requests | |||||
from swh.objstorage.objstorage import ( | from swh.objstorage.objstorage import ( | ||||
ObjStorage, | ObjStorage, | ||||
compute_hash, | compute_hash, | ||||
DEFAULT_LIMIT, | |||||
compressors, | compressors, | ||||
decompressors, | decompressors, | ||||
) | ) | ||||
from swh.objstorage.exc import ObjNotFoundError, Error | from swh.objstorage.exc import ObjNotFoundError, Error | ||||
from swh.model import hashutil | from swh.model import hashutil | ||||
logging.getLogger("azure.storage").setLevel(logging.CRITICAL) | |||||
class AzureCloudObjStorage(ObjStorage): | class AzureCloudObjStorage(ObjStorage): | ||||
"""ObjStorage with azure abilities. | """ObjStorage with azure abilities. | ||||
""" | """ | ||||
def __init__( | def __init__(self, container_url, compression="gzip", **kwargs): | ||||
self, account_name, api_secret_key, container_name, compression="gzip", **kwargs | |||||
): | |||||
super().__init__(**kwargs) | super().__init__(**kwargs) | ||||
self.block_blob_service = BlockBlobService( | self.container_client = ContainerClient.from_container_url(container_url) | ||||
account_name=account_name, | |||||
account_key=api_secret_key, | |||||
request_session=requests.Session(), | |||||
) | |||||
self.container_name = container_name | |||||
self.compression = compression | self.compression = compression | ||||
def get_blob_service(self, hex_obj_id): | def get_container_client(self, hex_obj_id): | ||||
"""Get the block_blob_service and container that contains the object with | """Get the container client for the container that contains the object with | ||||
internal id hex_obj_id | internal id hex_obj_id | ||||
""" | """ | ||||
return self.block_blob_service, self.container_name | return self.container_client | ||||
douardda: this method looks weird as is. Why does it exists? (not read the whole file yet, but my bet is… | |||||
def get_blob_client(self, hex_obj_id): | |||||
"""Get the azure blob client for the given hex obj id""" | |||||
container_client = self.get_container_client(hex_obj_id) | |||||
def get_all_blob_services(self): | return container_client.get_blob_client(blob=hex_obj_id) | ||||
def get_all_container_clients(self): | |||||
"""Get all active block_blob_services""" | """Get all active block_blob_services""" | ||||
yield self.block_blob_service, self.container_name | yield self.container_client | ||||
def _internal_id(self, obj_id): | def _internal_id(self, obj_id): | ||||
"""Internal id is the hex version in objstorage. | """Internal id is the hex version in objstorage. | ||||
""" | """ | ||||
return hashutil.hash_to_hex(obj_id) | return hashutil.hash_to_hex(obj_id) | ||||
def check_config(self, *, check_write): | def check_config(self, *, check_write): | ||||
"""Check the configuration for this object storage""" | """Check the configuration for this object storage""" | ||||
for service, container in self.get_all_blob_services(): | for container_client in self.get_all_container_clients(): | ||||
props = service.get_container_properties(container) | props = container_client.get_container_properties() | ||||
# FIXME: check_write is ignored here | # FIXME: check_write is ignored here | ||||
if not props: | if not props: | ||||
return False | return False | ||||
return True | return True | ||||
def __contains__(self, obj_id): | def __contains__(self, obj_id): | ||||
"""Does the storage contains the obj_id. | """Does the storage contains the obj_id. | ||||
""" | """ | ||||
hex_obj_id = self._internal_id(obj_id) | hex_obj_id = self._internal_id(obj_id) | ||||
service, container = self.get_blob_service(hex_obj_id) | client = self.get_blob_client(hex_obj_id) | ||||
return service.exists(container_name=container, blob_name=hex_obj_id) | try: | ||||
client.get_blob_properties() | |||||
except ResourceNotFoundError: | |||||
return False | |||||
else: | |||||
return True | |||||
def __iter__(self): | def __iter__(self): | ||||
"""Iterate over the objects present in the storage. | """Iterate over the objects present in the storage. | ||||
""" | """ | ||||
for service, container in self.get_all_blob_services(): | for client in self.get_all_container_clients(): | ||||
for obj in service.list_blobs(container): | for obj in client.list_blobs(): | ||||
yield hashutil.hash_to_bytes(obj.name) | yield hashutil.hash_to_bytes(obj.name) | ||||
def __len__(self): | def __len__(self): | ||||
"""Compute the number of objects in the current object storage. | """Compute the number of objects in the current object storage. | ||||
Returns: | Returns: | ||||
number of objects contained in the storage. | number of objects contained in the storage. | ||||
Show All 10 Lines | def add(self, content, obj_id=None, check_presence=True): | ||||
if check_presence and obj_id in self: | if check_presence and obj_id in self: | ||||
return obj_id | return obj_id | ||||
hex_obj_id = self._internal_id(obj_id) | hex_obj_id = self._internal_id(obj_id) | ||||
# Send the compressed content | # Send the compressed content | ||||
compressor = compressors[self.compression]() | compressor = compressors[self.compression]() | ||||
blob = [compressor.compress(content), compressor.flush()] | data = compressor.compress(content) | ||||
data += compressor.flush() | |||||
service, container = self.get_blob_service(hex_obj_id) | client = self.get_blob_client(hex_obj_id) | ||||
service.create_blob_from_bytes( | client.upload_blob(data=data, length=len(data)) | ||||
container_name=container, blob_name=hex_obj_id, blob=b"".join(blob), | |||||
) | |||||
return obj_id | return obj_id | ||||
def restore(self, content, obj_id=None): | def restore(self, content, obj_id=None): | ||||
"""Restore a content. | """Restore a content. | ||||
""" | """ | ||||
if obj_id is None: | |||||
# Checksum is missing, compute it on the fly. | |||||
obj_id = compute_hash(content) | |||||
if obj_id in self: | |||||
self.delete(obj_id) | |||||
return self.add(content, obj_id, check_presence=False) | return self.add(content, obj_id, check_presence=False) | ||||
def get(self, obj_id): | def get(self, obj_id): | ||||
"""Retrieve blob's content if found. | """Retrieve blob's content if found. | ||||
""" | """ | ||||
hex_obj_id = self._internal_id(obj_id) | hex_obj_id = self._internal_id(obj_id) | ||||
service, container = self.get_blob_service(hex_obj_id) | client = self.get_blob_client(hex_obj_id) | ||||
try: | try: | ||||
blob = service.get_blob_to_bytes( | download = client.download_blob() | ||||
container_name=container, blob_name=hex_obj_id | except ResourceNotFoundError: | ||||
) | raise ObjNotFoundError(obj_id) from None | ||||
except AzureMissingResourceHttpError: | else: | ||||
raise ObjNotFoundError(obj_id) | data = download.content_as_bytes() | ||||
decompressor = decompressors[self.compression]() | decompressor = decompressors[self.compression]() | ||||
ret = decompressor.decompress(blob.content) | ret = decompressor.decompress(data) | ||||
if decompressor.unused_data: | if decompressor.unused_data: | ||||
raise Error("Corrupt object %s: trailing data found" % hex_obj_id) | raise Error("Corrupt object %s: trailing data found" % hex_obj_id) | ||||
return ret | return ret | ||||
def check(self, obj_id): | def check(self, obj_id): | ||||
"""Check the content integrity. | """Check the content integrity. | ||||
""" | """ | ||||
obj_content = self.get(obj_id) | obj_content = self.get(obj_id) | ||||
content_obj_id = compute_hash(obj_content) | content_obj_id = compute_hash(obj_content) | ||||
if content_obj_id != obj_id: | if content_obj_id != obj_id: | ||||
raise Error(obj_id) | raise Error(obj_id) | ||||
def delete(self, obj_id): | def delete(self, obj_id): | ||||
"""Delete an object.""" | """Delete an object.""" | ||||
super().delete(obj_id) # Check delete permission | super().delete(obj_id) # Check delete permission | ||||
hex_obj_id = self._internal_id(obj_id) | hex_obj_id = self._internal_id(obj_id) | ||||
service, container = self.get_blob_service(hex_obj_id) | client = self.get_blob_client(hex_obj_id) | ||||
try: | try: | ||||
service.delete_blob(container_name=container, blob_name=hex_obj_id) | client.delete_blob() | ||||
except AzureMissingResourceHttpError: | except ResourceNotFoundError: | ||||
raise ObjNotFoundError("Content {} not found!".format(hex_obj_id)) | raise ObjNotFoundError(obj_id) from None | ||||
return True | return True | ||||
def list_content(self, last_obj_id=None, limit=DEFAULT_LIMIT): | |||||
all_blob_services = self.get_all_blob_services() | |||||
if last_obj_id: | |||||
last_obj_id = self._internal_id(last_obj_id) | |||||
last_service, _ = self.get_blob_service(last_obj_id) | |||||
all_blob_services = dropwhile( | |||||
lambda srv: srv[0] != last_service, all_blob_services | |||||
) | |||||
else: | |||||
last_service = None | |||||
def iterate_blobs(): | |||||
for service, container in all_blob_services: | |||||
marker = last_obj_id if service == last_service else None | |||||
for obj in service.list_blobs( | |||||
container, marker=marker, maxresults=limit | |||||
): | |||||
yield hashutil.hash_to_bytes(obj.name) | |||||
return islice(iterate_blobs(), limit) | |||||
class PrefixedAzureCloudObjStorage(AzureCloudObjStorage): | class PrefixedAzureCloudObjStorage(AzureCloudObjStorage): | ||||
"""ObjStorage with azure capabilities, striped by prefix. | """ObjStorage with azure capabilities, striped by prefix. | ||||
accounts is a dict containing entries of the form: | accounts is a dict containing entries of the form: | ||||
<prefix>: | <prefix>: <container_url_for_prefix> | ||||
account_name: <account_name> | |||||
api_secret_key: <api_secret_key> | |||||
container_name: <container_name> | |||||
""" | """ | ||||
def __init__(self, accounts, compression="gzip", **kwargs): | def __init__(self, accounts, compression="gzip", **kwargs): | ||||
# shortcut AzureCloudObjStorage __init__ | # shortcut AzureCloudObjStorage __init__ | ||||
ObjStorage.__init__(self, **kwargs) | ObjStorage.__init__(self, **kwargs) | ||||
self.compression = compression | self.compression = compression | ||||
Show All 15 Lines | def __init__(self, accounts, compression="gzip", **kwargs): | ||||
) | ) | ||||
missing_prefixes = expected_prefixes - set(accounts) | missing_prefixes = expected_prefixes - set(accounts) | ||||
if missing_prefixes: | if missing_prefixes: | ||||
raise ValueError( | raise ValueError( | ||||
"Missing prefixes %s" % ", ".join(sorted(missing_prefixes)) | "Missing prefixes %s" % ", ".join(sorted(missing_prefixes)) | ||||
) | ) | ||||
self.prefixes = {} | self.prefixes = {} | ||||
request_session = requests.Session() | for prefix, container_url in accounts.items(): | ||||
for prefix, account in accounts.items(): | self.prefixes[prefix] = ContainerClient.from_container_url(container_url) | ||||
self.prefixes[prefix] = ( | |||||
BlockBlobService( | |||||
account_name=account["account_name"], | |||||
account_key=account["api_secret_key"], | |||||
request_session=request_session, | |||||
), | |||||
account["container_name"], | |||||
) | |||||
def get_blob_service(self, hex_obj_id): | def get_container_client(self, hex_obj_id): | ||||
"""Get the block_blob_service and container that contains the object with | """Get the block_blob_service and container that contains the object with | ||||
internal id hex_obj_id | internal id hex_obj_id | ||||
""" | """ | ||||
return self.prefixes[hex_obj_id[: self.prefix_len]] | return self.prefixes[hex_obj_id[: self.prefix_len]] | ||||
def get_all_blob_services(self): | def get_all_container_clients(self): | ||||
"""Get all active block_blob_services""" | """Get all active container clients""" | ||||
# iterate on items() to sort blob services; | # iterate on items() to sort blob services; | ||||
# needed to be able to paginate in the list_content() method | # needed to be able to paginate in the list_content() method | ||||
yield from (v for _, v in sorted(self.prefixes.items())) | yield from (v for _, v in sorted(self.prefixes.items())) |
this method looks weird as is. Why does it exists? (not read the whole file yet, but my bet is a dispatching mechanism implemented in a subclass.
I've noted this was already like this before, nonetheless, it could benefit from a comment or something IMHO.