diff --git a/swh/model/hashutil.py b/swh/model/hashutil.py --- a/swh/model/hashutil.py +++ b/swh/model/hashutil.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2017 The Software Heritage developers +# Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -10,7 +10,48 @@ ALGORITHMS set. Any provided algorithms not in that list will result in a ValueError explaining the error. -This modules defines the following hashing functions: +This module defines a MultiHash class to ease the softwareheritage +hashing algorithms computation. This allows to compute hashes from +file object, path, data using a similar interface as what the standard +hashlib module provides. + +Basic usage examples: + +- file object: MultiHash.from_file( + file_object, hash_names=DEFAULT_ALGORITHMS).digest() + +- path (filepath): MultiHash.from_path(b'foo').hexdigest() + +- data (bytes): MultiHash.from_data(b'foo').bytehexdigest() + + +"Complex" usage, defining a swh hashlib instance first: + +- To compute length, integrate the length to the set of algorithms to + compute, for example: + + h = MultiHash(hash_names=set({'length'}).union(DEFAULT_ALGORITHMS)) + with open(filepath, 'rb') as f: + h.update(f.read(HASH_BLOCK_SIZE)) + hashes = h.digest() # returns a dict of {hash_algo_name: hash_in_bytes} + + for chunk in + # then use h as you would + +- Write alongside computing hashing algorithms (from a stream), example: + + h = MultiHash(length=length) + with open(filepath, 'wb') as f: + for chunk in r.iter_content(): # r a stream of sort + h.update(chunk) + f.write(chunk) + hashes = h.hexdigest() # returns a dict of {hash_algo_name: hash_in_hex} + + Note: Prior to this, we would have to use chunk_cb (cf. hash_file, + hash_path) + + +This module also defines the following (deprecated) hashing functions: - hash_file: Hash the contents of the given file object with the given algorithms (defaulting to DEFAULT_ALGORITHMS if none provided). @@ -46,6 +87,95 @@ _blake2_hash_cache = {} +class MultiHash: + """Hashutil class to support multiple hashes computation. + + Args: + + hash_names (set): Set of hash algorithms (+ optionally length) + to compute hashes (cf. DEFAULT_ALGORITHMS) + length (int): Length of the total sum of chunks to read + + If the length is provided as algorithm, the length is also + computed and returned. + + """ + def __init__(self, hash_names=DEFAULT_ALGORITHMS, length=None): + self.state = {} + self.track_length = False + for name in hash_names: + if name == 'length': + self.state['length'] = 0 + self.track_length = True + else: + self.state[name] = _new_hash(name, length) + + @classmethod + def from_state(cls, state, track_length): + ret = cls([]) + ret.state = state + ret.track_length = track_length + + @classmethod + def from_file(cls, file, hash_names=DEFAULT_ALGORITHMS, length=None): + ret = cls(length=length, hash_names=hash_names) + for chunk in file: + ret.update(chunk) + return ret + + @classmethod + def from_path(cls, path, hash_names=DEFAULT_ALGORITHMS, length=None, + track_length=True): + if not length: + length = os.path.getsize(path) + with open(path, 'rb') as f: + ret = cls.from_file(f, hash_names=hash_names, length=length) + # For compatibility reason with `hash_path` + if track_length: + ret.state['length'] = length + return ret + + @classmethod + def from_data(cls, data, hash_names=DEFAULT_ALGORITHMS, length=None): + if not length: + length = len(data) + fobj = BytesIO(data) + return cls.from_file(fobj, hash_names=hash_names, length=length) + + def update(self, chunk): + for name, h in self.state.items(): + if name == 'length': + continue + h.update(chunk) + if self.track_length: + self.state['length'] += len(chunk) + + def digest(self): + return { + name: h.digest() if name != 'length' else h + for name, h in self.state.items() + } + + def hexdigest(self): + return { + name: h.hexdigest() if name != 'length' else h + for name, h in self.state.items() + } + + def bytehexdigest(self): + return { + name: hash_to_bytehex(h.digest()) if name != 'length' else h + for name, h in self.state.items() + } + + def copy(self): + copied_state = { + name: h.copy() if name != 'length' else h + for name, h in self.state.items() + } + return self.from_state(copied_state, self.track_length) + + def _new_blake2_hash(algo): """Return a function that initializes a blake2 hash. @@ -162,43 +292,50 @@ return _new_hashlib_hash(algo) -def hash_file(fobj, length=None, algorithms=DEFAULT_ALGORITHMS, chunk_cb=None): - """Hash the contents of the given file object with the given algorithms. +def hash_file(fobj, length=None, algorithms=DEFAULT_ALGORITHMS, + chunk_cb=None): + """(Deprecated) cf. MultiHash.from_file + + Hash the contents of the given file object with the given algorithms. Args: fobj: a file-like object - length: the length of the contents of the file-like object (for the - git-specific algorithms) - algorithms: the hashing algorithms to be used, as an iterable over - strings + length (int): the length of the contents of the file-like + object (for the git-specific algorithms) + algorithms (set): the hashing algorithms to be used, as an + iterable over strings + chunk_cb (fun): a callback function taking a chunk of data as + parameter - Returns: a dict mapping each algorithm to a bytes digest. + Returns: + a dict mapping each algorithm to a digest (bytes by default). Raises: ValueError if algorithms contains an unknown hash algorithm. - """ - hashes = {algo: _new_hash(algo, length) for algo in algorithms} + """ + h = MultiHash(algorithms, length) while True: chunk = fobj.read(HASH_BLOCK_SIZE) if not chunk: break - for hash in hashes.values(): - hash.update(chunk) + h.update(chunk) if chunk_cb: chunk_cb(chunk) - return {algo: hash.digest() for algo, hash in hashes.items()} + return h.digest() def hash_path(path, algorithms=DEFAULT_ALGORITHMS, chunk_cb=None): - """Hash the contents of the file at the given path with the given - algorithms. + """(deprecated) cf. MultiHash.from_path + + Hash the contents of the file at the given path with the given + algorithms. Args: - path: the path of the file to hash - algorithms: the hashing algorithms used - chunk_cb: a callback + path (str): the path of the file to hash + algorithms (set): the hashing algorithms used + chunk_cb (fun): a callback function taking a chunk of data as parameter Returns: a dict mapping each algorithm to a bytes digest. @@ -209,31 +346,28 @@ """ length = os.path.getsize(path) with open(path, 'rb') as fobj: - hash = hash_file(fobj, length, algorithms, chunk_cb) - hash['length'] = length - return hash + hashes = hash_file(fobj, length, algorithms, chunk_cb=chunk_cb) + hashes['length'] = length + return hashes -def hash_data(data, algorithms=DEFAULT_ALGORITHMS, with_length=False): - """Hash the given binary blob with the given algorithms. +def hash_data(data, algorithms=DEFAULT_ALGORITHMS): + """(deprecated) cf. MultiHash.from_data + + Hash the given binary blob with the given algorithms. Args: data (bytes): raw content to hash - algorithms (list): the hashing algorithms used - with_length (bool): add the length key in the resulting dict + algorithms (set): the hashing algorithms used Returns: a dict mapping each algorithm to a bytes digest Raises: TypeError if data does not support the buffer interface. ValueError if algorithms contains an unknown hash algorithm. + """ - fobj = BytesIO(data) - length = len(data) - data = hash_file(fobj, length, algorithms) - if with_length: - data['length'] = length - return data + return MultiHash.from_data(data, hash_names=algorithms).digest() def hash_git_data(data, git_type, base_algo='sha1'): diff --git a/swh/model/tests/test_hashutil.py b/swh/model/tests/test_hashutil.py --- a/swh/model/tests/test_hashutil.py +++ b/swh/model/tests/test_hashutil.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2017 The Software Heritage developers +# Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -13,9 +13,10 @@ from unittest.mock import patch from swh.model import hashutil +from swh.model.hashutil import MultiHash -class Hashutil(unittest.TestCase): +class BaseHashutil(unittest.TestCase): def setUp(self): # Reset function cache hashutil._blake2_hash_cache = {} @@ -35,6 +36,11 @@ for type, cksum in self.hex_checksums.items() } + self.bytehex_checksums = { + type: hashutil.hash_to_bytehex(cksum) + for type, cksum in self.checksums.items() + } + self.git_hex_checksums = { 'blob': self.hex_checksums['sha1_git'], 'tree': '5b2e883aa33d2efab98442693ea4dd5f1b8871b0', @@ -47,6 +53,75 @@ for type, cksum in self.git_hex_checksums.items() } + +class MultiHashTest(BaseHashutil): + @istest + def multi_hash_data(self): + checksums = MultiHash.from_data(self.data).digest() + self.assertEqual(checksums, self.checksums) + self.assertFalse('length' in checksums) + + @istest + def multi_hash_data_with_length(self): + expected_checksums = self.checksums.copy() + expected_checksums['length'] = len(self.data) + + algos = set(['length']).union(hashutil.DEFAULT_ALGORITHMS) + checksums = MultiHash.from_data(self.data, hash_names=algos).digest() + + self.assertEqual(checksums, expected_checksums) + self.assertTrue('length' in checksums) + + @istest + def multi_hash_data_unknown_hash(self): + with self.assertRaises(ValueError) as cm: + MultiHash.from_data(self.data, ['unknown-hash']) + + self.assertIn('Unexpected hashing algorithm', cm.exception.args[0]) + self.assertIn('unknown-hash', cm.exception.args[0]) + + @istest + def multi_hash_file(self): + fobj = io.BytesIO(self.data) + + checksums = MultiHash.from_file(fobj, length=len(self.data)).digest() + self.assertEqual(checksums, self.checksums) + + @istest + def multi_hash_file_hexdigest(self): + fobj = io.BytesIO(self.data) + length = len(self.data) + checksums = MultiHash.from_file(fobj, length=length).hexdigest() + self.assertEqual(checksums, self.hex_checksums) + + @istest + def multi_hash_file_bytehexdigest(self): + fobj = io.BytesIO(self.data) + length = len(self.data) + checksums = MultiHash.from_file(fobj, length=length).bytehexdigest() + self.assertEqual(checksums, self.bytehex_checksums) + + @istest + def multi_hash_file_missing_length(self): + fobj = io.BytesIO(self.data) + with self.assertRaises(ValueError) as cm: + MultiHash.from_file(fobj, hash_names=['sha1_git']) + + self.assertIn('Missing length', cm.exception.args[0]) + + @istest + def multi_hash_path(self): + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(self.data) + + hashes = MultiHash.from_path(f.name).digest() + os.remove(f.name) + + self.checksums['length'] = len(self.data) + self.assertEquals(self.checksums, hashes) + + +class Hashutil(BaseHashutil): @istest def hash_data(self): checksums = hashutil.hash_data(self.data) @@ -58,7 +133,8 @@ expected_checksums = self.checksums.copy() expected_checksums['length'] = len(self.data) - checksums = hashutil.hash_data(self.data, with_length=True) + algos = set(['length']).union(hashutil.DEFAULT_ALGORITHMS) + checksums = hashutil.hash_data(self.data, algorithms=algos) self.assertEqual(checksums, expected_checksums) self.assertTrue('length' in checksums)