diff --git a/swh/model/merkle.py b/swh/model/merkle.py --- a/swh/model/merkle.py +++ b/swh/model/merkle.py @@ -120,6 +120,13 @@ self.__hash = None self.collected = False + def __eq__(self, other): + return isinstance(other, MerkleNode) \ + and super().__eq__(other) and self.data == other.data + + def __ne__(self, other): + return not self.__eq__(other) + def invalidate_hash(self): """Invalidate the cached hash of the current node.""" if not self.__hash: diff --git a/swh/model/tests/test_merkle.py b/swh/model/tests/test_merkle.py --- a/swh/model/tests/test_merkle.py +++ b/swh/model/tests/test_merkle.py @@ -46,6 +46,14 @@ self.data = {'value': b'value'} self.instance = MerkleTestLeaf(self.data) + def test_equality(self): + leaf1 = MerkleTestLeaf(self.data) + leaf2 = MerkleTestLeaf(self.data) + leaf3 = MerkleTestLeaf({}) + + self.assertEqual(leaf1, leaf2) + self.assertNotEqual(leaf1, leaf3) + def test_hash(self): self.assertEqual(self.instance.compute_hash_called, 0) instance_hash = self.instance.hash @@ -114,6 +122,20 @@ node2[j] = node3 self.nodes[value3] = node3 + def test_equality(self): + node1 = merkle.MerkleNode({'foo': b'bar'}) + node2 = merkle.MerkleNode({'foo': b'bar'}) + node3 = merkle.MerkleNode({}) + + self.assertEqual(node1, node2) + self.assertNotEqual(node1, node3, node1 == node3) + + node1['foo'] = node3 + self.assertNotEqual(node1, node2) + + node2['foo'] = node3 + self.assertEqual(node1, node2) + def test_hash(self): for node in self.nodes.values(): self.assertEqual(node.compute_hash_called, 0)