diff --git a/swh/model/merkle.py b/swh/model/merkle.py --- a/swh/model/merkle.py +++ b/swh/model/merkle.py @@ -277,18 +277,19 @@ for child in self.values(): child.reset_collect() - def iter_tree(self) -> Iterator["MerkleNode"]: + def iter_tree(self, dedup=True) -> Iterator["MerkleNode"]: """Yields all children nodes, recursively. Common nodes are deduplicated. """ - yield from self._iter_tree(set()) + yield from self._iter_tree(set(), dedup) - def _iter_tree(self, seen: Set[bytes]) -> Iterator["MerkleNode"]: + def _iter_tree(self, seen: Set[bytes], dedup) -> Iterator["MerkleNode"]: if self.hash not in seen: - seen.add(self.hash) + if dedup: + seen.add(self.hash) yield self for child in self.values(): - yield from child._iter_tree(seen=seen) + yield from child._iter_tree(seen=seen, dedup=dedup) class MerkleLeaf(MerkleNode): diff --git a/swh/model/tests/test_merkle.py b/swh/model/tests/test_merkle.py --- a/swh/model/tests/test_merkle.py +++ b/swh/model/tests/test_merkle.py @@ -176,6 +176,14 @@ nodes = list(self.root.iter_tree()) self.assertCountEqual(nodes, self.nodes.values()) + def test_iter_tree_without_deduplication(self): + # duplicate existing hash in merkle tree + self.root[b"d"] = MerkleTestNode({"value": b"root/c/c/c"}) + nodes_dedup = list(self.root.iter_tree()) + nodes = list(self.root.iter_tree(dedup=False)) + assert nodes != nodes_dedup + assert len(nodes) == len(nodes_dedup) + 1 + def test_get(self): for key in (b"a", b"b", b"c"): self.assertEqual(self.root[key], self.nodes[b"root/" + key])