diff --git a/swh/model/merkle.py b/swh/model/merkle.py new file mode 100644 --- /dev/null +++ b/swh/model/merkle.py @@ -0,0 +1,286 @@ +# Copyright (C) 2017 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +"""Merkle tree data structure""" + +import abc +import collections + + +def deep_update(left, right): + """Recursively update the left mapping with deeply nested values from the right + mapping. + + This function is useful to merge the results of several calls to + :func:`MerkleNode.collect`. + + Arguments: + left: a mapping (modified by the update operation) + right: a mapping + + Returns: + the left mapping, updated with nested values from the right mapping + + Example: + >>> a = { + ... 'key1': { + ... 'key2': { + ... 'key3': 'value1/2/3', + ... }, + ... }, + ... } + >>> deep_update(a, { + ... 'key1': { + ... 'key2': { + ... 'key4': 'value1/2/4', + ... }, + ... }, + ... }) == { + ... 'key1': { + ... 'key2': { + ... 'key3': 'value1/2/3', + ... 'key4': 'value1/2/4', + ... }, + ... }, + ... } + True + >>> deep_update(a, { + ... 'key1': { + ... 'key2': { + ... 'key3': 'newvalue1/2/3', + ... }, + ... }, + ... }) == { + ... 'key1': { + ... 'key2': { + ... 'key3': 'newvalue1/2/3', + ... 'key4': 'value1/2/4', + ... }, + ... }, + ... } + True + + """ + for key, rvalue in right.items(): + if isinstance(rvalue, collections.Mapping): + new_lvalue = deep_update(left.get(key, {}), rvalue) + left[key] = new_lvalue + else: + left[key] = rvalue + return left + + +class MerkleNode(dict, metaclass=abc.ABCMeta): + """Representation of a node in a Merkle Tree. + + A (generalized) `Merkle Tree`_ is a tree in which every node is labeled + with a hash of its own data and the hash of its children. + + .. _Merkle Tree: https://en.wikipedia.org/wiki/Merkle_tree + + In pseudocode:: + + node.hash = hash(node.data + + sum(child.hash for child in node.children)) + + This class efficiently implements the Merkle Tree data structure on top of + a Python :class:`dict`, minimizing hash computations and new data + collections when updating nodes. + + Node data is stored in the :attr:`data` attribute, while (named) children + are stored as items of the underlying dictionary. + + Addition, update and removal of objects are instrumented to automatically + invalidate the hashes of the current node as well as its registered + parents; It also resets the collection status of the objects so the updated + objects can be collected. + + The collection of updated data from the tree is implemented through the + :func:`collect` function and associated helpers. + + Attributes: + data (dict): data associated to the current node + parents (list): known parents of the current node + collected (bool): whether the current node has been collected + + """ + __slots__ = ['parents', 'data', '__hash', 'collected'] + + type = None + """Type of the current node (used as a classifier for :func:`collect`)""" + + def __init__(self, data=None): + super().__init__() + self.parents = [] + self.data = data + self.__hash = None + self.collected = False + + def invalidate_hash(self): + """Invalidate the cached hash of the current node.""" + if not self.__hash: + return + + self.__hash = None + self.collected = False + for parent in self.parents: + parent.invalidate_hash() + + def update_hash(self, *, force=False): + """Recursively compute the hash of the current node. + + Args: + force (bool): invalidate the cache and force the computation for + this node and all children. + """ + if self.__hash and not force: + return self.__hash + + if force: + self.invalidate_hash() + + for child in self.values(): + child.update_hash(force=force) + + self.__hash = self.compute_hash() + return self.__hash + + @property + def hash(self): + """The hash of the current node, as calculated by + :func:`compute_hash`. + """ + return self.update_hash() + + @abc.abstractmethod + def compute_hash(self): + """Compute the hash of the current node. + + The hash should depend on the data of the node, as well as on hashes + of the children nodes. + """ + raise NotImplementedError('Must implement compute_hash method') + + def __setitem__(self, name, new_child): + """Add a child, invalidating the current hash""" + self.invalidate_hash() + + super().__setitem__(name, new_child) + + new_child.parents.append(self) + + def __delitem__(self, name): + """Remove a child, invalidating the current hash""" + if name in self: + self.invalidate_hash() + self[name].parents.remove(self) + super().__delitem__(name) + else: + raise KeyError(name) + + def update(self, new_children): + """Add several named children from a dictionary""" + if not new_children: + return + + self.invalidate_hash() + + for name, new_child in new_children.items(): + new_child.parents.append(self) + if name in self: + self[name].parents.remove(self) + + super().update(new_children) + + def get_data(self, **kwargs): + """Retrieve and format the collected data for the current node, for use by + :func:`collect`. + + Can be overridden, for instance when you want the collected data to + contain information about the child nodes. + + Arguments: + kwargs: allow subclasses to alter behaviour depending on how + :func:`collect` is called. + + Returns: + data formatted for :func:`collect` + """ + return self.data + + def collect_node(self, **kwargs): + """Collect the data for the current node, for use by :func:`collect`. + + Arguments: + kwargs: passed as-is to :func:`get_data`. + + Returns: + A :class:`dict` compatible with :func:`collect`. + """ + if not self.collected: + self.collected = True + return {self.type: {self.hash: self.get_data(**kwargs)}} + else: + return {} + + def collect(self, **kwargs): + """Collect the data for all nodes in the subtree rooted at `self`. + + The data is deduplicated by type and by hash. + + Arguments: + kwargs: passed as-is to :func:`get_data`. + + Returns: + A :class:`dict` with the following structure:: + + { + 'typeA': { + node1.hash: node1.get_data(), + node2.hash: node2.get_data(), + }, + 'typeB': { + node3.hash: node3.get_data(), + ... + }, + ... + } + """ + ret = self.collect_node(**kwargs) + for child in self.values(): + deep_update(ret, child.collect(**kwargs)) + + return ret + + def reset_collect(self): + """Recursively unmark collected nodes in the subtree rooted at `self`. + + This lets the caller use :func:`collect` again. + """ + self.collected = False + + for child in self.values(): + child.reset_collect() + + +class MerkleLeaf(MerkleNode): + """A leaf to a Merkle tree. + + A Merkle leaf is simply a Merkle node with children disabled. + """ + __slots__ = [] + + def __setitem__(self, name, child): + raise ValueError('%s is a leaf' % self.__class__.__name__) + + def __getitem__(self, name): + raise ValueError('%s is a leaf' % self.__class__.__name__) + + def __delitem__(self, name): + raise ValueError('%s is a leaf' % self.__class__.__name__) + + def update(self, new_children): + """Children update operation. Disabled for leaves.""" + raise ValueError('%s is a leaf' % self.__class__.__name__) diff --git a/swh/model/tests/test_merkle.py b/swh/model/tests/test_merkle.py new file mode 100644 --- /dev/null +++ b/swh/model/tests/test_merkle.py @@ -0,0 +1,229 @@ +# Copyright (C) 2017 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import unittest + +from swh.model import merkle + + +class TestedMerkleNode(merkle.MerkleNode): + type = 'tested_merkle_node_type' + + def __init__(self, data): + super().__init__(data) + self.compute_hash_called = 0 + + def compute_hash(self): + self.compute_hash_called += 1 + child_data = [ + child + b'=' + self[child].hash + for child in sorted(self) + ] + + return ( + b'hash(' + + b', '.join([self.data['value']] + child_data) + + b')' + ) + + +class TestedMerkleLeaf(merkle.MerkleLeaf): + type = 'tested_merkle_leaf_type' + + def __init__(self, data): + super().__init__(data) + self.compute_hash_called = 0 + + def compute_hash(self): + self.compute_hash_called += 1 + return b'hash(' + self.data['value'] + b')' + + +class TestMerkleLeaf(unittest.TestCase): + def setUp(self): + self.data = {'value': b'value'} + self.instance = TestedMerkleLeaf(self.data) + + def test_hash(self): + self.assertEqual(self.instance.compute_hash_called, 0) + instance_hash = self.instance.hash + self.assertEqual(self.instance.compute_hash_called, 1) + instance_hash2 = self.instance.hash + self.assertEqual(self.instance.compute_hash_called, 1) + self.assertEqual(instance_hash, instance_hash2) + + def test_data(self): + self.assertEqual(self.instance.get_data(), self.data) + + def test_collect(self): + collected = self.instance.collect() + self.assertEqual( + collected, { + self.instance.type: { + self.instance.hash: self.instance.get_data(), + }, + }, + ) + collected2 = self.instance.collect() + self.assertEqual(collected2, {}) + self.instance.reset_collect() + collected3 = self.instance.collect() + self.assertEqual(collected, collected3) + + def test_leaf(self): + with self.assertRaisesRegex(ValueError, 'is a leaf'): + self.instance[b'key1'] = 'Test' + + with self.assertRaisesRegex(ValueError, 'is a leaf'): + del self.instance[b'key1'] + + with self.assertRaisesRegex(ValueError, 'is a leaf'): + self.instance[b'key1'] + + with self.assertRaisesRegex(ValueError, 'is a leaf'): + self.instance.update(self.data) + + +class TestMerkleNode(unittest.TestCase): + maxDiff = None + + def setUp(self): + self.root = TestedMerkleNode({'value': b'root'}) + self.nodes = {b'root': self.root} + for i in (b'a', b'b', b'c'): + value = b'root/' + i + node = TestedMerkleNode({ + 'value': value, + }) + self.root[i] = node + self.nodes[value] = node + for j in (b'a', b'b', b'c'): + value2 = value + b'/' + j + node2 = TestedMerkleNode({ + 'value': value2, + }) + node[j] = node2 + self.nodes[value2] = node2 + for k in (b'a', b'b', b'c'): + value3 = value2 + b'/' + j + node3 = TestedMerkleNode({ + 'value': value3, + }) + node2[j] = node3 + self.nodes[value3] = node3 + + def test_hash(self): + for node in self.nodes.values(): + self.assertEqual(node.compute_hash_called, 0) + + # Root hash will compute hash for all the nodes + hash = self.root.hash + for node in self.nodes.values(): + self.assertEqual(node.compute_hash_called, 1) + self.assertIn(node.data['value'], hash) + + # Should use the cached value + hash2 = self.root.hash + self.assertEqual(hash, hash2) + for node in self.nodes.values(): + self.assertEqual(node.compute_hash_called, 1) + + # Should still use the cached value + hash3 = self.root.update_hash(force=False) + self.assertEqual(hash, hash3) + for node in self.nodes.values(): + self.assertEqual(node.compute_hash_called, 1) + + # Force update of the cached value for a deeply nested node + self.root[b'a'][b'b'].update_hash(force=True) + for key, node in self.nodes.items(): + # update_hash rehashes all children + if key.startswith(b'root/a/b'): + self.assertEqual(node.compute_hash_called, 2) + else: + self.assertEqual(node.compute_hash_called, 1) + + hash4 = self.root.hash + self.assertEqual(hash, hash4) + for key, node in self.nodes.items(): + # update_hash also invalidates all parents + if key in (b'root', b'root/a') or key.startswith(b'root/a/b'): + self.assertEqual(node.compute_hash_called, 2) + else: + self.assertEqual(node.compute_hash_called, 1) + + def test_collect(self): + collected = self.root.collect() + self.assertEqual(len(collected[self.root.type]), len(self.nodes)) + for node in self.nodes.values(): + self.assertTrue(node.collected) + collected2 = self.root.collect() + self.assertEqual(collected2, {}) + + def test_get(self): + for key in (b'a', b'b', b'c'): + self.assertEqual(self.root[key], self.nodes[b'root/' + key]) + + with self.assertRaisesRegex(KeyError, "b'nonexistent'"): + self.root[b'nonexistent'] + + def test_del(self): + hash_root = self.root.hash + hash_a = self.nodes[b'root/a'].hash + del self.root[b'a'][b'c'] + hash_root2 = self.root.hash + hash_a2 = self.nodes[b'root/a'].hash + + self.assertNotEqual(hash_root, hash_root2) + self.assertNotEqual(hash_a, hash_a2) + + self.assertEqual(self.nodes[b'root/a/c'].parents, []) + + with self.assertRaisesRegex(KeyError, "b'nonexistent'"): + del self.root[b'nonexistent'] + + def test_update(self): + hash_root = self.root.hash + hash_b = self.root[b'b'].hash + new_children = { + b'c': TestedMerkleNode({'value': b'root/b/new_c'}), + b'd': TestedMerkleNode({'value': b'root/b/d'}), + } + + # collect all nodes + self.root.collect() + + self.root[b'b'].update(new_children) + + # Ensure everyone got reparented + self.assertEqual(new_children[b'c'].parents, [self.root[b'b']]) + self.assertEqual(new_children[b'd'].parents, [self.root[b'b']]) + self.assertEqual(self.nodes[b'root/b/c'].parents, []) + + hash_root2 = self.root.hash + self.assertNotEqual(hash_root, hash_root2) + self.assertIn(b'root/b/new_c', hash_root2) + self.assertIn(b'root/b/d', hash_root2) + + hash_b2 = self.root[b'b'].hash + self.assertNotEqual(hash_b, hash_b2) + + for key, node in self.nodes.items(): + if key in (b'root', b'root/b'): + self.assertEqual(node.compute_hash_called, 2) + else: + self.assertEqual(node.compute_hash_called, 1) + + # Ensure we collected root, root/b, and both new children + collected_after_update = self.root.collect() + self.assertCountEqual( + collected_after_update[TestedMerkleNode.type], + [self.nodes[b'root'].hash, self.nodes[b'root/b'].hash, + new_children[b'c'].hash, new_children[b'd'].hash], + ) + + # test that noop updates doesn't invalidate anything + self.root[b'a'][b'b'].update({}) + self.assertEqual(self.root.collect(), {})