diff --git a/swh/model/model.py b/swh/model/model.py --- a/swh/model/model.py +++ b/swh/model/model.py @@ -22,7 +22,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, TypeVar, Union import attr -from attrs_strict import type_validator +from attrs_strict import AttributeTypeError import dateutil.parser import iso8601 from typing_extensions import Final @@ -83,6 +83,63 @@ return value +def _check_type(type_, value): + if type_ is object or type_ is Any: + return True + + origin = getattr(type_, "__origin__", None) + + # Non-generic type, check it directly + if origin is None: + return type(value) == type_ + + # Check the type of the value itself + if origin is not Union and type(value) != origin: + return False + + # Then, if it's a container, check its items. + if origin is tuple: + args = type_.__args__ + if len(args) == 2 and args[1] is Ellipsis: + # Infinite tuple + return all(_check_type(args[0], item) for item in value) + else: + # Finite tuple + if len(args) != len(value): + return False + + return all( + _check_type(item_type, item) for (item_type, item) in zip(args, value) + ) + elif origin is Union: + args = type_.__args__ + return any(_check_type(variant, value) for variant in args) + elif origin is ImmutableDict: + (key_type, value_type) = type_.__args__ + return all( + _check_type(key_type, key) and _check_type(value_type, value) + for (key, value) in value.items() + ) + else: + # No need to check dict or list. because they are converted to ImmutableDict + # and tuple respectively. + raise NotImplementedError(f"Type-checking {type_}") + + +def type_validator(): + """Like attrs_strict.type_validator(), but stricter. + + It is an attrs validator, which checks attributes have the specified type, + using type equality instead of ``isinstance()``, for improved performance + """ + + def validator(instance, attribute, value): + if not _check_type(attribute.type, value): + raise AttributeTypeError(value, attribute) + + return validator + + ModelType = TypeVar("ModelType", bound="BaseModel") @@ -686,7 +743,7 @@ name = attr.ib(type=bytes, validator=type_validator()) type = attr.ib(type=str, validator=attr.validators.in_(["file", "dir", "rev"])) target = attr.ib(type=Sha1Git, validator=type_validator()) - perms = attr.ib(type=int, validator=type_validator()) + perms = attr.ib(type=int, validator=type_validator(), converter=int) """Usually one of the values of `swh.model.from_disk.DentryPerms`.""" diff --git a/swh/model/tests/test_model.py b/swh/model/tests/test_model.py --- a/swh/model/tests/test_model.py +++ b/swh/model/tests/test_model.py @@ -5,6 +5,7 @@ import copy import datetime +from typing import Any, List, Optional, Tuple, Union import attr from attrs_strict import AttributeTypeError @@ -12,6 +13,8 @@ from hypothesis.strategies import binary import pytest +from swh.model.collections import ImmutableDict +from swh.model.from_disk import DentryPerms from swh.model.hashutil import MultiHash, hash_to_bytes import swh.model.hypothesis_strategies as strategies from swh.model.model import ( @@ -31,8 +34,10 @@ Revision, SkippedContent, Snapshot, + TargetType, Timestamp, TimestampWithTimezone, + type_validator, ) from swh.model.swhids import CoreSWHID, ExtendedSWHID, ObjectType from swh.model.tests.swh_model_data import TEST_OBJECTS @@ -69,6 +74,156 @@ assert obj_as_dict == type(obj).from_dict(obj_as_dict).to_dict() +# List of (type, valid_values, invalid_values) +_TYPE_VALIDATOR_PARAMETERS: List[Tuple[Any, List[Any], List[Any]]] = [ + # base types: + ( + bool, + [True, False], + [-1, 0, 1, 42, 1000, None, "123", 0.0, (), ("foo",), ImmutableDict()], + ), + ( + int, + [-1, 0, 1, 42, 1000], + [True, False, None, "123", 0.0, (), ImmutableDict(), DentryPerms.directory], + ), + ( + float, + [-1.0, 0.0, 1.0, float("infinity"), float("NaN")], + [True, False, None, 1, "1.2", (), ImmutableDict()], + ), + ( + bytes, + [b"", b"123"], + [None, bytearray(b"\x12\x34"), "123", 0, 123, (), (1, 2, 3), ImmutableDict()], + ), + (str, ["", "123"], [None, b"123", b"", 0, (), (1, 2, 3), ImmutableDict()]), + # unions: + ( + Optional[int], + [None, -1, 0, 1, 42, 1000], + ["123", 0.0, (), ImmutableDict(), DentryPerms.directory], + ), + ( + Optional[bytes], + [None, b"", b"123"], + ["123", "", 0, (), (1, 2, 3), ImmutableDict()], + ), + ( + Union[str, bytes], + ["", "123", b"123", b""], + [None, 0, (), (1, 2, 3), ImmutableDict()], + ), + ( + Union[str, bytes, None], + ["", "123", b"123", b"", None], + [0, (), (1, 2, 3), ImmutableDict()], + ), + # tuples + ( + Tuple[str, str], + [("foo", "bar"), ("", "")], + [("foo",), ("foo", "bar", "baz"), ("foo", 42), (42, "foo")], + ), + ( + Tuple[str, ...], + [("foo",), ("foo", "bar"), ("", ""), ("foo", "bar", "baz")], + [("foo", 42), (42, "foo")], + ), + # composite generic: + ( + Tuple[Union[str, int], Union[str, int]], + [("foo", "foo"), ("foo", 42), (42, "foo"), (42, 42)], + [("foo", b"bar"), (b"bar", "foo")], + ), + ( + Union[Tuple[str, str], Tuple[int, int]], + [("foo", "foo"), (42, 42)], + [("foo", b"bar"), (b"bar", "foo"), ("foo", 42), (42, "foo")], + ), + ( + Tuple[Tuple[bytes, bytes], ...], + [(), ((b"foo", b"bar"),), ((b"foo", b"bar"), (b"baz", b"qux"))], + [((b"foo", "bar"),), ((b"foo", b"bar"), ("baz", b"qux"))], + ), + # standard types: + ( + datetime.datetime, + [datetime.datetime.now(), datetime.datetime.now(tz=datetime.timezone.utc)], + [None, 123], + ), + # ImmutableDict + ( + ImmutableDict[str, int], + [ + ImmutableDict(), + ImmutableDict({"foo": 42}), + ImmutableDict({"foo": 42, "bar": 123}), + ], + [ImmutableDict({"foo": "bar"}), ImmutableDict({42: 123})], + ), + # Any: + (Any, [-1, 0, 1, 42, 1000, None, "123", 0.0, (), ImmutableDict()], [],), + ( + ImmutableDict[Any, int], + [ + ImmutableDict(), + ImmutableDict({"foo": 42}), + ImmutableDict({"foo": 42, "bar": 123}), + ImmutableDict({42: 123}), + ], + [ImmutableDict({"foo": "bar"})], + ), + ( + ImmutableDict[str, Any], + [ + ImmutableDict(), + ImmutableDict({"foo": 42}), + ImmutableDict({"foo": "bar"}), + ImmutableDict({"foo": 42, "bar": 123}), + ], + [ImmutableDict({42: 123})], + ), + # attr objects: + ( + Timestamp, + [Timestamp(seconds=123, microseconds=0)], + [None, "2021-09-28T11:27:59", 123], + ), + # enums: + ( + TargetType, + [TargetType.CONTENT, TargetType.ALIAS], + ["content", "alias", 123, None], + ), +] + + +@pytest.mark.parametrize( + "type_,value", + [ + (type_, value) + for (type_, values, _) in _TYPE_VALIDATOR_PARAMETERS + for value in values + ], +) +def test_type_validator_valid(type_, value): + type_validator()(None, attr.ib(type=type_), value) + + +@pytest.mark.parametrize( + "type_,value", + [ + (type_, value) + for (type_, _, values) in _TYPE_VALIDATOR_PARAMETERS + for value in values + ], +) +def test_type_validator_invalid(type_, value): + with pytest.raises(AttributeTypeError): + type_validator()(None, attr.ib(type=type_), value) + + @pytest.mark.parametrize("object_type, objects", TEST_OBJECTS.items()) def test_swh_model_todict_fromdict(object_type, objects): """checks model objects in swh_model_data are in correct shape"""