diff --git a/swh/objstorage/backends/pathslicing.py b/swh/objstorage/backends/pathslicing.py --- a/swh/objstorage/backends/pathslicing.py +++ b/swh/objstorage/backends/pathslicing.py @@ -1,11 +1,9 @@ -# Copyright (C) 2015-2018 The Software Heritage developers +# Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import functools import os -import gzip import tempfile import random import collections @@ -16,12 +14,13 @@ from swh.model import hashutil from swh.objstorage.objstorage import ( + compressors, decompressors, ObjStorage, compute_hash, ID_HASH_ALGO, ID_HASH_LENGTH, DEFAULT_CHUNK_SIZE, DEFAULT_LIMIT) from swh.objstorage.exc import ObjNotFoundError, Error -GZIP_BUFSIZ = 1048576 +BUFSIZ = 1048576 DIR_MODE = 0o755 FILE_MODE = 0o644 @@ -32,8 +31,7 @@ """ Context manager for writing object files to the object storage. During writing, data are written to a temporary file, which is atomically - renamed to the right file name after closing. This context manager also - takes care of (gzip) compressing the data on the fly. + renamed to the right file name after closing. Usage sample: with _write_obj_file(hex_obj_id, objstorage): @@ -54,8 +52,7 @@ # Open the file and yield it for writing. tmp_f = os.fdopen(tmp, 'wb') - with gzip.GzipFile(filename=tmp_path, fileobj=tmp_f) as f: - yield f + yield tmp_f # Make sure the contents of the temporary file are written to disk tmp_f.flush() @@ -64,13 +61,12 @@ else: os.fsync(tmp) - # Then close the temporary file and move it to the right directory. + # Then close the temporary file and move it to the right path. tmp_f.close() os.chmod(tmp_path, FILE_MODE) os.rename(tmp_path, path) -@contextmanager def _read_obj_file(hex_obj_id, objstorage): """ Context manager for reading object file in the object storage. @@ -82,8 +78,8 @@ a file-like object open for reading bytes. """ path = objstorage._obj_path(hex_obj_id) - with gzip.GzipFile(path, 'rb') as f: - yield f + + return open(path, 'rb') class PathSlicingObjStorage(ObjStorage): @@ -113,7 +109,7 @@ """ - def __init__(self, root, slicing, **kwargs): + def __init__(self, root, slicing, compression='gzip', **kwargs): """ Create an object to access a hash-slicing based object storage. Args: @@ -134,6 +130,7 @@ ] self.use_fdatasync = hasattr(os, 'fdatasync') + self.compression = compression self.check_config(check_write=False) @@ -160,6 +157,10 @@ 'PathSlicingObjStorage root "%s" is not writable' % root ) + if self.compression not in compressors: + raise ValueError('Unknown compression algorithm "%s" for ' + 'PathSlicingObjStorage' % self.compression) + return True def __contains__(self, obj_id): @@ -236,10 +237,13 @@ return obj_id hex_obj_id = hashutil.hash_to_hex(obj_id) - if isinstance(content, collections.Iterator): - content = b''.join(content) + if not isinstance(content, collections.Iterator): + content = [content] + compressor = compressors[self.compression]() with _write_obj_file(hex_obj_id, self) as f: - f.write(content) + for chunk in content: + f.write(compressor.compress(chunk)) + f.write(compressor.flush()) return obj_id @@ -249,42 +253,35 @@ # Open the file and return its content as bytes hex_obj_id = hashutil.hash_to_hex(obj_id) + d = decompressors[self.compression]() with _read_obj_file(hex_obj_id, self) as f: - return f.read() + out = d.decompress(f.read()) + if d.unused_data: + raise Error('Corrupt object %s: trailing data found' % hex_obj_id,) + + return out def check(self, obj_id): - if obj_id not in self: - raise ObjNotFoundError(obj_id) + try: + data = self.get(obj_id) + except OSError: + hex_obj_id = hashutil.hash_to_hex(obj_id) + raise Error( + 'Corrupt object %s: not a proper compressed file' % hex_obj_id, + ) + + checksums = hashutil.MultiHash.from_data( + data, hash_names=[ID_HASH_ALGO]).digest() + actual_obj_id = checksums[ID_HASH_ALGO] hex_obj_id = hashutil.hash_to_hex(obj_id) - try: - with gzip.open(self._obj_path(hex_obj_id)) as f: - length = None - if ID_HASH_ALGO.endswith('_git'): - # if the hashing algorithm is git-like, we need to know the - # content size to hash on the fly. Do a first pass here to - # compute the size - length = 0 - while True: - chunk = f.read(GZIP_BUFSIZ) - length += len(chunk) - if not chunk: - break - f.rewind() - - checksums = hashutil.MultiHash.from_file( - f, hash_names=[ID_HASH_ALGO], length=length).digest() - actual_obj_id = checksums[ID_HASH_ALGO] - if hex_obj_id != hashutil.hash_to_hex(actual_obj_id): - raise Error( - 'Corrupt object %s should have id %s' - % (hashutil.hash_to_hex(obj_id), - hashutil.hash_to_hex(actual_obj_id)) - ) - except (OSError, IOError): - # IOError is for compatibility with older python versions - raise Error('Corrupt object %s is not a gzip file' % hex_obj_id) + if hex_obj_id != hashutil.hash_to_hex(actual_obj_id): + raise Error( + 'Corrupt object %s should have id %s' + % (hashutil.hash_to_hex(obj_id), + hashutil.hash_to_hex(actual_obj_id)) + ) def delete(self, obj_id): super().delete(obj_id) # Check delete permission @@ -331,8 +328,10 @@ @contextmanager def chunk_writer(self, obj_id): hex_obj_id = hashutil.hash_to_hex(obj_id) + compressor = compressors[self.compression]() with _write_obj_file(hex_obj_id, self) as f: - yield f.write + yield lambda c: f.write(compressor.compress(c)) + f.write(compressor.flush()) def add_stream(self, content_iter, obj_id, check_presence=True): if check_presence and obj_id in self: @@ -349,9 +348,16 @@ raise ObjNotFoundError(obj_id) hex_obj_id = hashutil.hash_to_hex(obj_id) + decompressor = decompressors[self.compression]() with _read_obj_file(hex_obj_id, self) as f: - reader = functools.partial(f.read, chunk_size) - yield from iter(reader, b'') + while True: + raw = f.read(chunk_size) + if not raw: + break + r = decompressor.decompress(raw) + if not r: + continue + yield r def list_content(self, last_obj_id=None, limit=DEFAULT_LIMIT): if last_obj_id: diff --git a/swh/objstorage/tests/objstorage_testing.py b/swh/objstorage/tests/objstorage_testing.py --- a/swh/objstorage/tests/objstorage_testing.py +++ b/swh/objstorage/tests/objstorage_testing.py @@ -156,7 +156,6 @@ return self.assertTrue(isinstance(r, collections.Iterator)) r = list(r) - self.assertEqual(len(r), 9) self.assertEqual(b''.join(r), content) def test_add_batch(self): diff --git a/swh/objstorage/tests/test_objstorage_pathslicing.py b/swh/objstorage/tests/test_objstorage_pathslicing.py --- a/swh/objstorage/tests/test_objstorage_pathslicing.py +++ b/swh/objstorage/tests/test_objstorage_pathslicing.py @@ -7,7 +7,8 @@ import tempfile import unittest from unittest.mock import patch, DEFAULT -import gzip + +from typing import Optional from swh.model import hashutil from swh.objstorage import exc, get_objstorage, ID_HASH_LENGTH @@ -16,14 +17,18 @@ class TestPathSlicingObjStorage(ObjStorageTestFixture, unittest.TestCase): + compression = None # type: Optional[str] def setUp(self): super().setUp() self.slicing = '0:2/2:4/4:6' self.tmpdir = tempfile.mkdtemp() self.storage = get_objstorage( - 'pathslicing', - {'root': self.tmpdir, 'slicing': self.slicing} + 'pathslicing', { + 'root': self.tmpdir, + 'slicing': self.slicing, + 'compression': self.compression, + } ) def tearDown(self): @@ -49,28 +54,15 @@ def test_check_ok(self): content, obj_id = self.hash_content(b'check_ok') self.storage.add(content, obj_id=obj_id) - self.storage.check(obj_id) - self.storage.check(obj_id.hex()) - - def test_check_not_gzip(self): - content, obj_id = self.hash_content(b'check_not_gzip') - self.storage.add(content, obj_id=obj_id) - with open(self.content_path(obj_id), 'ab') as f: # Add garbage. - f.write(b'garbage') - with self.assertRaises(exc.Error) as error: - self.storage.check(obj_id) - self.assertEquals(( - 'Corrupt object %s is not a gzip file' % obj_id.hex(),), - error.exception.args) + assert self.storage.check(obj_id) is None + assert self.storage.check(obj_id.hex()) is None def test_check_id_mismatch(self): content, obj_id = self.hash_content(b'check_id_mismatch') - self.storage.add(content, obj_id=obj_id) - with gzip.open(self.content_path(obj_id), 'wb') as f: - f.write(b'unexpected content') + self.storage.add(b'unexpected content', obj_id=obj_id) with self.assertRaises(exc.Error) as error: self.storage.check(obj_id) - self.assertEquals(( + self.assertEqual(( 'Corrupt object %s should have id ' '12ebb2d6c81395bcc5cab965bdff640110cb67ff' % obj_id.hex(),), error.exception.args) @@ -137,3 +129,31 @@ self.storage.add(content, obj_id=obj_id) assert patched['fdatasync'].call_count == 0 assert patched['fsync'].call_count == 1 + + def test_check_not_compressed(self): + content, obj_id = self.hash_content(b'check_not_compressed') + self.storage.add(content, obj_id=obj_id) + with open(self.content_path(obj_id), 'ab') as f: # Add garbage. + f.write(b'garbage') + with self.assertRaises(exc.Error) as error: + self.storage.check(obj_id) + if self.compression is None: + self.assertIn('Corrupt object', error.exception.args[0]) + else: + self.assertIn('trailing data found', error.exception.args[0]) + + +class TestPathSlicingObjStorageGzip(TestPathSlicingObjStorage): + compression = 'gzip' + + +class TestPathSlicingObjStorageZlib(TestPathSlicingObjStorage): + compression = 'zlib' + + +class TestPathSlicingObjStorageBz2(TestPathSlicingObjStorage): + compression = 'bz2' + + +class TestPathSlicingObjStorageLzma(TestPathSlicingObjStorage): + compression = 'lzma'