diff --git a/PKG-INFO b/PKG-INFO index 2df799d..4b35bba 100644 --- a/PKG-INFO +++ b/PKG-INFO @@ -1,46 +1,46 @@ Metadata-Version: 2.1 Name: swh.model -Version: 2.6.3 +Version: 2.6.4 Summary: Software Heritage data model Home-page: https://forge.softwareheritage.org/diffusion/DMOD/ Author: Software Heritage developers Author-email: swh-devel@inria.fr License: UNKNOWN Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate Project-URL: Source, https://forge.softwareheritage.org/source/swh-model Project-URL: Documentation, https://docs.softwareheritage.org/devel/swh-model/ Platform: UNKNOWN Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable Requires-Python: >=3.7 Description-Content-Type: text/markdown Provides-Extra: cli Provides-Extra: testing-minimal Provides-Extra: testing License-File: LICENSE License-File: AUTHORS swh-model ========= Implementation of the Data model of the Software Heritage project, used to archive source code artifacts. This module defines the notion of SoftWare Heritage persistent IDentifiers (SWHIDs) and provides tools to compute them: ```sh $ swh-identify fork.c kmod.c sched/deadline.c swh:1:cnt:2e391c754ae730bd2d8520c2ab497c403220c6e3 fork.c swh:1:cnt:0277d1216f80ae1adeed84a686ed34c9b2931fc2 kmod.c swh:1:cnt:57b939c81bce5d06fa587df8915f05affbe22b82 sched/deadline.c $ swh-identify --no-filename /usr/src/linux/kernel/ swh:1:dir:f9f858a48d663b3809c9e2f336412717496202ab ``` diff --git a/swh.model.egg-info/PKG-INFO b/swh.model.egg-info/PKG-INFO index 2df799d..4b35bba 100644 --- a/swh.model.egg-info/PKG-INFO +++ b/swh.model.egg-info/PKG-INFO @@ -1,46 +1,46 @@ Metadata-Version: 2.1 Name: swh.model -Version: 2.6.3 +Version: 2.6.4 Summary: Software Heritage data model Home-page: https://forge.softwareheritage.org/diffusion/DMOD/ Author: Software Heritage developers Author-email: swh-devel@inria.fr License: UNKNOWN Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate Project-URL: Source, https://forge.softwareheritage.org/source/swh-model Project-URL: Documentation, https://docs.softwareheritage.org/devel/swh-model/ Platform: UNKNOWN Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable Requires-Python: >=3.7 Description-Content-Type: text/markdown Provides-Extra: cli Provides-Extra: testing-minimal Provides-Extra: testing License-File: LICENSE License-File: AUTHORS swh-model ========= Implementation of the Data model of the Software Heritage project, used to archive source code artifacts. This module defines the notion of SoftWare Heritage persistent IDentifiers (SWHIDs) and provides tools to compute them: ```sh $ swh-identify fork.c kmod.c sched/deadline.c swh:1:cnt:2e391c754ae730bd2d8520c2ab497c403220c6e3 fork.c swh:1:cnt:0277d1216f80ae1adeed84a686ed34c9b2931fc2 kmod.c swh:1:cnt:57b939c81bce5d06fa587df8915f05affbe22b82 sched/deadline.c $ swh-identify --no-filename /usr/src/linux/kernel/ swh:1:dir:f9f858a48d663b3809c9e2f336412717496202ab ``` diff --git a/swh/model/hypothesis_strategies.py b/swh/model/hypothesis_strategies.py index cdd4d1a..ac912d4 100644 --- a/swh/model/hypothesis_strategies.py +++ b/swh/model/hypothesis_strategies.py @@ -1,510 +1,520 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime +import string from hypothesis import assume from hypothesis.extra.dateutil import timezones from hypothesis.strategies import ( binary, booleans, builds, characters, composite, datetimes, dictionaries, from_regex, integers, just, lists, none, one_of, sampled_from, sets, text, tuples, ) from .from_disk import DentryPerms from .identifiers import ( ExtendedObjectType, ExtendedSWHID, identifier_to_bytes, snapshot_identifier, ) from .model import ( BaseContent, Content, Directory, DirectoryEntry, MetadataAuthority, MetadataFetcher, ObjectType, Origin, OriginVisit, OriginVisitStatus, Person, RawExtrinsicMetadata, Release, Revision, RevisionType, SkippedContent, Snapshot, SnapshotBranch, TargetType, Timestamp, TimestampWithTimezone, ) pgsql_alphabet = characters( blacklist_categories=("Cs",), blacklist_characters=["\u0000"] ) # postgresql does not like these def optional(strategy): return one_of(none(), strategy) def pgsql_text(): return text(alphabet=pgsql_alphabet) def sha1_git(): return binary(min_size=20, max_size=20) def sha1(): return binary(min_size=20, max_size=20) @composite def extended_swhids(draw): object_type = draw(sampled_from(ExtendedObjectType)) object_id = draw(sha1_git()) return ExtendedSWHID(object_type=object_type, object_id=object_id) def aware_datetimes(): # datetimes in Software Heritage are not used for software artifacts # (which may be much older than 2000), but only for objects like scheduler # task runs, and origin visits, which were created by Software Heritage, # so at least in 2015. # We're forbidding old datetimes, because until 1956, many timezones had seconds # in their "UTC offsets" (see # ), which is not # encodable in ISO8601; and we need our datetimes to be ISO8601-encodable in the # RPC protocol min_value = datetime.datetime(2000, 1, 1, 0, 0, 0) return datetimes(min_value=min_value, timezones=timezones()) @composite -def urls(draw): +def iris(draw): protocol = draw(sampled_from(["git", "http", "https", "deb"])) - domain = draw(from_regex(r"\A([a-z]([a-z0-9-]*)\.){1,3}[a-z0-9]+\Z")) + domain = draw(from_regex(r"\A([a-z]([a-z0-9é🏛️-]*)\.){1,3}([a-z0-9é])+\Z")) return "%s://%s" % (protocol, domain) @composite def persons_d(draw): fullname = draw(binary()) email = draw(optional(binary())) name = draw(optional(binary())) assume(not (len(fullname) == 32 and email is None and name is None)) return dict(fullname=fullname, name=name, email=email) def persons(): return persons_d().map(Person.from_dict) def timestamps_d(): max_seconds = datetime.datetime.max.replace( tzinfo=datetime.timezone.utc ).timestamp() min_seconds = datetime.datetime.min.replace( tzinfo=datetime.timezone.utc ).timestamp() return builds( dict, seconds=integers(min_seconds, max_seconds), microseconds=integers(0, 1000000), ) def timestamps(): return timestamps_d().map(Timestamp.from_dict) @composite def timestamps_with_timezone_d( draw, timestamp=timestamps_d(), offset=integers(min_value=-14 * 60, max_value=14 * 60), negative_utc=booleans(), ): timestamp = draw(timestamp) offset = draw(offset) negative_utc = draw(negative_utc) assume(not (negative_utc and offset)) return dict(timestamp=timestamp, offset=offset, negative_utc=negative_utc) timestamps_with_timezone = timestamps_with_timezone_d().map( TimestampWithTimezone.from_dict ) def origins_d(): - return builds(dict, url=urls()) + return builds(dict, url=iris()) def origins(): return origins_d().map(Origin.from_dict) def origin_visits_d(): return builds( dict, visit=integers(1, 1000), - origin=urls(), + origin=iris(), date=aware_datetimes(), type=pgsql_text(), ) def origin_visits(): return origin_visits_d().map(OriginVisit.from_dict) def metadata_dicts(): return dictionaries(pgsql_text(), pgsql_text()) def origin_visit_statuses_d(): return builds( dict, visit=integers(1, 1000), - origin=urls(), + origin=iris(), type=optional(sampled_from(["git", "svn", "pypi", "debian"])), status=sampled_from( ["created", "ongoing", "full", "partial", "not_found", "failed"] ), date=aware_datetimes(), snapshot=optional(sha1_git()), metadata=optional(metadata_dicts()), ) def origin_visit_statuses(): return origin_visit_statuses_d().map(OriginVisitStatus.from_dict) @composite def releases_d(draw): target_type = sampled_from([x.value for x in ObjectType]) name = binary() message = optional(binary()) synthetic = booleans() target = sha1_git() metadata = optional(revision_metadata()) return draw( one_of( builds( dict, name=name, message=message, synthetic=synthetic, author=none(), date=none(), target=target, target_type=target_type, metadata=metadata, ), builds( dict, name=name, message=message, synthetic=synthetic, date=timestamps_with_timezone_d(), author=persons_d(), target=target, target_type=target_type, metadata=metadata, ), ) ) def releases(): return releases_d().map(Release.from_dict) revision_metadata = metadata_dicts def extra_headers(): return lists( tuples(binary(min_size=0, max_size=50), binary(min_size=0, max_size=500)) ).map(tuple) def revisions_d(): return builds( dict, message=optional(binary()), synthetic=booleans(), author=persons_d(), committer=persons_d(), date=timestamps_with_timezone_d(), committer_date=timestamps_with_timezone_d(), parents=tuples(sha1_git()), directory=sha1_git(), type=sampled_from([x.value for x in RevisionType]), metadata=optional(revision_metadata()), extra_headers=extra_headers(), ) # TODO: metadata['extra_headers'] can have binary keys and values def revisions(): return revisions_d().map(Revision.from_dict) def directory_entries_d(): return builds( dict, name=binary(), target=sha1_git(), type=sampled_from(["file", "dir", "rev"]), perms=sampled_from([perm.value for perm in DentryPerms]), ) def directory_entries(): return directory_entries_d().map(DirectoryEntry) def directories_d(): return builds(dict, entries=tuples(directory_entries_d())) def directories(): return directories_d().map(Directory.from_dict) def contents_d(): return one_of(present_contents_d(), skipped_contents_d()) def contents(): return one_of(present_contents(), skipped_contents()) def present_contents_d(): return builds( dict, data=binary(max_size=4096), ctime=optional(aware_datetimes()), status=one_of(just("visible"), just("hidden")), ) def present_contents(): return present_contents_d().map(lambda d: Content.from_data(**d)) @composite def skipped_contents_d(draw): result = BaseContent._hash_data(draw(binary(max_size=4096))) result.pop("data") nullify_attrs = draw( sets(sampled_from(["sha1", "sha1_git", "sha256", "blake2s256"])) ) for k in nullify_attrs: result[k] = None result["reason"] = draw(pgsql_text()) result["status"] = "absent" result["ctime"] = draw(optional(aware_datetimes())) return result def skipped_contents(): return skipped_contents_d().map(SkippedContent.from_dict) def branch_names(): return binary(min_size=1) def branch_targets_object_d(): return builds( dict, target=sha1_git(), target_type=sampled_from( [x.value for x in TargetType if x.value not in ("alias",)] ), ) def branch_targets_alias_d(): return builds( dict, target=sha1_git(), target_type=just("alias") ) # TargetType.ALIAS.value)) def branch_targets_d(*, only_objects=False): if only_objects: return branch_targets_object_d() else: return one_of(branch_targets_alias_d(), branch_targets_object_d()) def branch_targets(*, only_objects=False): return builds(SnapshotBranch.from_dict, branch_targets_d(only_objects=only_objects)) @composite def snapshots_d(draw, *, min_size=0, max_size=100, only_objects=False): branches = draw( dictionaries( keys=branch_names(), values=optional(branch_targets_d(only_objects=only_objects)), min_size=min_size, max_size=max_size, ) ) if not only_objects: # Make sure aliases point to actual branches unresolved_aliases = { branch: target["target"] for branch, target in branches.items() if ( target and target["target_type"] == "alias" and target["target"] not in branches ) } for alias_name, alias_target in unresolved_aliases.items(): # Override alias branch with one pointing to a real object # if max_size constraint is reached alias = alias_target if len(branches) < max_size else alias_name branches[alias] = draw(branch_targets_d(only_objects=True)) # Ensure no cycles between aliases while True: try: id_ = snapshot_identifier( { "branches": { name: branch or None for (name, branch) in branches.items() } } ) except ValueError as e: for (source, target) in e.args[1]: branches[source] = draw(branch_targets_d(only_objects=True)) else: break return dict(id=identifier_to_bytes(id_), branches=branches) def snapshots(*, min_size=0, max_size=100, only_objects=False): return snapshots_d( min_size=min_size, max_size=max_size, only_objects=only_objects ).map(Snapshot.from_dict) def metadata_authorities(): - return builds(MetadataAuthority, url=urls(), metadata=just(None)) + return builds(MetadataAuthority, url=iris(), metadata=just(None)) def metadata_fetchers(): - return builds(MetadataFetcher, metadata=just(None)) + return builds( + MetadataFetcher, + name=text(min_size=1, alphabet=string.printable), + version=text( + min_size=1, + alphabet=string.ascii_letters + string.digits + string.punctuation, + ), + metadata=just(None), + ) def raw_extrinsic_metadata(): return builds( RawExtrinsicMetadata, target=extended_swhids(), discovery_date=aware_datetimes(), authority=metadata_authorities(), fetcher=metadata_fetchers(), + format=text(min_size=1, alphabet=string.printable), ) def raw_extrinsic_metadata_d(): return raw_extrinsic_metadata().map(RawExtrinsicMetadata.to_dict) def objects(blacklist_types=("origin_visit_status",), split_content=False): """generates a random couple (type, obj) which obj is an instance of the Model class corresponding to obj_type. `blacklist_types` is a list of obj_type to exclude from the strategy. If `split_content` is True, generates Content and SkippedContent under different obj_type, resp. "content" and "skipped_content". """ strategies = [ ("origin", origins), ("origin_visit", origin_visits), ("origin_visit_status", origin_visit_statuses), ("snapshot", snapshots), ("release", releases), ("revision", revisions), ("directory", directories), ("raw_extrinsic_metadata", raw_extrinsic_metadata), ] if split_content: strategies.append(("content", present_contents)) strategies.append(("skipped_content", skipped_contents)) else: strategies.append(("content", contents)) args = [ obj_gen().map(lambda x, obj_type=obj_type: (obj_type, x)) for (obj_type, obj_gen) in strategies if obj_type not in blacklist_types ] return one_of(*args) def object_dicts(blacklist_types=("origin_visit_status",), split_content=False): """generates a random couple (type, dict) which dict is suitable for .from_dict() factory methods. `blacklist_types` is a list of obj_type to exclude from the strategy. If `split_content` is True, generates Content and SkippedContent under different obj_type, resp. "content" and "skipped_content". """ strategies = [ ("origin", origins_d), ("origin_visit", origin_visits_d), ("origin_visit_status", origin_visit_statuses_d), ("snapshot", snapshots_d), ("release", releases_d), ("revision", revisions_d), ("directory", directories_d), ("raw_extrinsic_metadata", raw_extrinsic_metadata_d), ] if split_content: strategies.append(("content", present_contents_d)) strategies.append(("skipped_content", skipped_contents_d)) else: strategies.append(("content", contents_d)) args = [ obj_gen().map(lambda x, obj_type=obj_type: (obj_type, x)) for (obj_type, obj_gen) in strategies if obj_type not in blacklist_types ] return one_of(*args) diff --git a/swh/model/merkle.py b/swh/model/merkle.py index 098c872..8934ad1 100644 --- a/swh/model/merkle.py +++ b/swh/model/merkle.py @@ -1,313 +1,315 @@ # Copyright (C) 2017-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """Merkle tree data structure""" import abc from collections.abc import Mapping from typing import Dict, Iterator, List, Set def deep_update(left, right): """Recursively update the left mapping with deeply nested values from the right mapping. This function is useful to merge the results of several calls to :func:`MerkleNode.collect`. Arguments: left: a mapping (modified by the update operation) right: a mapping Returns: the left mapping, updated with nested values from the right mapping Example: >>> a = { ... 'key1': { ... 'key2': { ... 'key3': 'value1/2/3', ... }, ... }, ... } >>> deep_update(a, { ... 'key1': { ... 'key2': { ... 'key4': 'value1/2/4', ... }, ... }, ... }) == { ... 'key1': { ... 'key2': { ... 'key3': 'value1/2/3', ... 'key4': 'value1/2/4', ... }, ... }, ... } True >>> deep_update(a, { ... 'key1': { ... 'key2': { ... 'key3': 'newvalue1/2/3', ... }, ... }, ... }) == { ... 'key1': { ... 'key2': { ... 'key3': 'newvalue1/2/3', ... 'key4': 'value1/2/4', ... }, ... }, ... } True """ for key, rvalue in right.items(): if isinstance(rvalue, Mapping): new_lvalue = deep_update(left.get(key, {}), rvalue) left[key] = new_lvalue else: left[key] = rvalue return left class MerkleNode(dict, metaclass=abc.ABCMeta): """Representation of a node in a Merkle Tree. A (generalized) `Merkle Tree`_ is a tree in which every node is labeled with a hash of its own data and the hash of its children. .. _Merkle Tree: https://en.wikipedia.org/wiki/Merkle_tree In pseudocode:: node.hash = hash(node.data + sum(child.hash for child in node.children)) This class efficiently implements the Merkle Tree data structure on top of a Python :class:`dict`, minimizing hash computations and new data collections when updating nodes. Node data is stored in the :attr:`data` attribute, while (named) children are stored as items of the underlying dictionary. Addition, update and removal of objects are instrumented to automatically invalidate the hashes of the current node as well as its registered parents; It also resets the collection status of the objects so the updated objects can be collected. The collection of updated data from the tree is implemented through the :func:`collect` function and associated helpers. """ __slots__ = ["parents", "data", "__hash", "collected"] data: Dict """data associated to the current node""" parents: List """known parents of the current node""" collected: bool """whether the current node has been collected""" def __init__(self, data=None): super().__init__() self.parents = [] self.data = data self.__hash = None self.collected = False def __eq__(self, other): return ( isinstance(other, MerkleNode) and super().__eq__(other) and self.data == other.data ) def __ne__(self, other): return not self.__eq__(other) def invalidate_hash(self): """Invalidate the cached hash of the current node.""" if not self.__hash: return self.__hash = None self.collected = False for parent in self.parents: parent.invalidate_hash() def update_hash(self, *, force=False): """Recursively compute the hash of the current node. Args: force (bool): invalidate the cache and force the computation for this node and all children. """ if self.__hash and not force: return self.__hash if force: self.invalidate_hash() for child in self.values(): child.update_hash(force=force) self.__hash = self.compute_hash() return self.__hash @property def hash(self): """The hash of the current node, as calculated by :func:`compute_hash`. """ return self.update_hash() @abc.abstractmethod def compute_hash(self): """Compute the hash of the current node. The hash should depend on the data of the node, as well as on hashes of the children nodes. """ raise NotImplementedError("Must implement compute_hash method") def __setitem__(self, name, new_child): """Add a child, invalidating the current hash""" self.invalidate_hash() super().__setitem__(name, new_child) new_child.parents.append(self) def __delitem__(self, name): """Remove a child, invalidating the current hash""" if name in self: self.invalidate_hash() self[name].parents.remove(self) super().__delitem__(name) else: raise KeyError(name) def update(self, new_children): """Add several named children from a dictionary""" if not new_children: return self.invalidate_hash() for name, new_child in new_children.items(): new_child.parents.append(self) if name in self: self[name].parents.remove(self) super().update(new_children) def get_data(self, **kwargs): """Retrieve and format the collected data for the current node, for use by :func:`collect`. Can be overridden, for instance when you want the collected data to contain information about the child nodes. Arguments: kwargs: allow subclasses to alter behaviour depending on how :func:`collect` is called. Returns: data formatted for :func:`collect` """ return self.data def collect_node(self, **kwargs): """Collect the data for the current node, for use by :func:`collect`. Arguments: kwargs: passed as-is to :func:`get_data`. Returns: A :class:`dict` compatible with :func:`collect`. """ if not self.collected: self.collected = True return {self.object_type: {self.hash: self.get_data(**kwargs)}} else: return {} def collect(self, **kwargs): """Collect the data for all nodes in the subtree rooted at `self`. The data is deduplicated by type and by hash. Arguments: kwargs: passed as-is to :func:`get_data`. Returns: A :class:`dict` with the following structure:: { 'typeA': { node1.hash: node1.get_data(), node2.hash: node2.get_data(), }, 'typeB': { node3.hash: node3.get_data(), ... }, ... } """ ret = self.collect_node(**kwargs) for child in self.values(): deep_update(ret, child.collect(**kwargs)) return ret def reset_collect(self): """Recursively unmark collected nodes in the subtree rooted at `self`. This lets the caller use :func:`collect` again. """ self.collected = False for child in self.values(): child.reset_collect() - def iter_tree(self) -> Iterator["MerkleNode"]: - """Yields all children nodes, recursively. Common nodes are - deduplicated. + def iter_tree(self, dedup=True) -> Iterator["MerkleNode"]: + """Yields all children nodes, recursively. Common nodes are deduplicated + by default (deduplication can be turned off setting the given argument + 'dedup' to False). """ - yield from self._iter_tree(set()) + yield from self._iter_tree(set(), dedup) - def _iter_tree(self, seen: Set[bytes]) -> Iterator["MerkleNode"]: + def _iter_tree(self, seen: Set[bytes], dedup) -> Iterator["MerkleNode"]: if self.hash not in seen: - seen.add(self.hash) + if dedup: + seen.add(self.hash) yield self for child in self.values(): - yield from child._iter_tree(seen=seen) + yield from child._iter_tree(seen=seen, dedup=dedup) class MerkleLeaf(MerkleNode): """A leaf to a Merkle tree. A Merkle leaf is simply a Merkle node with children disabled. """ __slots__ = [] # type: List[str] def __setitem__(self, name, child): raise ValueError("%s is a leaf" % self.__class__.__name__) def __getitem__(self, name): raise ValueError("%s is a leaf" % self.__class__.__name__) def __delitem__(self, name): raise ValueError("%s is a leaf" % self.__class__.__name__) def update(self, new_children): """Children update operation. Disabled for leaves.""" raise ValueError("%s is a leaf" % self.__class__.__name__) diff --git a/swh/model/tests/test_merkle.py b/swh/model/tests/test_merkle.py index 65992f4..32de872 100644 --- a/swh/model/tests/test_merkle.py +++ b/swh/model/tests/test_merkle.py @@ -1,247 +1,255 @@ # Copyright (C) 2017-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import unittest from swh.model import merkle class MerkleTestNode(merkle.MerkleNode): object_type = "tested_merkle_node_type" def __init__(self, data): super().__init__(data) self.compute_hash_called = 0 def compute_hash(self): self.compute_hash_called += 1 child_data = [child + b"=" + self[child].hash for child in sorted(self)] return b"hash(" + b", ".join([self.data["value"]] + child_data) + b")" class MerkleTestLeaf(merkle.MerkleLeaf): object_type = "tested_merkle_leaf_type" def __init__(self, data): super().__init__(data) self.compute_hash_called = 0 def compute_hash(self): self.compute_hash_called += 1 return b"hash(" + self.data["value"] + b")" class TestMerkleLeaf(unittest.TestCase): def setUp(self): self.data = {"value": b"value"} self.instance = MerkleTestLeaf(self.data) def test_equality(self): leaf1 = MerkleTestLeaf(self.data) leaf2 = MerkleTestLeaf(self.data) leaf3 = MerkleTestLeaf({}) self.assertEqual(leaf1, leaf2) self.assertNotEqual(leaf1, leaf3) def test_hash(self): self.assertEqual(self.instance.compute_hash_called, 0) instance_hash = self.instance.hash self.assertEqual(self.instance.compute_hash_called, 1) instance_hash2 = self.instance.hash self.assertEqual(self.instance.compute_hash_called, 1) self.assertEqual(instance_hash, instance_hash2) def test_data(self): self.assertEqual(self.instance.get_data(), self.data) def test_collect(self): collected = self.instance.collect() self.assertEqual( collected, { self.instance.object_type: { self.instance.hash: self.instance.get_data(), }, }, ) collected2 = self.instance.collect() self.assertEqual(collected2, {}) self.instance.reset_collect() collected3 = self.instance.collect() self.assertEqual(collected, collected3) def test_leaf(self): with self.assertRaisesRegex(ValueError, "is a leaf"): self.instance[b"key1"] = "Test" with self.assertRaisesRegex(ValueError, "is a leaf"): del self.instance[b"key1"] with self.assertRaisesRegex(ValueError, "is a leaf"): self.instance[b"key1"] with self.assertRaisesRegex(ValueError, "is a leaf"): self.instance.update(self.data) class TestMerkleNode(unittest.TestCase): maxDiff = None def setUp(self): self.root = MerkleTestNode({"value": b"root"}) self.nodes = {b"root": self.root} for i in (b"a", b"b", b"c"): value = b"root/" + i node = MerkleTestNode({"value": value,}) self.root[i] = node self.nodes[value] = node for j in (b"a", b"b", b"c"): value2 = value + b"/" + j node2 = MerkleTestNode({"value": value2,}) node[j] = node2 self.nodes[value2] = node2 for k in (b"a", b"b", b"c"): value3 = value2 + b"/" + j node3 = MerkleTestNode({"value": value3,}) node2[j] = node3 self.nodes[value3] = node3 def test_equality(self): node1 = merkle.MerkleNode({"foo": b"bar"}) node2 = merkle.MerkleNode({"foo": b"bar"}) node3 = merkle.MerkleNode({}) self.assertEqual(node1, node2) self.assertNotEqual(node1, node3, node1 == node3) node1["foo"] = node3 self.assertNotEqual(node1, node2) node2["foo"] = node3 self.assertEqual(node1, node2) def test_hash(self): for node in self.nodes.values(): self.assertEqual(node.compute_hash_called, 0) # Root hash will compute hash for all the nodes hash = self.root.hash for node in self.nodes.values(): self.assertEqual(node.compute_hash_called, 1) self.assertIn(node.data["value"], hash) # Should use the cached value hash2 = self.root.hash self.assertEqual(hash, hash2) for node in self.nodes.values(): self.assertEqual(node.compute_hash_called, 1) # Should still use the cached value hash3 = self.root.update_hash(force=False) self.assertEqual(hash, hash3) for node in self.nodes.values(): self.assertEqual(node.compute_hash_called, 1) # Force update of the cached value for a deeply nested node self.root[b"a"][b"b"].update_hash(force=True) for key, node in self.nodes.items(): # update_hash rehashes all children if key.startswith(b"root/a/b"): self.assertEqual(node.compute_hash_called, 2) else: self.assertEqual(node.compute_hash_called, 1) hash4 = self.root.hash self.assertEqual(hash, hash4) for key, node in self.nodes.items(): # update_hash also invalidates all parents if key in (b"root", b"root/a") or key.startswith(b"root/a/b"): self.assertEqual(node.compute_hash_called, 2) else: self.assertEqual(node.compute_hash_called, 1) def test_collect(self): collected = self.root.collect() self.assertEqual(len(collected[self.root.object_type]), len(self.nodes)) for node in self.nodes.values(): self.assertTrue(node.collected) collected2 = self.root.collect() self.assertEqual(collected2, {}) - def test_iter_tree(self): + def test_iter_tree_with_deduplication(self): nodes = list(self.root.iter_tree()) self.assertCountEqual(nodes, self.nodes.values()) + def test_iter_tree_without_deduplication(self): + # duplicate existing hash in merkle tree + self.root[b"d"] = MerkleTestNode({"value": b"root/c/c/c"}) + nodes_dedup = list(self.root.iter_tree()) + nodes = list(self.root.iter_tree(dedup=False)) + assert nodes != nodes_dedup + assert len(nodes) == len(nodes_dedup) + 1 + def test_get(self): for key in (b"a", b"b", b"c"): self.assertEqual(self.root[key], self.nodes[b"root/" + key]) with self.assertRaisesRegex(KeyError, "b'nonexistent'"): self.root[b"nonexistent"] def test_del(self): hash_root = self.root.hash hash_a = self.nodes[b"root/a"].hash del self.root[b"a"][b"c"] hash_root2 = self.root.hash hash_a2 = self.nodes[b"root/a"].hash self.assertNotEqual(hash_root, hash_root2) self.assertNotEqual(hash_a, hash_a2) self.assertEqual(self.nodes[b"root/a/c"].parents, []) with self.assertRaisesRegex(KeyError, "b'nonexistent'"): del self.root[b"nonexistent"] def test_update(self): hash_root = self.root.hash hash_b = self.root[b"b"].hash new_children = { b"c": MerkleTestNode({"value": b"root/b/new_c"}), b"d": MerkleTestNode({"value": b"root/b/d"}), } # collect all nodes self.root.collect() self.root[b"b"].update(new_children) # Ensure everyone got reparented self.assertEqual(new_children[b"c"].parents, [self.root[b"b"]]) self.assertEqual(new_children[b"d"].parents, [self.root[b"b"]]) self.assertEqual(self.nodes[b"root/b/c"].parents, []) hash_root2 = self.root.hash self.assertNotEqual(hash_root, hash_root2) self.assertIn(b"root/b/new_c", hash_root2) self.assertIn(b"root/b/d", hash_root2) hash_b2 = self.root[b"b"].hash self.assertNotEqual(hash_b, hash_b2) for key, node in self.nodes.items(): if key in (b"root", b"root/b"): self.assertEqual(node.compute_hash_called, 2) else: self.assertEqual(node.compute_hash_called, 1) # Ensure we collected root, root/b, and both new children collected_after_update = self.root.collect() self.assertCountEqual( collected_after_update[MerkleTestNode.object_type], [ self.nodes[b"root"].hash, self.nodes[b"root/b"].hash, new_children[b"c"].hash, new_children[b"d"].hash, ], ) # test that noop updates doesn't invalidate anything self.root[b"a"][b"b"].update({}) self.assertEqual(self.root.collect(), {})