diff --git a/swh/model/model.py b/swh/model/model.py --- a/swh/model/model.py +++ b/swh/model/model.py @@ -23,6 +23,8 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union import attr +from attr._make import _AndValidator +from attr.validators import and_ from attrs_strict import AttributeTypeError import dateutil.parser import iso8601 @@ -147,18 +149,172 @@ raise NotImplementedError(f"Type-checking {type_}") -def type_validator(): - """Like attrs_strict.type_validator(), but stricter. +def generic_type_validator(instance, attribute, value): + """validates the type of an attribute value whatever the attribute type""" + if not _check_type(attribute.type, value): + raise AttributeTypeError(value, attribute) - It is an attrs validator, which checks attributes have the specified type, - using type equality instead of ``isinstance()``, for improved performance - """ - def validator(instance, attribute, value): - if not _check_type(attribute.type, value): - raise AttributeTypeError(value, attribute) +def _true_validator(instance, attribute, value, expected_type=None): + pass + + +def _none_validator(instance, attribute, value, expected_type=None): + if value is not None: + raise AttributeTypeError(value, attribute) + + +def _origin_type_validator(instance, attribute, value, expected_type=None): + # This is functionally equivalent to using just this: + # return isinstance(value, type) + # but using type equality before isinstance allows very quick checks + # when the exact class is used (which is the overwhelming majority of cases) + # while still allowing subclasses to be used. + if expected_type is None: + expected_type = attribute.type + if not (type(value) == expected_type or isinstance(value, expected_type)): + raise AttributeTypeError(value, attribute) + + +def _tuple_infinite_validator(instance, attribute, value, expected_type=None): + type_ = type(value) + if type_ != tuple and not isinstance(value, tuple): + raise AttributeTypeError(value, attribute) + if expected_type is None: + expected_type = attribute.type + args = expected_type.__args__ + # assert len(args) == 2 and args[1] is Ellipsis + expected_value_type = args[0] + validator = optimized_validator(expected_value_type) + try: + for i in value: + validator(instance, attribute, i, expected_type=expected_value_type) + except AttributeTypeError: + raise AttributeTypeError(value, attribute) from None + + +def _tuple_bytes_bytes_validator(instance, attribute, value, expected_type=None): + type_ = type(value) + if type_ != tuple and not isinstance(value, tuple): + raise AttributeTypeError(value, attribute) + if len(value) != 2: + raise AttributeTypeError(value, attribute) + if type(value[0]) is not bytes or type(value[1]) is not bytes: + raise AttributeTypeError(value, attribute) + + +def _tuple_finite_validator(instance, attribute, value, expected_type=None): + # might be useful to optimise the sub-validator tuple, in practice, we only + # have [bytes, bytes] + type_ = type(value) + if type_ != tuple and not isinstance(value, tuple): + raise AttributeTypeError(value, attribute) + if expected_type is None: + expected_type = attribute.type + args = expected_type.__args__ + + # assert len(args) != 2 or args[1] is Ellipsis + if len(args) != len(value): + raise AttributeTypeError(value, attribute) + try: + for item_type, item in zip(args, value): + validator = optimized_validator(item_type) + validator(instance, attribute, item, expected_type=item_type) + except AttributeTypeError: + raise AttributeTypeError(value, attribute) from None + + +def _immutable_dict_validator(instance, attribute, value, expected_type=None): + value_type = type(value) + if value_type != ImmutableDict and not isinstance(value, ImmutableDict): + raise AttributeTypeError(value, attribute) + + if expected_type is None: + expected_type = attribute.type + (expected_key_type, expected_value_type) = expected_type.__args__ + + key_validator = optimized_validator(expected_key_type) + value_validator = optimized_validator(expected_value_type) + + try: + for (item_key, item_value) in value.items(): + key_validator( + instance, attribute, item_key, expected_type=expected_key_type + ) + value_validator( + instance, attribute, item_value, expected_type=expected_value_type + ) + except AttributeTypeError: + raise AttributeTypeError(value, attribute) from None + + +def optimized_validator(type_): + if type_ is object or type_ is Any: + return _true_validator + + if type_ is None: + return _none_validator - return validator + origin = getattr(type_, "__origin__", None) + + # Non-generic type, check it directly + if origin is None: + return _origin_type_validator + + # Then, if it's a container, check its items. + if origin is tuple: + args = type_.__args__ + if len(args) == 2 and args[1] is Ellipsis: + # Infinite tuple + return _tuple_infinite_validator + elif args == (bytes, bytes): + return _tuple_bytes_bytes_validator + else: + return _tuple_finite_validator + elif origin is Union: + args = type_.__args__ + all_validators = tuple((optimized_validator(t), t) for t in args) + + def union_validator(instance, attribute, value, expected_type=None): + for (validator, type_) in all_validators: + try: + validator(instance, attribute, value, expected_type=type_) + except AttributeTypeError: + pass + else: + break + else: + raise AttributeTypeError(value, attribute) + + return union_validator + elif origin is ImmutableDict: + return _immutable_dict_validator + # No need to check dict or list. because they are converted to ImmutableDict + # and tuple respectively. + raise NotImplementedError(f"Type-checking {type_}") + + +def optimize_all_validators(cls, old_fields): + """process validators to turn them into a faster version … eventually""" + new_fields = [] + for f in old_fields: + validator = f.validator + if validator is generic_type_validator: + validator = optimized_validator(f.type) + elif isinstance(validator, _AndValidator): + new_and = [] + for v in validator._validators: + if v is generic_type_validator: + v = optimized_validator(f.type) + new_and.append(v) + validator = and_(*new_and) + else: + validator = None + + if validator is not None: + f = f.evolve(validator=validator) + new_fields.append(f) + return new_fields ModelType = TypeVar("ModelType", bound="BaseModel") @@ -285,15 +441,15 @@ ) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class Person(BaseModel): """Represents the author/committer of a revision or release.""" object_type: Final = "person" - fullname = attr.ib(type=bytes, validator=type_validator()) - name = attr.ib(type=Optional[bytes], validator=type_validator(), eq=False) - email = attr.ib(type=Optional[bytes], validator=type_validator(), eq=False) + fullname = attr.ib(type=bytes, validator=generic_type_validator) + name = attr.ib(type=Optional[bytes], validator=generic_type_validator, eq=False) + email = attr.ib(type=Optional[bytes], validator=generic_type_validator, eq=False) @classmethod def from_fullname(cls, fullname: bytes): @@ -367,14 +523,14 @@ return super().from_dict(d) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class Timestamp(BaseModel): """Represents a naive timestamp from a VCS.""" object_type: Final = "timestamp" - seconds = attr.ib(type=int, validator=type_validator()) - microseconds = attr.ib(type=int, validator=type_validator()) + seconds = attr.ib(type=int, validator=generic_type_validator) + microseconds = attr.ib(type=int, validator=generic_type_validator) @seconds.validator def check_seconds(self, attribute, value): @@ -389,15 +545,15 @@ raise ValueError("Microseconds must be in [0, 1000000[.") -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class TimestampWithTimezone(BaseModel): """Represents a TZ-aware timestamp from a VCS.""" object_type: Final = "timestamp_with_timezone" - timestamp = attr.ib(type=Timestamp, validator=type_validator()) + timestamp = attr.ib(type=Timestamp, validator=generic_type_validator) - offset_bytes = attr.ib(type=bytes, validator=type_validator()) + offset_bytes = attr.ib(type=bytes, validator=generic_type_validator) """Raw git representation of the timezone, as an offset from UTC. It should follow this format: ``+HHMM`` or ``-HHMM`` (including ``+0000`` and ``-0000``). @@ -586,15 +742,15 @@ return self._parse_offset_bytes(self.offset_bytes) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class Origin(HashableObject, BaseModel): """Represents a software source: a VCS and an URL.""" object_type: Final = "origin" - url = attr.ib(type=str, validator=type_validator()) + url = attr.ib(type=str, validator=generic_type_validator) - id = attr.ib(type=Sha1Git, validator=type_validator(), default=b"") + id = attr.ib(type=Sha1Git, validator=generic_type_validator, default=b"") def unique_key(self) -> KeyType: return {"url": self.url} @@ -610,18 +766,18 @@ ) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class OriginVisit(BaseModel): """Represents an origin visit with a given type at a given point in time, by a SWH loader.""" object_type: Final = "origin_visit" - origin = attr.ib(type=str, validator=type_validator()) - date = attr.ib(type=datetime.datetime, validator=type_validator()) - type = attr.ib(type=str, validator=type_validator()) + origin = attr.ib(type=str, validator=generic_type_validator) + date = attr.ib(type=datetime.datetime, validator=generic_type_validator) + type = attr.ib(type=str, validator=generic_type_validator) """Should not be set before calling 'origin_visit_add()'.""" - visit = attr.ib(type=Optional[int], validator=type_validator(), default=None) + visit = attr.ib(type=Optional[int], validator=generic_type_validator, default=None) @date.validator def check_date(self, attribute, value): @@ -641,16 +797,16 @@ return {"origin": self.origin, "date": str(self.date)} -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class OriginVisitStatus(BaseModel): """Represents a visit update of an origin at a given point in time.""" object_type: Final = "origin_visit_status" - origin = attr.ib(type=str, validator=type_validator()) - visit = attr.ib(type=int, validator=type_validator()) + origin = attr.ib(type=str, validator=generic_type_validator) + visit = attr.ib(type=int, validator=generic_type_validator) - date = attr.ib(type=datetime.datetime, validator=type_validator()) + date = attr.ib(type=datetime.datetime, validator=generic_type_validator) status = attr.ib( type=str, validator=attr.validators.in_( @@ -658,13 +814,13 @@ ), ) snapshot = attr.ib( - type=Optional[Sha1Git], validator=type_validator(), repr=hash_repr + type=Optional[Sha1Git], validator=generic_type_validator, repr=hash_repr ) # Type is optional be to able to use it before adding it to the database model - type = attr.ib(type=Optional[str], validator=type_validator(), default=None) + type = attr.ib(type=Optional[str], validator=generic_type_validator, default=None) metadata = attr.ib( type=Optional[ImmutableDict[str, object]], - validator=type_validator(), + validator=generic_type_validator, converter=freeze_optional_dict, default=None, ) @@ -707,14 +863,14 @@ return f"ObjectType.{self.name}" -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class SnapshotBranch(BaseModel): """Represents one of the branches of a snapshot.""" object_type: Final = "snapshot_branch" - target = attr.ib(type=bytes, validator=type_validator(), repr=hash_repr) - target_type = attr.ib(type=TargetType, validator=type_validator()) + target = attr.ib(type=bytes, validator=generic_type_validator, repr=hash_repr) + target_type = attr.ib(type=TargetType, validator=generic_type_validator) @target.validator def check_target(self, attribute, value): @@ -729,7 +885,7 @@ return cls(target=d["target"], target_type=TargetType(d["target_type"])) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class Snapshot(HashableObject, BaseModel): """Represents the full state of an origin at a given point in time.""" @@ -737,10 +893,12 @@ branches = attr.ib( type=ImmutableDict[bytes, Optional[SnapshotBranch]], - validator=type_validator(), + validator=generic_type_validator, converter=freeze_optional_dict, ) - id = attr.ib(type=Sha1Git, validator=type_validator(), default=b"", repr=hash_repr) + id = attr.ib( + type=Sha1Git, validator=generic_type_validator, default=b"", repr=hash_repr + ) def _compute_hash_from_attributes(self) -> bytes: return _compute_hash_from_manifest( @@ -763,26 +921,34 @@ return CoreSWHID(object_type=SwhidObjectType.SNAPSHOT, object_id=self.id) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class Release(HashableObjectWithManifest, BaseModel): object_type: Final = "release" - name = attr.ib(type=bytes, validator=type_validator()) - message = attr.ib(type=Optional[bytes], validator=type_validator()) - target = attr.ib(type=Optional[Sha1Git], validator=type_validator(), repr=hash_repr) - target_type = attr.ib(type=ObjectType, validator=type_validator()) - synthetic = attr.ib(type=bool, validator=type_validator()) - author = attr.ib(type=Optional[Person], validator=type_validator(), default=None) + name = attr.ib(type=bytes, validator=generic_type_validator) + message = attr.ib(type=Optional[bytes], validator=generic_type_validator) + target = attr.ib( + type=Optional[Sha1Git], validator=generic_type_validator, repr=hash_repr + ) + target_type = attr.ib(type=ObjectType, validator=generic_type_validator) + synthetic = attr.ib(type=bool, validator=generic_type_validator) + author = attr.ib( + type=Optional[Person], validator=generic_type_validator, default=None + ) date = attr.ib( - type=Optional[TimestampWithTimezone], validator=type_validator(), default=None + type=Optional[TimestampWithTimezone], + validator=generic_type_validator, + default=None, ) metadata = attr.ib( type=Optional[ImmutableDict[str, object]], - validator=type_validator(), + validator=generic_type_validator, converter=freeze_optional_dict, default=None, ) - id = attr.ib(type=Sha1Git, validator=type_validator(), default=b"", repr=hash_repr) + id = attr.ib( + type=Sha1Git, validator=generic_type_validator, default=b"", repr=hash_repr + ) raw_manifest = attr.ib(type=Optional[bytes], default=None) def _compute_hash_from_attributes(self) -> bytes: @@ -839,31 +1005,37 @@ return tuple((k, v) for k, v in value) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class Revision(HashableObjectWithManifest, BaseModel): object_type: Final = "revision" - message = attr.ib(type=Optional[bytes], validator=type_validator()) - author = attr.ib(type=Optional[Person], validator=type_validator()) - committer = attr.ib(type=Optional[Person], validator=type_validator()) - date = attr.ib(type=Optional[TimestampWithTimezone], validator=type_validator()) + message = attr.ib(type=Optional[bytes], validator=generic_type_validator) + author = attr.ib(type=Optional[Person], validator=generic_type_validator) + committer = attr.ib(type=Optional[Person], validator=generic_type_validator) + date = attr.ib( + type=Optional[TimestampWithTimezone], validator=generic_type_validator + ) committer_date = attr.ib( - type=Optional[TimestampWithTimezone], validator=type_validator() + type=Optional[TimestampWithTimezone], validator=generic_type_validator ) - type = attr.ib(type=RevisionType, validator=type_validator()) - directory = attr.ib(type=Sha1Git, validator=type_validator(), repr=hash_repr) - synthetic = attr.ib(type=bool, validator=type_validator()) + type = attr.ib(type=RevisionType, validator=generic_type_validator) + directory = attr.ib(type=Sha1Git, validator=generic_type_validator, repr=hash_repr) + synthetic = attr.ib(type=bool, validator=generic_type_validator) metadata = attr.ib( type=Optional[ImmutableDict[str, object]], - validator=type_validator(), + validator=generic_type_validator, converter=freeze_optional_dict, default=None, ) - parents = attr.ib(type=Tuple[Sha1Git, ...], validator=type_validator(), default=()) - id = attr.ib(type=Sha1Git, validator=type_validator(), default=b"", repr=hash_repr) + parents = attr.ib( + type=Tuple[Sha1Git, ...], validator=generic_type_validator, default=() + ) + id = attr.ib( + type=Sha1Git, validator=generic_type_validator, default=b"", repr=hash_repr + ) extra_headers = attr.ib( type=Tuple[Tuple[bytes, bytes], ...], - validator=type_validator(), + validator=generic_type_validator, converter=tuplify_extra_headers, default=(), ) @@ -951,14 +1123,14 @@ _DIR_ENTRY_TYPES = ["file", "dir", "rev"] -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class DirectoryEntry(BaseModel): object_type: Final = "directory_entry" - name = attr.ib(type=bytes, validator=type_validator()) + name = attr.ib(type=bytes, validator=generic_type_validator) type = attr.ib(type=str, validator=attr.validators.in_(_DIR_ENTRY_TYPES)) - target = attr.ib(type=Sha1Git, validator=type_validator(), repr=hash_repr) - perms = attr.ib(type=int, validator=type_validator(), converter=int, repr=oct) + target = attr.ib(type=Sha1Git, validator=generic_type_validator, repr=hash_repr) + perms = attr.ib(type=int, validator=generic_type_validator, converter=int, repr=oct) """Usually one of the values of `swh.model.from_disk.DentryPerms`.""" @name.validator @@ -967,12 +1139,14 @@ raise ValueError(f"{value!r} is not a valid directory entry name.") -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class Directory(HashableObjectWithManifest, BaseModel): object_type: Final = "directory" - entries = attr.ib(type=Tuple[DirectoryEntry, ...], validator=type_validator()) - id = attr.ib(type=Sha1Git, validator=type_validator(), default=b"", repr=hash_repr) + entries = attr.ib(type=Tuple[DirectoryEntry, ...], validator=generic_type_validator) + id = attr.ib( + type=Sha1Git, validator=generic_type_validator, default=b"", repr=hash_repr + ) raw_manifest = attr.ib(type=Optional[bytes], default=None) def _compute_hash_from_attributes(self) -> bytes: @@ -1086,7 +1260,7 @@ return (True, dir_) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class BaseContent(BaseModel): status = attr.ib( type=str, validator=attr.validators.in_(["visible", "hidden", "absent"]) @@ -1122,16 +1296,16 @@ return {algo: getattr(self, algo) for algo in DEFAULT_ALGORITHMS} -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class Content(BaseContent): object_type: Final = "content" - sha1 = attr.ib(type=bytes, validator=type_validator(), repr=hash_repr) - sha1_git = attr.ib(type=Sha1Git, validator=type_validator(), repr=hash_repr) - sha256 = attr.ib(type=bytes, validator=type_validator(), repr=hash_repr) - blake2s256 = attr.ib(type=bytes, validator=type_validator(), repr=hash_repr) + sha1 = attr.ib(type=bytes, validator=generic_type_validator, repr=hash_repr) + sha1_git = attr.ib(type=Sha1Git, validator=generic_type_validator, repr=hash_repr) + sha256 = attr.ib(type=bytes, validator=generic_type_validator, repr=hash_repr) + blake2s256 = attr.ib(type=bytes, validator=generic_type_validator, repr=hash_repr) - length = attr.ib(type=int, validator=type_validator()) + length = attr.ib(type=int, validator=generic_type_validator) status = attr.ib( type=str, @@ -1139,11 +1313,11 @@ default="visible", ) - data = attr.ib(type=Optional[bytes], validator=type_validator(), default=None) + data = attr.ib(type=Optional[bytes], validator=generic_type_validator, default=None) ctime = attr.ib( type=Optional[datetime.datetime], - validator=type_validator(), + validator=generic_type_validator, default=None, eq=False, ) @@ -1205,29 +1379,33 @@ return CoreSWHID(object_type=SwhidObjectType.CONTENT, object_id=self.sha1_git) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class SkippedContent(BaseContent): object_type: Final = "skipped_content" - sha1 = attr.ib(type=Optional[bytes], validator=type_validator(), repr=hash_repr) + sha1 = attr.ib( + type=Optional[bytes], validator=generic_type_validator, repr=hash_repr + ) sha1_git = attr.ib( - type=Optional[Sha1Git], validator=type_validator(), repr=hash_repr + type=Optional[Sha1Git], validator=generic_type_validator, repr=hash_repr + ) + sha256 = attr.ib( + type=Optional[bytes], validator=generic_type_validator, repr=hash_repr ) - sha256 = attr.ib(type=Optional[bytes], validator=type_validator(), repr=hash_repr) blake2s256 = attr.ib( - type=Optional[bytes], validator=type_validator(), repr=hash_repr + type=Optional[bytes], validator=generic_type_validator, repr=hash_repr ) - length = attr.ib(type=Optional[int], validator=type_validator()) + length = attr.ib(type=Optional[int], validator=generic_type_validator) status = attr.ib(type=str, validator=attr.validators.in_(["absent"])) - reason = attr.ib(type=Optional[str], validator=type_validator(), default=None) + reason = attr.ib(type=Optional[str], validator=generic_type_validator, default=None) - origin = attr.ib(type=Optional[str], validator=type_validator(), default=None) + origin = attr.ib(type=Optional[str], validator=generic_type_validator, default=None) ctime = attr.ib( type=Optional[datetime.datetime], - validator=type_validator(), + validator=generic_type_validator, default=None, eq=False, ) @@ -1298,19 +1476,19 @@ return f"MetadataAuthorityType.{self.name}" -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class MetadataAuthority(BaseModel): """Represents an entity that provides metadata about an origin or software artifact.""" object_type: Final = "metadata_authority" - type = attr.ib(type=MetadataAuthorityType, validator=type_validator()) - url = attr.ib(type=str, validator=type_validator()) + type = attr.ib(type=MetadataAuthorityType, validator=generic_type_validator) + url = attr.ib(type=str, validator=generic_type_validator) metadata = attr.ib( type=Optional[ImmutableDict[str, Any]], default=None, - validator=type_validator(), + validator=generic_type_validator, converter=freeze_optional_dict, ) @@ -1332,19 +1510,19 @@ return {"type": self.type.value, "url": self.url} -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class MetadataFetcher(BaseModel): """Represents a software component used to fetch metadata from a metadata authority, and ingest them into the Software Heritage archive.""" object_type: Final = "metadata_fetcher" - name = attr.ib(type=str, validator=type_validator()) - version = attr.ib(type=str, validator=type_validator()) + name = attr.ib(type=str, validator=generic_type_validator) + version = attr.ib(type=str, validator=generic_type_validator) metadata = attr.ib( type=Optional[ImmutableDict[str, Any]], default=None, - validator=type_validator(), + validator=generic_type_validator, converter=freeze_optional_dict, ) @@ -1369,40 +1547,42 @@ return value.astimezone(datetime.timezone.utc).replace(microsecond=0) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class RawExtrinsicMetadata(HashableObject, BaseModel): object_type: Final = "raw_extrinsic_metadata" # target object - target = attr.ib(type=ExtendedSWHID, validator=type_validator()) + target = attr.ib(type=ExtendedSWHID, validator=generic_type_validator) # source discovery_date = attr.ib(type=datetime.datetime, converter=normalize_discovery_date) - authority = attr.ib(type=MetadataAuthority, validator=type_validator()) - fetcher = attr.ib(type=MetadataFetcher, validator=type_validator()) + authority = attr.ib(type=MetadataAuthority, validator=generic_type_validator) + fetcher = attr.ib(type=MetadataFetcher, validator=generic_type_validator) # the metadata itself - format = attr.ib(type=str, validator=type_validator()) - metadata = attr.ib(type=bytes, validator=type_validator()) + format = attr.ib(type=str, validator=generic_type_validator) + metadata = attr.ib(type=bytes, validator=generic_type_validator) # context - origin = attr.ib(type=Optional[str], default=None, validator=type_validator()) - visit = attr.ib(type=Optional[int], default=None, validator=type_validator()) + origin = attr.ib(type=Optional[str], default=None, validator=generic_type_validator) + visit = attr.ib(type=Optional[int], default=None, validator=generic_type_validator) snapshot = attr.ib( - type=Optional[CoreSWHID], default=None, validator=type_validator() + type=Optional[CoreSWHID], default=None, validator=generic_type_validator ) release = attr.ib( - type=Optional[CoreSWHID], default=None, validator=type_validator() + type=Optional[CoreSWHID], default=None, validator=generic_type_validator ) revision = attr.ib( - type=Optional[CoreSWHID], default=None, validator=type_validator() + type=Optional[CoreSWHID], default=None, validator=generic_type_validator ) - path = attr.ib(type=Optional[bytes], default=None, validator=type_validator()) + path = attr.ib(type=Optional[bytes], default=None, validator=generic_type_validator) directory = attr.ib( - type=Optional[CoreSWHID], default=None, validator=type_validator() + type=Optional[CoreSWHID], default=None, validator=generic_type_validator ) - id = attr.ib(type=Sha1Git, validator=type_validator(), default=b"", repr=hash_repr) + id = attr.ib( + type=Sha1Git, validator=generic_type_validator, default=b"", repr=hash_repr + ) def _compute_hash_from_attributes(self) -> bytes: return _compute_hash_from_manifest( @@ -1592,16 +1772,18 @@ ) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class ExtID(HashableObject, BaseModel): object_type: Final = "extid" - extid_type = attr.ib(type=str, validator=type_validator()) - extid = attr.ib(type=bytes, validator=type_validator()) - target = attr.ib(type=CoreSWHID, validator=type_validator()) - extid_version = attr.ib(type=int, validator=type_validator(), default=0) + extid_type = attr.ib(type=str, validator=generic_type_validator) + extid = attr.ib(type=bytes, validator=generic_type_validator) + target = attr.ib(type=CoreSWHID, validator=generic_type_validator) + extid_version = attr.ib(type=int, validator=generic_type_validator, default=0) - id = attr.ib(type=Sha1Git, validator=type_validator(), default=b"", repr=hash_repr) + id = attr.ib( + type=Sha1Git, validator=generic_type_validator, default=b"", repr=hash_repr + ) @classmethod def from_dict(cls, d): diff --git a/swh/model/tests/test_model.py b/swh/model/tests/test_model.py --- a/swh/model/tests/test_model.py +++ b/swh/model/tests/test_model.py @@ -43,7 +43,8 @@ TargetType, Timestamp, TimestampWithTimezone, - type_validator, + generic_type_validator, + optimized_validator, ) import swh.model.swhids from swh.model.swhids import CoreSWHID, ExtendedSWHID, ObjectType @@ -275,8 +276,34 @@ for value in values ], ) -def test_type_validator_valid(type_, value): - type_validator()(None, attr.ib(type=type_), value) +def test_generic_type_validator_valid(type_, value): + generic_type_validator(None, attr.ib(type=type_), value) + + +@pytest.mark.parametrize( + "type_,value", + [ + pytest.param(type_, value, id=f"type={type_}, value={value}") + for (type_, values, _) in _TYPE_VALIDATOR_PARAMETERS + for value in values + ], +) +def test_optimized_type_validator_valid(type_, value): + validator = optimized_validator(type_) + validator(None, attr.ib(type=type_), value) + + +@pytest.mark.parametrize( + "type_,value", + [ + pytest.param(type_, value, id=f"type={type_}, value={value}") + for (type_, _, values) in _TYPE_VALIDATOR_PARAMETERS + for value in values + ], +) +def test_generic_type_validator_invalid(type_, value): + with pytest.raises(AttributeTypeError): + generic_type_validator(None, attr.ib(type=type_), value) @pytest.mark.parametrize( @@ -287,9 +314,10 @@ for value in values ], ) -def test_type_validator_invalid(type_, value): +def test_optimized_type_validator_invalid(type_, value): + validator = optimized_validator(type_) with pytest.raises(AttributeTypeError): - type_validator()(None, attr.ib(type=type_), value) + validator(None, attr.ib(type=type_), value) @pytest.mark.parametrize("object_type, objects", TEST_OBJECTS.items())