diff --git a/swh/model/merkle.py b/swh/model/merkle.py --- a/swh/model/merkle.py +++ b/swh/model/merkle.py @@ -8,7 +8,7 @@ import abc import collections -from typing import List, Optional +from typing import Generator, List, Optional, Set def deep_update(left, right): @@ -273,6 +273,21 @@ for child in self.values(): child.reset_collect() + def iter_tree(self) -> Generator['MerkleNode', None, None]: + """Yields all children nodes, recursively. Common nodes are + deduplicated. + """ + yield from self._iter_tree(set()) + + def _iter_tree( + self, seen: Set[bytes]) -> Generator['MerkleNode', None, None]: + seen = seen or set() + if self.hash not in seen: + seen.add(self.hash) + yield self + for child in self.values(): + yield from child._iter_tree(seen=seen) + class MerkleLeaf(MerkleNode): """A leaf to a Merkle tree. diff --git a/swh/model/tests/test_merkle.py b/swh/model/tests/test_merkle.py --- a/swh/model/tests/test_merkle.py +++ b/swh/model/tests/test_merkle.py @@ -184,6 +184,10 @@ collected2 = self.root.collect() self.assertEqual(collected2, {}) + def test_iter_tree(self): + nodes = list(self.root.iter_tree()) + self.assertCountEqual(nodes, self.nodes.values()) + def test_get(self): for key in (b'a', b'b', b'c'): self.assertEqual(self.root[key], self.nodes[b'root/' + key])