diff --git a/swh/model/from_disk.py b/swh/model/from_disk.py new file mode 100644 --- /dev/null +++ b/swh/model/from_disk.py @@ -0,0 +1,311 @@ +# Copyright (C) 2017 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import enum +import os +import stat + +from . import hashutil +from .merkle import MerkleLeaf, MerkleNode +from .identifiers import ( + directory_identifier, + identifier_to_bytes as id_to_bytes, + identifier_to_str as id_to_str, +) + + +class DentryPerms(enum.IntEnum): + """Admissible permissions for directory entries.""" + content = 0o100644 + """Content""" + executable_content = 0o100755 + """Executable content (e.g. executable script)""" + symlink = 0o120000 + """Symbolic link""" + directory = 0o040000 + """Directory""" + revision = 0o160000 + """Revision (e.g. submodule)""" + + +def mode_to_perms(mode): + """Convert a file mode to a permission compatible with Software Heritage + directory entries + + Args: + mode (int): a file mode as returned by :func:`os.stat` in + :attr:`os.stat_result.st_mode` + + Returns: + DentryPerms: one of the following values: + :const:`DentryPerms.content`: plain file + :const:`DentryPerms.executable_content`: executable file + :const:`DentryPerms.symlink`: symbolic link + :const:`DentryPerms.directory`: directory + + """ + if stat.S_ISLNK(mode): + return DentryPerms.symlink + if stat.S_ISDIR(mode): + return DentryPerms.directory + else: + # file is executable in any way + if mode & (0o111): + return DentryPerms.executable_content + else: + return DentryPerms.content + + +class Content(MerkleLeaf): + """Representation of a Software Heritage content as a node in a Merkle tree. + + The current data structure uses the `sha1_git` hash as a key. + + """ + __slots__ = [] + type = 'content' + + @classmethod + def from_bytes(cls, *, mode, data): + """Convert data (raw :class:`bytes`) to a Software Heritage content entry + + Args: + mode (int): a file mode (passed to :func:`mode_to_perms`) + data (bytes): raw contents of the file + """ + ret = hashutil.hash_data(data) + ret['length'] = len(data) + ret['perms'] = mode_to_perms(mode) + ret['data'] = data + + return cls(ret) + + @classmethod + def from_symlink(cls, *, path, mode): + """Convert a symbolic link to a Software Heritage content entry""" + return cls.from_bytes(mode=mode, data=os.readlink(path)) + + @classmethod + def from_file(cls, *, path, data=False): + """Compute the Software Heritage content entry corresponding to an on-disk + file. + + The returned dictionary contains keys useful for both: + - loading the content in the archive (hashes, `length`) + - using the content as a directory entry in a directory + + Args: + path (bytes): path to the file for which we're computing the + content entry + data (bool): add the file data to the entry + """ + file_stat = os.lstat(path) + mode = file_stat.st_mode + + if stat.S_ISLNK(mode): + # Symbolic link: return a file whose contents are the link target + return cls.from_symlink(path=path, mode=mode) + elif not stat.S_ISREG(mode): + # not a regular file: return the empty file instead + return cls.from_bytes(mode=mode, data=b'') + + length = file_stat.st_size + + if not data: + ret = hashutil.hash_path(path) + else: + chunks = [] + + def append_chunk(x, chunks=chunks): + chunks.append(x) + + with open(path, 'rb') as fobj: + ret = hashutil.hash_file(fobj, length=length, + chunk_cb=append_chunk) + + ret['data'] = b''.join(chunks) + + ret['perms'] = mode_to_perms(mode) + ret['length'] = length + + obj = cls(ret) + return obj + + def __repr__(self): + return 'Content(id=%s)' % id_to_str(self.hash) + + def compute_hash(self): + return self.data['sha1_git'] + + +def accept_all_directories(dirname, entries): + """Default filter for :func:`Directory.from_disk` accepting all + directories + + Args: + dirname (bytes): directory name + entries (list): directory entries + """ + return True + + +def ignore_empty_directories(dirname, entries): + """Filter for :func:`directory_to_objects` ignoring empty directories + + Args: + dirname (bytes): directory name + entries (list): directory entries + Returns: + True if the directory is not empty, false if the directory is empty + """ + return bool(entries) + + +def ignore_named_directories(names, *, case_sensitive=True): + """Filter for :func:`directory_to_objects` to ignore directories named one + of names. + + Args: + names (list of bytes): names to ignore + case_sensitive (bool): whether to do the filtering in a case sensitive + way + Returns: + a directory filter for :func:`directory_to_objects` + """ + if not case_sensitive: + names = [name.lower() for name in names] + + def named_filter(dirname, entries, + names=names, case_sensitive=case_sensitive): + if case_sensitive: + return dirname not in names + else: + return dirname.lower() not in names + + return named_filter + + +class Directory(MerkleNode): + __slots__ = ['__entries'] + type = 'directory' + + @classmethod + def from_disk(cls, *, path, data=False, + dir_filter=accept_all_directories): + """Compute the Software Heritage objects for a given directory tree + + Args: + path (bytes): the directory to traverse + data (bool): whether to add the data to the content objects + dir_filter (function): a filter to ignore some directories by + name or contents. Takes two arguments: dirname and entries, and + returns True if the directory should be added, False if the + directory should be ignored. + """ + + top_path = path + dirs = {} + + for root, dentries, fentries in os.walk(top_path, topdown=False): + entries = {} + # Join fentries and dentries in the same processing, as symbolic + # links to directories appear in dentries... + for name in fentries + dentries: + path = os.path.join(root, name) + if not os.path.isdir(path) or os.path.islink(path): + content = Content.from_file(path=path, data=data) + entries[name] = content + else: + if dir_filter(name, dirs[path].entries): + entries[name] = dirs[path] + + dirs[root] = cls({'name': os.path.basename(root)}) + dirs[root].update(entries) + + return dirs[top_path] + + def __init__(self, data=None): + super().__init__(data=data) + self.__entries = None + + def invalidate_hash(self): + self.__entries = None + super().invalidate_hash() + + @staticmethod + def child_to_directory_entry(name, child): + if isinstance(child, Directory): + return { + 'type': 'dir', + 'perms': DentryPerms.directory, + 'target': child.hash, + 'name': name, + } + elif isinstance(child, Content): + return { + 'type': 'file', + 'perms': child.data['perms'], + 'target': child.hash, + 'name': name, + } + else: + raise ValueError('unknown child') + + def get_data(self, **kwargs): + return { + 'id': self.hash, + 'entries': self.entries, + } + + @property + def entries(self): + if self.__entries is None: + self.__entries = [ + self.child_to_directory_entry(name, child) + for name, child in self.items() + ] + + return self.__entries + + def compute_hash(self): + return id_to_bytes(directory_identifier({'entries': self.entries})) + + def __getitem__(self, key): + if not isinstance(key, bytes): + raise ValueError('Can only get a bytes from directory') + if b'/' not in key: + return super().__getitem__(key) + else: + key1, key2 = key.split(b'/', 1) + return super().__getitem__(key1)[key2] + + def __setitem__(self, key, value): + if not isinstance(key, bytes): + raise ValueError('Can only set a bytes directory entry') + if not isinstance(value, (Content, Directory)): + raise ValueError('Can only set a directory entry to a Content or ' + 'Directory') + + if b'/' not in key: + return super().__setitem__(key, value) + else: + key1, key2 = key.rsplit(b'/', 1) + self[key1].add_child(key2, value) + + def __delitem__(self, key): + if not isinstance(key, bytes): + raise ValueError('Can only delete a bytes directory entry') + + if b'/' not in key: + super().__delitem__(key) + else: + key1, key2 = key.rsplit(b'/', 1) + del super().__getitem__(key1)[key2] + + def __repr__(self): + return 'Directory(%s, id=%s)' % ( + self.data['name'], + id_to_str(self.hash) if self.hash else '?', + ) diff --git a/swh/model/merkle.py b/swh/model/merkle.py new file mode 100644 --- /dev/null +++ b/swh/model/merkle.py @@ -0,0 +1,272 @@ +# Copyright (C) 2017 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +"""Merkle tree data structure""" + +import abc +import collections + + +def deep_update(left, right): + """Recursively update the left mapping with deeply nested values from the right + mapping. + + This function is useful to merge the results of several calls to + :func:`MerkleNode.collect`. + + Arguments: + left: a mapping (modified by the update operation) + right: a mapping + + Returns: + the left mapping, updated with nested values from the right mapping + + Example: + >>> a = { + ... 'key1': { + ... 'key2': { + ... 'key3': 'value1/2/3', + ... }, + ... }, + ... } + >>> deep_update(a, { + ... 'key1': { + ... 'key2': { + ... 'key4': 'value1/2/4', + ... }, + ... }, + ... }) + {'key1': {'key2': {'key3': 'value1/2/3', 'key4': 'value1/2/4'}}} + >>> deep_update(a, { + ... 'key1': { + ... 'key2': { + ... 'key3': 'newvalue1/2/3', + ... }, + ... }, + ... }) + {'key1': {'key2': {'key3': 'newvalue1/2/3', 'key4': 'value1/2/4'}}} + + """ + for key, rvalue in right.items(): + if isinstance(rvalue, collections.Mapping): + new_lvalue = deep_update(left.get(key, {}), rvalue) + left[key] = new_lvalue + else: + left[key] = rvalue + return left + + +class MerkleNode(dict, metaclass=abc.ABCMeta): + """Representation of a node in a Merkle Tree. + + A (generalized) `Merkle Tree`_ is a tree in which every node is labeled + with a hash of its own data and the hash of its children. + + .. _Merkle Tree: https://en.wikipedia.org/wiki/Merkle_tree + + In pseudocode:: + + node.hash = hash(node.data + + sum(child.hash for child in node.children)) + + This class efficiently implements the Merkle Tree data structure on top of + a Python :class:`dict`, minimizing hash computations and new data + collections when updating nodes. + + Node data is stored in the :attr:`data` attribute, while (named) children + are stored as items of the underlying dictionary. + + Addition, update and removal of objects are instrumented to automatically + invalidate the hashes of the current node as well as its registered + parents; It also resets the collection status of the objects so the updated + objects can be collected. + + The collection of updated data from the tree is implemented through the + :func:`collect` function and associated helpers. + + Attributes: + data (dict): data associated to the current node + parents (list): known parents of the current node + collected (bool): whether the current node has been collected + + """ + __slots__ = ['parents', 'data', '__hash', 'collected'] + + type = None + """Type of the current node (used as a classifier for :func:`collect`)""" + + def __init__(self, data=None): + super().__init__() + self.parents = [] + self.data = data + self.__hash = None + self.collected = False + + def invalidate_hash(self): + """Invalidate the cached hash of the current node.""" + if not self.__hash: + return + + self.__hash = None + self.collected = False + for parent in self.parents: + parent.invalidate_hash() + + def update_hash(self, *, force=False): + """Recursively compute the hash of the current node. + + Args: + force (bool): invalidate the cache and force the computation for + this node and all children. + """ + if self.__hash and not force: + return self.__hash + + if force: + self.invalidate_hash() + + for child in self.values(): + child.update_hash(force=force) + + self.__hash = self.compute_hash() + return self.__hash + + @property + def hash(self): + """The hash of the current node, as calculated by + :func:`compute_hash`. + """ + return self.update_hash() + + @abc.abstractmethod + def compute_hash(self): + """Compute the hash of the current node. + + The hash should depend on the data of the node, as well as on hashes + of the children nodes. + """ + raise NotImplementedError('Must implement compute_hash method') + + def __setitem__(self, name, new_child): + """Add a child, invalidating the current hash""" + self.invalidate_hash() + + super().__setitem__(name, new_child) + + new_child.parents.append(self) + + def __delitem__(self, name): + """Remove a child, invalidating the current hash""" + if name in self: + self.invalidate_hash() + self[name].parents.remove(self) + super().__delitem__(name) + else: + raise KeyError(name) + + def update(self, new_children): + """Add several named children from a dictionary""" + if not new_children: + return + + self.invalidate_hash() + + for name, new_child in new_children.items(): + new_child.parents.append(self) + if name in self: + self[name].parents.remove(self) + + super().update(new_children) + + def get_data(self, **kwargs): + """Retrieve and format the collected data for the current node, for use by + :func:`collect`. + + Can be overridden, for instance when you want the collected data to + contain information about the child nodes. + + Arguments: + kwargs: allow subclasses to alter behaviour depending on how + :func:`collect` is called. + + Returns: + data formatted for :func:`collect` + """ + return self.data + + def collect_node(self, **kwargs): + """Collect the data for the current node, for use by :func:`collect`. + + Arguments: + kwargs: passed as-is to :func:`get_data`. + + Returns: + A :class:`dict` compatible with :func:`collect`. + """ + if not self.collected: + self.collected = True + return {self.type: {self.hash: self.get_data(**kwargs)}} + else: + return {} + + def collect(self, **kwargs): + """Collect the data for all nodes in the subtree rooted at `self`. + + The data is deduplicated by type and by hash. + + Arguments: + kwargs: passed as-is to :func:`get_data`. + + Returns: + A :class:`dict` with the following structure:: + + { + 'typeA': { + node1.hash: node1.get_data(), + node2.hash: node2.get_data(), + }, + 'typeB': { + node3.hash: node3.get_data(), + ... + }, + ... + } + """ + ret = self.collect_node(**kwargs) + for child in self.values(): + deep_update(ret, child.collect(**kwargs)) + + return ret + + def reset_collect(self): + """Recursively unmark collected nodes in the subtree rooted at `self`. + + This lets the caller use :func:`collect` again. + """ + self.collected = False + + for child in self.values(): + child.reset_collect() + + +class MerkleLeaf(MerkleNode): + """A leaf to a Merkle tree. + + A Merkle leaf is simply a Merkle node with children disabled. + """ + __slots__ = [] + + def __setitem__(self, name, child): + raise ValueError('%s is a leaf' % self.__class__.__name__) + + def __getitem__(self, name): + raise ValueError('%s is a leaf' % self.__class__.__name__) + + def __delitem__(self, name): + raise ValueError('%s is a leaf' % self.__class__.__name__) + + def update(self, new_children): + """Children update operation. Disabled for leaves.""" + raise ValueError('%s is a leaf' % self.__class__.__name__) diff --git a/swh/model/tests/test_from_disk.py b/swh/model/tests/test_from_disk.py new file mode 100644 --- /dev/null +++ b/swh/model/tests/test_from_disk.py @@ -0,0 +1,268 @@ +# Copyright (C) 2017 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import os +import tempfile +import unittest + +from swh.model import from_disk +from swh.model.from_disk import Content, Directory +from swh.model.hashutil import hash_to_bytes + + +class ModeToPerms(unittest.TestCase): + def setUp(self): + super().setUp() + + perms = from_disk.DentryPerms + + # Generate a full permissions map + self.perms_map = {} + + # Symlinks + for i in range(0o120000, 0o127777 + 1): + self.perms_map[i] = perms.symlink + + # Directories + for i in range(0o040000, 0o047777 + 1): + self.perms_map[i] = perms.directory + + # Other file types: socket, regular file, block device, character + # device, fifo all map to regular files + for ft in [0o140000, 0o100000, 0o060000, 0o020000, 0o010000]: + for i in range(ft, ft + 0o7777 + 1): + if i & 0o111: + # executable bits are set + self.perms_map[i] = perms.executable_content + else: + self.perms_map[i] = perms.content + + def test_exhaustive_mode_to_perms(self): + for fmode, perm in self.perms_map.items(): + self.assertEqual(perm, from_disk.mode_to_perms(fmode)) + + +class DataMixin: + maxDiff = None + + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory( + prefix=b'swh.model.from_disk' + ) + self.contents = { + b'file': { + 'data': b'42\n', + 'sha1': hash_to_bytes( + '34973274ccef6ab4dfaaf86599792fa9c3fe4689' + ), + 'sha256': hash_to_bytes( + '084c799cd551dd1d8d5c5f9a5d593b2e' + '931f5e36122ee5c793c1d08a19839cc0' + ), + 'sha1_git': hash_to_bytes( + 'd81cc0710eb6cf9efd5b920a8453e1e07157b6cd'), + 'blake2s256': hash_to_bytes( + 'd5fe1939576527e42cfd76a9455a2432' + 'fe7f56669564577dd93c4280e76d661d' + ), + 'length': 3, + 'mode': 0o100644 + }, + } + + self.symlinks = { + b'symlink': { + 'data': b'target', + 'blake2s256': hash_to_bytes( + '595d221b30fdd8e10e2fdf18376e688e' + '9f18d56fd9b6d1eb6a822f8c146c6da6' + ), + 'sha1': hash_to_bytes( + '0e8a3ad980ec179856012b7eecf4327e99cd44cd' + ), + 'sha1_git': hash_to_bytes( + '1de565933b05f74c75ff9a6520af5f9f8a5a2f1d' + ), + 'sha256': hash_to_bytes( + '34a04005bcaf206eec990bd9637d9fdb' + '6725e0a0c0d4aebf003f17f4c956eb5c' + ), + 'length': 6, + } + } + + self.specials = { + b'fifo': os.mkfifo, + b'devnull': lambda path: os.mknod(path, device=os.makedev(1, 3)), + } + + self.empty_content = { + 'data': b'', + 'length': 0, + 'length': 0, + 'blake2s256': hash_to_bytes( + '69217a3079908094e11121d042354a7c' + '1f55b6482ca1a51e1b250dfd1ed0eef9' + ), + 'sha1': hash_to_bytes( + 'da39a3ee5e6b4b0d3255bfef95601890afd80709' + ), + 'sha1_git': hash_to_bytes( + 'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391' + ), + 'sha256': hash_to_bytes( + 'e3b0c44298fc1c149afbf4c8996fb924' + '27ae41e4649b934ca495991b7852b855' + ), + } + + def tearDown(self): + self.tmpdir.cleanup() + + def make_contents(self, directory): + for filename, content in self.contents.items(): + path = os.path.join(directory, filename) + with open(path, 'wb') as f: + f.write(content['data']) + os.chmod(path, content['mode']) + + def make_symlinks(self, directory): + for filename, symlink in self.symlinks.items(): + path = os.path.join(directory, filename) + os.symlink(symlink['data'], path) + + def make_specials(self, directory): + for filename, fn in self.specials.items(): + path = os.path.join(directory, filename) + fn(path) + + +class TestContent(DataMixin, unittest.TestCase): + def setUp(self): + super().setUp() + + def test_data_to_content(self): + for filename, content in self.contents.items(): + mode = content.pop('mode') + conv_content = Content.from_bytes(mode=mode, data=content['data']) + structure = conv_content.data.copy() + self.assertEqual(structure.pop('perms'), + from_disk.mode_to_perms(mode)) + self.assertEqual(structure, content) + + +class SymlinkToContent(DataMixin, unittest.TestCase): + def setUp(self): + super().setUp() + self.make_symlinks(self.tmpdir.name) + + def test_symlink_to_content(self): + for filename, symlink in self.symlinks.items(): + path = os.path.join(self.tmpdir.name, filename) + perms = 0o120000 + conv_content = Content.from_symlink(path=path, mode=perms) + conv_content = conv_content.data.copy() + self.assertEqual(conv_content.pop('perms'), perms) + self.assertEqual(conv_content, symlink) + + +class FileToContent(DataMixin, unittest.TestCase): + def setUp(self): + super().setUp() + self.make_contents(self.tmpdir.name) + self.make_symlinks(self.tmpdir.name) + self.make_specials(self.tmpdir.name) + + def test_file_to_content(self): + for data in [False, True]: + for filename, symlink in self.symlinks.items(): + path = os.path.join(self.tmpdir.name, filename) + perms = 0o120000 + conv_content = Content.from_file(path=path, data=data) + conv_content = conv_content.data.copy() + self.assertEqual(conv_content.pop('perms'), perms) + if not data: + conv_content['data'] = symlink['data'] + self.assertEqual(conv_content, symlink) + + for filename, content in self.contents.items(): + content = content.copy() + path = os.path.join(self.tmpdir.name, filename) + perms = 0o100644 + if content.pop('mode') & 0o111: + perms = 0o100755 + conv_content = Content.from_file(path=path, data=data) + conv_content = conv_content.data.copy() + self.assertEqual(conv_content.pop('perms'), perms) + if not data: + conv_content['data'] = content['data'] + self.assertEqual(conv_content, content) + + for filename in self.specials: + path = os.path.join(self.tmpdir.name, filename) + perms = 0o100644 + conv_content = Content.from_file(path=path, data=data) + conv_content = conv_content.data.copy() + self.assertEqual(conv_content.pop('perms'), perms) + if not data: + conv_content['data'] = b'' + self.assertEqual(conv_content, self.empty_content) + + +class DirectoryToObjects(DataMixin, unittest.TestCase): + def setUp(self): + super().setUp() + contents = os.path.join(self.tmpdir.name, b'contents') + os.mkdir(contents) + self.make_contents(contents) + symlinks = os.path.join(self.tmpdir.name, b'symlinks') + os.mkdir(symlinks) + self.make_symlinks(symlinks) + specials = os.path.join(self.tmpdir.name, b'specials') + os.mkdir(specials) + self.make_specials(specials) + empties = os.path.join(self.tmpdir.name, b'empty1', b'empty2') + os.makedirs(empties) + + def test_directory_to_objects(self): + objs = Directory.from_disk(path=self.tmpdir.name).collect() + + self.assertIn('content', objs) + self.assertIn('directory', objs) + + self.assertEqual(len(objs['directory']), 6) + self.assertEqual(len(objs['content']), + len(self.contents) + + len(self.symlinks) + + 1) + + def test_directory_to_objects_ignore_empty(self): + objs = Directory.from_disk( + path=self.tmpdir.name, + dir_filter=from_disk.ignore_empty_directories + ).collect() + + self.assertIn('content', objs) + self.assertIn('directory', objs) + + self.assertEqual(len(objs['directory']), 4) + self.assertEqual(len(objs['content']), + len(self.contents) + + len(self.symlinks) + + 1) + + def test_directory_to_objects_ignore_name(self): + objs = Directory.from_disk( + path=self.tmpdir.name, + dir_filter=from_disk.ignore_named_directories([b'symlinks']) + ).collect() + + self.assertIn('content', objs) + self.assertIn('directory', objs) + + self.assertEqual(len(objs['directory']), 5) + self.assertEqual(len(objs['content']), + len(self.contents) + + 1) diff --git a/swh/model/tests/test_merkle.py b/swh/model/tests/test_merkle.py new file mode 100644 --- /dev/null +++ b/swh/model/tests/test_merkle.py @@ -0,0 +1,229 @@ +# Copyright (C) 2017 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import unittest + +from swh.model import merkle + + +class TestedMerkleNode(merkle.MerkleNode): + type = 'tested_merkle_node_type' + + def __init__(self, data): + super().__init__(data) + self.compute_hash_called = 0 + + def compute_hash(self): + self.compute_hash_called += 1 + child_data = [ + child + b'=' + self[child].hash + for child in sorted(self) + ] + + return ( + b'hash(' + + b', '.join([self.data['value']] + child_data) + + b')' + ) + + +class TestedMerkleLeaf(merkle.MerkleLeaf): + type = 'tested_merkle_leaf_type' + + def __init__(self, data): + super().__init__(data) + self.compute_hash_called = 0 + + def compute_hash(self): + self.compute_hash_called += 1 + return b'hash(' + self.data['value'] + b')' + + +class TestMerkleLeaf(unittest.TestCase): + def setUp(self): + self.data = {'value': b'value'} + self.instance = TestedMerkleLeaf(self.data) + + def test_hash(self): + self.assertEqual(self.instance.compute_hash_called, 0) + instance_hash = self.instance.hash + self.assertEqual(self.instance.compute_hash_called, 1) + instance_hash2 = self.instance.hash + self.assertEqual(self.instance.compute_hash_called, 1) + self.assertEqual(instance_hash, instance_hash2) + + def test_data(self): + self.assertEqual(self.instance.get_data(), self.data) + + def test_collect(self): + collected = self.instance.collect() + self.assertEqual( + collected, { + self.instance.type: { + self.instance.hash: self.instance.get_data(), + }, + }, + ) + collected2 = self.instance.collect() + self.assertEqual(collected2, {}) + self.instance.reset_collect() + collected3 = self.instance.collect() + self.assertEqual(collected, collected3) + + def test_leaf(self): + with self.assertRaisesRegex(ValueError, 'is a leaf'): + self.instance[b'key1'] = 'Test' + + with self.assertRaisesRegex(ValueError, 'is a leaf'): + del self.instance[b'key1'] + + with self.assertRaisesRegex(ValueError, 'is a leaf'): + self.instance[b'key1'] + + with self.assertRaisesRegex(ValueError, 'is a leaf'): + self.instance.update(self.data) + + +class TestMerkleNode(unittest.TestCase): + maxDiff = None + + def setUp(self): + self.root = TestedMerkleNode({'value': b'root'}) + self.nodes = {b'root': self.root} + for i in (b'a', b'b', b'c'): + value = b'root/' + i + node = TestedMerkleNode({ + 'value': value, + }) + self.root[i] = node + self.nodes[value] = node + for j in (b'a', b'b', b'c'): + value2 = value + b'/' + j + node2 = TestedMerkleNode({ + 'value': value2, + }) + node[j] = node2 + self.nodes[value2] = node2 + for k in (b'a', b'b', b'c'): + value3 = value2 + b'/' + j + node3 = TestedMerkleNode({ + 'value': value3, + }) + node2[j] = node3 + self.nodes[value3] = node3 + + def test_hash(self): + for node in self.nodes.values(): + self.assertEqual(node.compute_hash_called, 0) + + # Root hash will compute hash for all the nodes + hash = self.root.hash + for node in self.nodes.values(): + self.assertEqual(node.compute_hash_called, 1) + self.assertIn(node.data['value'], hash) + + # Should use the cached value + hash2 = self.root.hash + self.assertEqual(hash, hash2) + for node in self.nodes.values(): + self.assertEqual(node.compute_hash_called, 1) + + # Should still use the cached value + hash3 = self.root.update_hash(force=False) + self.assertEqual(hash, hash3) + for node in self.nodes.values(): + self.assertEqual(node.compute_hash_called, 1) + + # Force update of the cached value for a deeply nested node + self.root[b'a'][b'b'].update_hash(force=True) + for key, node in self.nodes.items(): + # update_hash rehashes all children + if key.startswith(b'root/a/b'): + self.assertEqual(node.compute_hash_called, 2) + else: + self.assertEqual(node.compute_hash_called, 1) + + hash4 = self.root.hash + self.assertEqual(hash, hash4) + for key, node in self.nodes.items(): + # update_hash also invalidates all parents + if key in (b'root', b'root/a') or key.startswith(b'root/a/b'): + self.assertEqual(node.compute_hash_called, 2) + else: + self.assertEqual(node.compute_hash_called, 1) + + def test_collect(self): + collected = self.root.collect() + self.assertEqual(len(collected[self.root.type]), len(self.nodes)) + for node in self.nodes.values(): + self.assertTrue(node.collected) + collected2 = self.root.collect() + self.assertEqual(collected2, {}) + + def test_get(self): + for key in (b'a', b'b', b'c'): + self.assertEqual(self.root[key], self.nodes[b'root/' + key]) + + with self.assertRaisesRegex(KeyError, "b'nonexistent'"): + self.root[b'nonexistent'] + + def test_del(self): + hash_root = self.root.hash + hash_a = self.nodes[b'root/a'].hash + del self.root[b'a'][b'c'] + hash_root2 = self.root.hash + hash_a2 = self.nodes[b'root/a'].hash + + self.assertNotEqual(hash_root, hash_root2) + self.assertNotEqual(hash_a, hash_a2) + + self.assertEqual(self.nodes[b'root/a/c'].parents, []) + + with self.assertRaisesRegex(KeyError, "b'nonexistent'"): + del self.root[b'nonexistent'] + + def test_update(self): + hash_root = self.root.hash + hash_b = self.root[b'b'].hash + new_children = { + b'c': TestedMerkleNode({'value': b'root/b/new_c'}), + b'd': TestedMerkleNode({'value': b'root/b/d'}), + } + + # collect all nodes + self.root.collect() + + self.root[b'b'].update(new_children) + + # Ensure everyone got reparented + self.assertEqual(new_children[b'c'].parents, [self.root[b'b']]) + self.assertEqual(new_children[b'd'].parents, [self.root[b'b']]) + self.assertEqual(self.nodes[b'root/b/c'].parents, []) + + hash_root2 = self.root.hash + self.assertNotEqual(hash_root, hash_root2) + self.assertIn(b'root/b/new_c', hash_root2) + self.assertIn(b'root/b/d', hash_root2) + + hash_b2 = self.root[b'b'].hash + self.assertNotEqual(hash_b, hash_b2) + + for key, node in self.nodes.items(): + if key in (b'root', b'root/b'): + self.assertEqual(node.compute_hash_called, 2) + else: + self.assertEqual(node.compute_hash_called, 1) + + # Ensure we collected root, root/b, and both new children + collected_after_update = self.root.collect() + self.assertCountEqual( + collected_after_update[TestedMerkleNode.type], + [self.nodes[b'root'].hash, self.nodes[b'root/b'].hash, + new_children[b'c'].hash, new_children[b'd'].hash], + ) + + # test that noop updates doesn't invalidate anything + self.root[b'a'][b'b'].update({}) + self.assertEqual(self.root.collect(), {})