diff --git a/swh/model/from_disk.py b/swh/model/from_disk.py --- a/swh/model/from_disk.py +++ b/swh/model/from_disk.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017-2020 The Software Heritage developers +# Copyright (C) 2017-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information diff --git a/swh/model/merkle.py b/swh/model/merkle.py --- a/swh/model/merkle.py +++ b/swh/model/merkle.py @@ -1,76 +1,14 @@ -# Copyright (C) 2017-2020 The Software Heritage developers +# Copyright (C) 2017-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """Merkle tree data structure""" -import abc -from collections.abc import Mapping -from typing import Dict, Iterator, List, Set - - -def deep_update(left, right): - """Recursively update the left mapping with deeply nested values from the right - mapping. - - This function is useful to merge the results of several calls to - :func:`MerkleNode.collect`. - - Arguments: - left: a mapping (modified by the update operation) - right: a mapping - - Returns: - the left mapping, updated with nested values from the right mapping - - Example: - >>> a = { - ... 'key1': { - ... 'key2': { - ... 'key3': 'value1/2/3', - ... }, - ... }, - ... } - >>> deep_update(a, { - ... 'key1': { - ... 'key2': { - ... 'key4': 'value1/2/4', - ... }, - ... }, - ... }) == { - ... 'key1': { - ... 'key2': { - ... 'key3': 'value1/2/3', - ... 'key4': 'value1/2/4', - ... }, - ... }, - ... } - True - >>> deep_update(a, { - ... 'key1': { - ... 'key2': { - ... 'key3': 'newvalue1/2/3', - ... }, - ... }, - ... }) == { - ... 'key1': { - ... 'key2': { - ... 'key3': 'newvalue1/2/3', - ... 'key4': 'value1/2/4', - ... }, - ... }, - ... } - True +from __future__ import annotations - """ - for key, rvalue in right.items(): - if isinstance(rvalue, Mapping): - new_lvalue = deep_update(left.get(key, {}), rvalue) - left[key] = new_lvalue - else: - left[key] = rvalue - return left +import abc +from typing import Any, Dict, Iterator, List, Set class MerkleNode(dict, metaclass=abc.ABCMeta): @@ -141,7 +79,7 @@ for parent in self.parents: parent.invalidate_hash() - def update_hash(self, *, force=False): + def update_hash(self, *, force=False) -> Any: """Recursively compute the hash of the current node. Args: @@ -161,14 +99,17 @@ return self.__hash @property - def hash(self): + def hash(self) -> Any: """The hash of the current node, as calculated by :func:`compute_hash`. """ return self.update_hash() + def __hash__(self): + return hash(self.hash) + @abc.abstractmethod - def compute_hash(self): + def compute_hash(self) -> Any: """Compute the hash of the current node. The hash should depend on the data of the node, as well as on hashes @@ -223,47 +164,24 @@ """ return self.data - def collect_node(self, **kwargs): - """Collect the data for the current node, for use by :func:`collect`. - - Arguments: - kwargs: passed as-is to :func:`get_data`. - - Returns: - A :class:`dict` compatible with :func:`collect`. - """ + def collect_node(self) -> Set[MerkleNode]: + """Collect the current node if it has not been yet, for use by :func:`collect`.""" if not self.collected: self.collected = True - return {self.object_type: {self.hash: self.get_data(**kwargs)}} + return {self} else: - return {} + return set() - def collect(self, **kwargs): - """Collect the data for all nodes in the subtree rooted at `self`. - - The data is deduplicated by type and by hash. - - Arguments: - kwargs: passed as-is to :func:`get_data`. + def collect(self) -> Set[MerkleNode]: + """Collect the added and modified nodes in the subtree rooted at `self` + since the last collect operation. Returns: - A :class:`dict` with the following structure:: - - { - 'typeA': { - node1.hash: node1.get_data(), - node2.hash: node2.get_data(), - }, - 'typeB': { - node3.hash: node3.get_data(), - ... - }, - ... - } + A :class:`set` of collected nodes """ - ret = self.collect_node(**kwargs) + ret = self.collect_node() for child in self.values(): - deep_update(ret, child.collect(**kwargs)) + ret.update(child.collect()) return ret @@ -277,14 +195,14 @@ for child in self.values(): child.reset_collect() - def iter_tree(self, dedup=True) -> Iterator["MerkleNode"]: + def iter_tree(self, dedup=True) -> Iterator[MerkleNode]: """Yields all children nodes, recursively. Common nodes are deduplicated by default (deduplication can be turned off setting the given argument 'dedup' to False). """ yield from self._iter_tree(set(), dedup) - def _iter_tree(self, seen: Set[bytes], dedup) -> Iterator["MerkleNode"]: + def _iter_tree(self, seen: Set[bytes], dedup) -> Iterator[MerkleNode]: if self.hash not in seen: if dedup: seen.add(self.hash) diff --git a/swh/model/tests/test_from_disk.py b/swh/model/tests/test_from_disk.py --- a/swh/model/tests/test_from_disk.py +++ b/swh/model/tests/test_from_disk.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017-2020 The Software Heritage developers +# Copyright (C) 2017-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -715,6 +715,21 @@ empties = os.path.join(self.tmpdir_name, b"empty1", b"empty2") os.makedirs(empties) + def check_collect( + self, directory, expected_directory_count, expected_content_count + ): + objs = directory.collect() + contents = [] + directories = [] + for obj in objs: + if isinstance(obj, Content): + contents.append(obj) + elif isinstance(obj, Directory): + directories.append(obj) + + self.assertEqual(len(directories), expected_directory_count) + self.assertEqual(len(contents), expected_content_count) + def test_directory_to_objects(self): directory = Directory.from_disk(path=self.tmpdir_name) @@ -743,13 +758,10 @@ with self.assertRaisesRegex(KeyError, "b'nonexistentdir'"): directory[b"nonexistentdir/file"] - objs = directory.collect() - - self.assertCountEqual(["content", "directory"], objs) - - self.assertEqual(len(objs["directory"]), 6) - self.assertEqual( - len(objs["content"]), len(self.contents) + len(self.symlinks) + 1 + self.check_collect( + directory, + expected_directory_count=6, + expected_content_count=len(self.contents) + len(self.symlinks) + 1, ) def test_directory_to_objects_ignore_empty(self): @@ -775,13 +787,10 @@ with self.assertRaisesRegex(KeyError, "b'empty1'"): directory[b"empty1/empty2"] - objs = directory.collect() - - self.assertCountEqual(["content", "directory"], objs) - - self.assertEqual(len(objs["directory"]), 4) - self.assertEqual( - len(objs["content"]), len(self.contents) + len(self.symlinks) + 1 + self.check_collect( + directory, + expected_directory_count=4, + expected_content_count=len(self.contents) + len(self.symlinks) + 1, ) def test_directory_to_objects_ignore_name(self): @@ -806,12 +815,11 @@ with self.assertRaisesRegex(KeyError, "b'symlinks'"): directory[b"symlinks"] - objs = directory.collect() - - self.assertCountEqual(["content", "directory"], objs) - - self.assertEqual(len(objs["directory"]), 5) - self.assertEqual(len(objs["content"]), len(self.contents) + 1) + self.check_collect( + directory, + expected_directory_count=5, + expected_content_count=len(self.contents) + 1, + ) def test_directory_to_objects_ignore_name_case(self): directory = Directory.from_disk( @@ -837,12 +845,11 @@ with self.assertRaisesRegex(KeyError, "b'symlinks'"): directory[b"symlinks"] - objs = directory.collect() - - self.assertCountEqual(["content", "directory"], objs) - - self.assertEqual(len(objs["directory"]), 5) - self.assertEqual(len(objs["content"]), len(self.contents) + 1) + self.check_collect( + directory, + expected_directory_count=5, + expected_content_count=len(self.contents) + 1, + ) def test_directory_entry_order(self): with tempfile.TemporaryDirectory() as dirname: diff --git a/swh/model/tests/test_merkle.py b/swh/model/tests/test_merkle.py --- a/swh/model/tests/test_merkle.py +++ b/swh/model/tests/test_merkle.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017-2020 The Software Heritage developers +# Copyright (C) 2017-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -15,11 +15,10 @@ super().__init__(data) self.compute_hash_called = 0 - def compute_hash(self): + def compute_hash(self) -> bytes: self.compute_hash_called += 1 child_data = [child + b"=" + self[child].hash for child in sorted(self)] - - return b"hash(" + b", ".join([self.data["value"]] + child_data) + b")" + return b"hash(" + b", ".join([self.data.get("value", b"")] + child_data) + b")" class MerkleTestLeaf(merkle.MerkleLeaf): @@ -31,7 +30,7 @@ def compute_hash(self): self.compute_hash_called += 1 - return b"hash(" + self.data["value"] + b")" + return b"hash(" + self.data.get("value", b"") + b")" class TestMerkleLeaf(unittest.TestCase): @@ -62,14 +61,10 @@ collected = self.instance.collect() self.assertEqual( collected, - { - self.instance.object_type: { - self.instance.hash: self.instance.get_data(), - }, - }, + {self.instance}, ) collected2 = self.instance.collect() - self.assertEqual(collected2, {}) + self.assertEqual(collected2, set()) self.instance.reset_collect() collected3 = self.instance.collect() self.assertEqual(collected, collected3) @@ -123,17 +118,17 @@ self.nodes[value3] = node3 def test_equality(self): - node1 = merkle.MerkleNode({"foo": b"bar"}) - node2 = merkle.MerkleNode({"foo": b"bar"}) - node3 = merkle.MerkleNode({}) + node1 = MerkleTestNode({"value": b"bar"}) + node2 = MerkleTestNode({"value": b"bar"}) + node3 = MerkleTestNode({}) self.assertEqual(node1, node2) self.assertNotEqual(node1, node3, node1 == node3) - node1["foo"] = node3 + node1[b"a"] = node3 self.assertNotEqual(node1, node2) - node2["foo"] = node3 + node2[b"a"] = node3 self.assertEqual(node1, node2) def test_hash(self): @@ -178,11 +173,11 @@ def test_collect(self): collected = self.root.collect() - self.assertEqual(len(collected[self.root.object_type]), len(self.nodes)) + self.assertEqual(collected, set(self.nodes.values())) for node in self.nodes.values(): self.assertTrue(node.collected) collected2 = self.root.collect() - self.assertEqual(collected2, {}) + self.assertEqual(collected2, set()) def test_iter_tree_with_deduplication(self): nodes = list(self.root.iter_tree()) @@ -252,16 +247,16 @@ # Ensure we collected root, root/b, and both new children collected_after_update = self.root.collect() - self.assertCountEqual( - collected_after_update[MerkleTestNode.object_type], - [ - self.nodes[b"root"].hash, - self.nodes[b"root/b"].hash, - new_children[b"c"].hash, - new_children[b"d"].hash, - ], + self.assertEqual( + collected_after_update, + { + self.nodes[b"root"], + self.nodes[b"root/b"], + new_children[b"c"], + new_children[b"d"], + }, ) # test that noop updates doesn't invalidate anything self.root[b"a"][b"b"].update({}) - self.assertEqual(self.root.collect(), {}) + self.assertEqual(self.root.collect(), set())