diff --git a/swh/objstorage/cloud/objstorage_cloud.py b/swh/objstorage/cloud/objstorage_cloud.py --- a/swh/objstorage/cloud/objstorage_cloud.py +++ b/swh/objstorage/cloud/objstorage_cloud.py @@ -4,9 +4,11 @@ # See top-level LICENSE file for more information import abc +import collections from swh.model import hashutil from swh.objstorage.objstorage import ObjStorage, compute_hash +from swh.objstorage.objstorage import compressors, decompressors from swh.objstorage.exc import ObjNotFoundError, Error from libcloud.storage import providers @@ -22,12 +24,13 @@ https://libcloud.readthedocs.io/en/latest/storage/api.html). """ - def __init__(self, container_name, **kwargs): + def __init__(self, container_name, compression=None, **kwargs): super().__init__(**kwargs) self.driver = self._get_driver(**kwargs) self.container_name = container_name self.container = self.driver.get_container( container_name=container_name) + self.compression = compression def _get_driver(self, **kwargs): """Initialize a driver to communicate with the cloud @@ -118,7 +121,8 @@ return self.add(content, obj_id, check_presence=False) def get(self, obj_id): - return bytes(self._get_object(obj_id).as_stream()) + obj = b''.join(self._get_object(obj_id).as_stream()) + return decompressors[self.compression](obj) def check(self, obj_id): # Check that the file exists, as _get_object raises ObjNotFoundError @@ -142,11 +146,22 @@ """ hex_obj_id = hashutil.hash_to_hex(obj_id) + try: return self.driver.get_object(self.container_name, hex_obj_id) except ObjectDoesNotExistError: raise ObjNotFoundError(obj_id) + def _compressor(self, data): + comp = compressors[self.compression]() + for chunk in data: + cchunk = comp.compress(chunk) + if cchunk: + yield cchunk + trail = comp.flush() + if trail: + yield trail + def _put_object(self, content, obj_id): """Create an object in the cloud storage. @@ -155,11 +170,12 @@ """ hex_obj_id = hashutil.hash_to_hex(obj_id) - self.driver.upload_object_via_stream(iter(content), self.container, - hex_obj_id) - def list_content(self): - return iter(self) + if not isinstance(content, collections.Iterator): + content = (content,) + self.driver.upload_object_via_stream( + self._compressor(content), + self.container, hex_obj_id) class AwsCloudObjStorage(CloudObjStorage): diff --git a/swh/objstorage/objstorage.py b/swh/objstorage/objstorage.py --- a/swh/objstorage/objstorage.py +++ b/swh/objstorage/objstorage.py @@ -5,6 +5,9 @@ import abc from itertools import dropwhile, islice +import bz2 +import lzma +import zlib from swh.model import hashutil @@ -34,6 +37,29 @@ ).digest().get(ID_HASH_ALGO) +class NullCompressor: + def compress(self, data): + return data + + def flush(self): + return b'' + + +decompressors = { + 'bz2': bz2.decompress, + 'lzma': lzma.decompress, + 'zlib': zlib.decompress, + None: lambda x: x, + } + +compressors = { + 'bz2': bz2.BZ2Compressor, + 'lzma': lzma.LZMACompressor, + 'zlib': zlib.compressobj, + None: NullCompressor, + } + + class ObjStorage(metaclass=abc.ABCMeta): """ High-level API to manipulate the Software Heritage object storage. diff --git a/swh/objstorage/tests/test_objstorage_cloud.py b/swh/objstorage/tests/test_objstorage_cloud.py --- a/swh/objstorage/tests/test_objstorage_cloud.py +++ b/swh/objstorage/tests/test_objstorage_cloud.py @@ -4,10 +4,14 @@ # See top-level LICENSE file for more information import unittest +import bz2 +import lzma +import zlib from libcloud.common.types import InvalidCredsError from libcloud.storage.types import (ContainerDoesNotExistError, ObjectDoesNotExistError) +from swh.model import hashutil from swh.objstorage.cloud.objstorage_cloud import CloudObjStorage from .objstorage_testing import ObjStorageTestFixture @@ -96,3 +100,67 @@ CONTAINER_NAME, api_key=API_KEY, api_secret_key=API_SECRET_KEY, ) + + def test_compression(self): + content, obj_id = self.hash_content(b'add_get_w_id') + self.storage.add(content, obj_id=obj_id) + data = self.storage.driver.containers[CONTAINER_NAME] + obj_id = hashutil.hash_to_hex(obj_id) + self.assertEqual(b''.join(data[obj_id].content), content) + + +class TestCloudObjStorageBz2(ObjStorageTestFixture, unittest.TestCase): + + def setUp(self): + super().setUp() + self.storage = MockCloudObjStorage( + CONTAINER_NAME, + compression='bz2', + api_key=API_KEY, api_secret_key=API_SECRET_KEY, + ) + + def test_compression(self): + content, obj_id = self.hash_content(b'add_get_w_id') + self.storage.add(content, obj_id=obj_id) + data = self.storage.driver.containers[CONTAINER_NAME] + obj_id = hashutil.hash_to_hex(obj_id) + self.assertEqual(bz2.decompress(b''.join(data[obj_id].content)), + content) + + +class TestCloudObjStorageLzma(ObjStorageTestFixture, unittest.TestCase): + + def setUp(self): + super().setUp() + self.storage = MockCloudObjStorage( + CONTAINER_NAME, + compression='lzma', + api_key=API_KEY, api_secret_key=API_SECRET_KEY, + ) + + def test_compression(self): + content, obj_id = self.hash_content(b'add_get_w_id') + self.storage.add(content, obj_id=obj_id) + data = self.storage.driver.containers[CONTAINER_NAME] + obj_id = hashutil.hash_to_hex(obj_id) + self.assertEqual(lzma.decompress(b''.join(data[obj_id].content)), + content) + + +class TestCloudObjStorageZlib(ObjStorageTestFixture, unittest.TestCase): + + def setUp(self): + super().setUp() + self.storage = MockCloudObjStorage( + CONTAINER_NAME, + compression='zlib', + api_key=API_KEY, api_secret_key=API_SECRET_KEY, + ) + + def test_compression(self): + content, obj_id = self.hash_content(b'add_get_w_id') + self.storage.add(content, obj_id=obj_id) + data = self.storage.driver.containers[CONTAINER_NAME] + obj_id = hashutil.hash_to_hex(obj_id) + self.assertEqual(zlib.decompress(b''.join(data[obj_id].content)), + content)