diff --git a/swh/vault/cookers/base.py b/swh/vault/cookers/base.py --- a/swh/vault/cookers/base.py +++ b/swh/vault/cookers/base.py @@ -1,4 +1,4 @@ -# Copyright (C) 2016-2017 The Software Heritage developers +# Copyright (C) 2016-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -6,6 +6,7 @@ import abc import io import itertools +import logging import os import tarfile import tempfile @@ -27,10 +28,21 @@ 'url': 'http://localhost:5002/', }, }), - 'vault_url': ('str', 'http://localhost:5005/') + 'vault_url': ('str', 'http://localhost:5005/'), + 'max_bundle_size': ('int', 2 ** 29), # 512 MiB } +class PolicyError(Exception): + """Raised when the bundle violates the cooking policy.""" + pass + + +class BundleTooLargeError(PolicyError): + """Raised when the bundle is too large to be cooked.""" + pass + + class BaseVaultCooker(metaclass=abc.ABCMeta): """Abstract base class for the vault's bundle creators @@ -60,6 +72,7 @@ self.obj_id = hashutil.hash_to_bytes(obj_id) self.backend = RemoteVaultClient(self.config['vault_url']) self.storage = get_storage(**self.config['storage']) + self.max_bundle_size = self.config['max_bundle_size'] @abc.abstractmethod def check_exists(self): @@ -82,15 +95,27 @@ """ self.backend.set_status(self.obj_type, self.obj_id, 'pending') self.backend.set_progress(self.obj_type, self.obj_id, 'Processing...') - content_iter = self.prepare_bundle() - # TODO: use proper content streaming try: - bundle = b''.join(content_iter) + content_iter = self.prepare_bundle() + bundle = b'' + for chunk in content_iter: + bundle += chunk + if len(bundle) > self.max_bundle_size: + raise BundleTooLargeError( + "The requested bundle exceeds the maximum allowed " + "size of {} MiB.".format(self.max_bundle_size)) + except PolicyError as e: + self.backend.set_status(self.obj_type, self.obj_id, 'failed') + self.backend.set_progress(self.obj_type, self.obj_id, str(e)) except Exception as e: self.backend.set_status(self.obj_type, self.obj_id, 'failed') - self.backend.set_progress(self.obj_type, self.obj_id, e.message) + self.backend.set_progress( + self.obj_type, self.obj_id, + "Internal Server Error. This incident will be reported.") + logging.exception("Bundle cooking failed.") else: + # TODO: use proper content streaming instead of put_bundle() self.backend.put_bundle(self.CACHE_TYPE_KEY, self.obj_id, bundle) self.backend.set_status(self.obj_type, self.obj_id, 'done') self.backend.set_progress(self.obj_type, self.obj_id, None) @@ -129,11 +154,30 @@ return list(storage.content_get([file_data['sha1']]))[0]['data'] -def get_tar_bytes(path, arcname=None): +# TODO: We should use something like that for all the IO done by the cookers. +# Instead of using generators to yield chunks, we should just write() the +# chunks in an object like this, which would give us way better control over +# the buffering, and allow for streaming content to the objstorage. +class BytesIOBundleSizeLimit(io.BytesIO): + def __init__(self, *args, size_limit=None, **kwargs): + self.size_limit = size_limit + + def write(self, chunk): + if ((self.size_limit is not None + and self.getbuffer().nbytes + len(chunk) > self.size_limit)): + raise BundleTooLargeError( + "The requested bundle exceeds the maximum allowed " + "size of {} MiB.".format(self.size_limit)) + return super().write(chunk) + + +# TODO: Once the BytesIO buffer is put in BaseCooker, we can just pass it here +# as a fileobj parameter instead of passing size_limit +def get_tar_bytes(path, arcname=None, size_limit=None): path = Path(path) if not arcname: arcname = path.name - tar_buffer = io.BytesIO() + tar_buffer = BytesIOBundleSizeLimit(size_limit=size_limit) tar = tarfile.open(fileobj=tar_buffer, mode='w') tar.add(str(path), arcname=arcname) return tar_buffer.getbuffer() @@ -149,7 +193,7 @@ def __init__(self, storage): self.storage = storage - def get_directory_bytes(self, dir_id): + def get_directory_bytes(self, dir_id, size_limit=None): # Create temporary folder to retrieve the files into. root = bytes(tempfile.mkdtemp(prefix='directory.', suffix='.cook'), 'utf8') @@ -158,7 +202,8 @@ # a compressed directory. bundle_content = self._create_bundle_content( root, - hashutil.hash_to_hex(dir_id)) + hashutil.hash_to_hex(dir_id), + size_limit=size_limit) return bundle_content def build_directory(self, dir_id, root): @@ -220,7 +265,7 @@ content = list(self.storage.content_get([obj_id]))[0]['data'] return content - def _create_bundle_content(self, path, hex_dir_id): + def _create_bundle_content(self, path, hex_dir_id, size_limit=None): """Create a bundle from the given directory Args: @@ -231,4 +276,4 @@ bytes that represent the compressed directory as a bundle. """ - return get_tar_bytes(path.decode(), hex_dir_id) + return get_tar_bytes(path.decode(), hex_dir_id, size_limit=size_limit) diff --git a/swh/vault/cookers/directory.py b/swh/vault/cookers/directory.py --- a/swh/vault/cookers/directory.py +++ b/swh/vault/cookers/directory.py @@ -15,4 +15,5 @@ def prepare_bundle(self): directory_builder = DirectoryBuilder(self.storage) - yield directory_builder.get_directory_bytes(self.obj_id) + yield directory_builder.get_directory_bytes(self.obj_id, + self.max_bundle_size) diff --git a/swh/vault/cookers/revision_flat.py b/swh/vault/cookers/revision_flat.py --- a/swh/vault/cookers/revision_flat.py +++ b/swh/vault/cookers/revision_flat.py @@ -34,4 +34,5 @@ directory_builder.build_directory(revision['directory'], str(revdir).encode()) # FIXME: stream the bytes! this tarball can be HUUUUUGE - yield get_tar_bytes(root_tmp, hashutil.hash_to_hex(self.obj_id)) + yield get_tar_bytes(root_tmp, hashutil.hash_to_hex(self.obj_id), + self.max_bundle_size) diff --git a/swh/vault/tests/test_cookers_base.py b/swh/vault/tests/test_cookers_base.py new file mode 100644 --- /dev/null +++ b/swh/vault/tests/test_cookers_base.py @@ -0,0 +1,92 @@ +# Copyright (C) 2018 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import pathlib +import tempfile +import unittest +from unittest.mock import MagicMock + +from swh.model import hashutil +from swh.vault.cookers.base import (BaseVaultCooker, get_tar_bytes, + BundleTooLargeError) + + +TEST_BUNDLE_CHUNKS = [b"test content 1\n", + b"test content 2\n", + b"test content 3\n"] +TEST_BUNDLE_CONTENT = b''.join(TEST_BUNDLE_CHUNKS) +TEST_OBJ_TYPE = 'test_type' +TEST_HEX_ID = '17a3e48bce37be5226490e750202ad3a9a1a3fe9' +TEST_OBJ_ID = hashutil.hash_to_bytes(TEST_HEX_ID) + + +class BaseVaultCookerMock(BaseVaultCooker): + CACHE_TYPE_KEY = TEST_OBJ_TYPE + + def __init__(self, *args, **kwargs): + super().__init__(self.CACHE_TYPE_KEY, TEST_OBJ_ID, *args, **kwargs) + self.storage = MagicMock() + self.backend = MagicMock() + + def check_exists(self): + return True + + def prepare_bundle(self): + for chunk in TEST_BUNDLE_CHUNKS: + yield chunk + + +class TestBaseVaultCooker(unittest.TestCase): + def test_simple_cook(self): + cooker = BaseVaultCookerMock() + cooker.cook() + cooker.backend.put_bundle.assert_called_once_with( + TEST_OBJ_TYPE, TEST_OBJ_ID, TEST_BUNDLE_CONTENT) + cooker.backend.set_status.assert_called_with( + TEST_OBJ_TYPE, TEST_OBJ_ID, 'done') + cooker.backend.set_progress.assert_called_with( + TEST_OBJ_TYPE, TEST_OBJ_ID, None) + cooker.backend.send_notif.assert_called_with( + TEST_OBJ_TYPE, TEST_OBJ_ID) + + def test_code_exception_cook(self): + cooker = BaseVaultCookerMock() + cooker.prepare_bundle = MagicMock() + cooker.prepare_bundle.side_effect = RuntimeError("Nope") + cooker.cook() + + # Potentially remove this when we have objstorage streaming + cooker.backend.put_bundle.assert_not_called() + + cooker.backend.set_status.assert_called_with( + TEST_OBJ_TYPE, TEST_OBJ_ID, 'failed') + self.assertNotIn("Nope", cooker.backend.set_progress.call_args[0][2]) + cooker.backend.send_notif.assert_called_with( + TEST_OBJ_TYPE, TEST_OBJ_ID) + + def test_policy_exception_cook(self): + cooker = BaseVaultCookerMock() + cooker.max_bundle_size = 8 + cooker.cook() + + # Potentially remove this when we have objstorage streaming + cooker.backend.put_bundle.assert_not_called() + + cooker.backend.set_status.assert_called_with( + TEST_OBJ_TYPE, TEST_OBJ_ID, 'failed') + self.assertIn("exceeds", cooker.backend.set_progress.call_args[0][2]) + cooker.backend.send_notif.assert_called_with( + TEST_OBJ_TYPE, TEST_OBJ_ID) + + +class TestGetTarBytes(unittest.TestCase): + def test_tar_too_large(self): + with tempfile.TemporaryDirectory(prefix='tmp-vault-repo-') as td: + p = pathlib.Path(td) + (p / 'dir1/dir2').mkdir(parents=True) + (p / 'dir1/dir2/file').write_text('testtesttesttest') + + with self.assertRaises(BundleTooLargeError): + get_tar_bytes(p, size_limit=8)