diff --git a/swh/model/collections.py b/swh/model/collections.py new file mode 100644 --- /dev/null +++ b/swh/model/collections.py @@ -0,0 +1,49 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from collections.abc import Mapping +from typing import Dict, Generic, Iterable, Optional, Tuple, TypeVar, Union + +KT = TypeVar("KT") +VT = TypeVar("VT") + + +class ImmutableDict(Mapping, Generic[KT, VT]): + data: Tuple[Tuple[KT, VT], ...] + + def __init__(self, data: Union[Iterable[Tuple[KT, VT]], Dict[KT, VT]] = {}): + if isinstance(data, dict): + self.data = tuple(item for item in data.items()) + else: + self.data = tuple(data) + + def __getitem__(self, key): + for (k, v) in self.data: + if k == key: + return v + raise KeyError(key) + + def __iter__(self): + for (k, v) in self.data: + yield k + + def __len__(self): + return len(self.data) + + def items(self): + yield from self.data + + def copy_pop(self, popped_key) -> Tuple[Optional[VT], "ImmutableDict[KT, VT]"]: + """Returns a copy of this ImmutableDict without the given key, + as well as the value associated to the key.""" + popped_value = None + new_items = [] + for (key, value) in self.data: + if key == popped_key: + popped_value = value + else: + new_items.append((key, value)) + + return (popped_value, ImmutableDict(new_items)) diff --git a/swh/model/model.py b/swh/model/model.py --- a/swh/model/model.py +++ b/swh/model/model.py @@ -6,7 +6,6 @@ import datetime from abc import ABCMeta, abstractmethod -from copy import deepcopy from enum import Enum from hashlib import sha256 from typing import Any, Dict, Iterable, Optional, Tuple, TypeVar, Union @@ -17,6 +16,8 @@ import dateutil.parser import iso8601 +from .collections import ImmutableDict +from .hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, MultiHash from .identifiers import ( normalize_timestamp, directory_identifier, @@ -25,7 +26,6 @@ snapshot_identifier, SWHID, ) -from .hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, MultiHash class MissingData(Exception): @@ -41,13 +41,26 @@ Sha1Git = bytes +KT = TypeVar("KT") +VT = TypeVar("VT") + + +def freeze_optional_dict( + d: Union[None, Dict[KT, VT], ImmutableDict[KT, VT]] # type: ignore +) -> Optional[ImmutableDict[KT, VT]]: + if isinstance(d, dict): + return ImmutableDict(d) + else: + return d + + def dictify(value): "Helper function used by BaseModel.to_dict()" if isinstance(value, BaseModel): return value.to_dict() elif isinstance(value, Enum): return value.value - elif isinstance(value, dict): + elif isinstance(value, (dict, ImmutableDict)): return {k: dictify(v) for k, v in value.items()} elif isinstance(value, tuple): return tuple(dictify(v) for v in value) @@ -277,7 +290,10 @@ ) snapshot = attr.ib(type=Optional[Sha1Git], validator=type_validator()) metadata = attr.ib( - type=Optional[Dict[str, object]], validator=type_validator(), default=None + type=Optional[ImmutableDict[str, object]], + validator=type_validator(), + converter=freeze_optional_dict, + default=None, ) @@ -332,7 +348,9 @@ object_type: Final = "snapshot" branches = attr.ib( - type=Dict[bytes, Optional[SnapshotBranch]], validator=type_validator() + type=ImmutableDict[bytes, Optional[SnapshotBranch]], + validator=type_validator(), + converter=freeze_optional_dict, ) id = attr.ib(type=Sha1Git, validator=type_validator(), default=b"") @@ -344,10 +362,10 @@ def from_dict(cls, d): d = d.copy() return cls( - branches={ - name: SnapshotBranch.from_dict(branch) if branch else None + branches=ImmutableDict( + (name, SnapshotBranch.from_dict(branch) if branch else None) for (name, branch) in d.pop("branches").items() - }, + ), **d, ) @@ -366,7 +384,10 @@ type=Optional[TimestampWithTimezone], validator=type_validator(), default=None ) metadata = attr.ib( - type=Optional[Dict[str, object]], validator=type_validator(), default=None + type=Optional[ImmutableDict[str, object]], + validator=type_validator(), + converter=freeze_optional_dict, + default=None, ) id = attr.ib(type=Sha1Git, validator=type_validator(), default=b"") @@ -431,7 +452,10 @@ directory = attr.ib(type=Sha1Git, validator=type_validator()) synthetic = attr.ib(type=bool, validator=type_validator()) metadata = attr.ib( - type=Optional[Dict[str, object]], validator=type_validator(), default=None + type=Optional[ImmutableDict[str, object]], + validator=type_validator(), + converter=freeze_optional_dict, + default=None, ) parents = attr.ib(type=Tuple[Sha1Git, ...], validator=type_validator(), default=()) id = attr.ib(type=Sha1Git, validator=type_validator(), default=b"") @@ -447,12 +471,11 @@ # ensure metadata is a deep copy of whatever was given, and if needed # extract extra_headers from there if self.metadata: - metadata = deepcopy(self.metadata) + metadata = self.metadata if not self.extra_headers and "extra_headers" in metadata: + (extra_headers, metadata) = metadata.copy_pop("extra_headers") object.__setattr__( - self, - "extra_headers", - tuplify_extra_headers(metadata.pop("extra_headers")), + self, "extra_headers", tuplify_extra_headers(extra_headers), ) attr.validate(self) object.__setattr__(self, "metadata", metadata) @@ -713,7 +736,10 @@ type = attr.ib(type=MetadataAuthorityType, validator=type_validator()) url = attr.ib(type=str, validator=type_validator()) metadata = attr.ib( - type=Optional[Dict[str, Any]], default=None, validator=type_validator() + type=Optional[ImmutableDict[str, Any]], + default=None, + validator=type_validator(), + converter=freeze_optional_dict, ) @@ -725,7 +751,10 @@ name = attr.ib(type=str, validator=type_validator()) version = attr.ib(type=str, validator=type_validator()) metadata = attr.ib( - type=Optional[Dict[str, Any]], default=None, validator=type_validator() + type=Optional[ImmutableDict[str, Any]], + default=None, + validator=type_validator(), + converter=freeze_optional_dict, ) diff --git a/swh/model/tests/test_collections.py b/swh/model/tests/test_collections.py new file mode 100644 --- /dev/null +++ b/swh/model/tests/test_collections.py @@ -0,0 +1,50 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import pytest + +from swh.model.collections import ImmutableDict + + +def test_immutabledict_empty(): + d = ImmutableDict() + + assert d == {} + assert d != {"foo": "bar"} + + assert list(d) == [] + assert list(d.items()) == [] + + +def test_immutabledict_one_item(): + d = ImmutableDict({"foo": "bar"}) + + assert d == {"foo": "bar"} + assert d != {} + + assert d["foo"] == "bar" + with pytest.raises(KeyError, match="bar"): + d["bar"] + + assert list(d) == ["foo"] + assert list(d.items()) == [("foo", "bar")] + + +def test_immutabledict_immutable(): + d = ImmutableDict({"foo": "bar"}) + + with pytest.raises(TypeError, match="item assignment"): + d["bar"] = "baz" + + with pytest.raises(TypeError, match="item deletion"): + del d["foo"] + + +def test_immutabledict_copy_pop(): + d = ImmutableDict({"foo": "bar", "baz": "qux"}) + + assert d.copy_pop("foo") == ("bar", ImmutableDict({"baz": "qux"})) + + assert d.copy_pop("not a key") == (None, d)