diff --git a/swh/model/collections.py b/swh/model/collections.py index 2724f85..495b43c 100644 --- a/swh/model/collections.py +++ b/swh/model/collections.py @@ -1,49 +1,56 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections.abc import Mapping from typing import Dict, Generic, Iterable, Optional, Tuple, TypeVar, Union KT = TypeVar("KT") VT = TypeVar("VT") class ImmutableDict(Mapping, Generic[KT, VT]): data: Tuple[Tuple[KT, VT], ...] - def __init__(self, data: Union[Iterable[Tuple[KT, VT]], Dict[KT, VT]] = {}): + def __init__( + self, + data: Union[ + Iterable[Tuple[KT, VT]], "ImmutableDict[KT, VT]", Dict[KT, VT] + ] = {}, + ): if isinstance(data, dict): self.data = tuple(item for item in data.items()) + elif isinstance(data, ImmutableDict): + self.data = data.data else: self.data = tuple(data) def __getitem__(self, key): for (k, v) in self.data: if k == key: return v raise KeyError(key) def __iter__(self): for (k, v) in self.data: yield k def __len__(self): return len(self.data) def items(self): yield from self.data def copy_pop(self, popped_key) -> Tuple[Optional[VT], "ImmutableDict[KT, VT]"]: """Returns a copy of this ImmutableDict without the given key, as well as the value associated to the key.""" popped_value = None new_items = [] for (key, value) in self.data: if key == popped_key: popped_value = value else: new_items.append((key, value)) return (popped_value, ImmutableDict(new_items)) diff --git a/swh/model/tests/test_collections.py b/swh/model/tests/test_collections.py index c7b44cb..b042c59 100644 --- a/swh/model/tests/test_collections.py +++ b/swh/model/tests/test_collections.py @@ -1,50 +1,66 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest from swh.model.collections import ImmutableDict def test_immutabledict_empty(): d = ImmutableDict() assert d == {} assert d != {"foo": "bar"} assert list(d) == [] assert list(d.items()) == [] def test_immutabledict_one_item(): d = ImmutableDict({"foo": "bar"}) assert d == {"foo": "bar"} assert d != {} assert d["foo"] == "bar" with pytest.raises(KeyError, match="bar"): d["bar"] assert list(d) == ["foo"] assert list(d.items()) == [("foo", "bar")] +def test_immutabledict_from_iterable(): + d1 = ImmutableDict() + d2 = ImmutableDict({"foo": "bar"}) + + assert ImmutableDict([]) == d1 + assert ImmutableDict([("foo", "bar")]) == d2 + + +def test_immutabledict_from_immutabledict(): + d1 = ImmutableDict() + d2 = ImmutableDict({"foo": "bar"}) + + assert ImmutableDict(d1) == d1 + assert ImmutableDict(d2) == d2 + + def test_immutabledict_immutable(): d = ImmutableDict({"foo": "bar"}) with pytest.raises(TypeError, match="item assignment"): d["bar"] = "baz" with pytest.raises(TypeError, match="item deletion"): del d["foo"] def test_immutabledict_copy_pop(): d = ImmutableDict({"foo": "bar", "baz": "qux"}) assert d.copy_pop("foo") == ("bar", ImmutableDict({"baz": "qux"})) assert d.copy_pop("not a key") == (None, d)