diff --git a/swh/model/model.py b/swh/model/model.py --- a/swh/model/model.py +++ b/swh/model/model.py @@ -23,6 +23,8 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union import attr +from attr._make import _AndValidator +from attr.validators import and_ from attrs_strict import AttributeTypeError import dateutil.parser import iso8601 @@ -93,72 +95,238 @@ return value -def _check_type(type_, value): +def generic_type_validator(instance, attribute, value): + """validates the type of an attribute value whatever the attribute type""" + raise NotImplementedError("generic type check should have been optimized") + + +def _true_validator(instance, attribute, value, expected_type=None, origin_value=None): + pass + + +def _none_validator(instance, attribute, value, expected_type=None, origin_value=None): + if value is not None: + if origin_value is None: + origin_value = value + raise AttributeTypeError(origin_value, attribute) + + +def _origin_type_validator( + instance, attribute, value, expected_type=None, origin_value=None +): + # This is functionally equivalent to using just this: + # return isinstance(value, type) + # but using type equality before isinstance allows very quick checks + # when the exact class is used (which is the overwhelming majority of cases) + # while still allowing subclasses to be used. + if expected_type is None: + expected_type = attribute.type + if not (type(value) == expected_type or isinstance(value, expected_type)): + if origin_value is None: + origin_value = value + raise AttributeTypeError(origin_value, attribute) + + +def _tuple_infinite_validator( + instance, + attribute, + value, + expected_type=None, + origin_value=None, +): + type_ = type(value) + if origin_value is None: + origin_value = value + if type_ != tuple and not isinstance(value, tuple): + raise AttributeTypeError(origin_value, attribute) + if expected_type is None: + expected_type = attribute.type + args = expected_type.__args__ + # assert len(args) == 2 and args[1] is Ellipsis + expected_value_type = args[0] + validator = optimized_validator(expected_value_type) + for i in value: + validator( + instance, + attribute, + i, + expected_type=expected_value_type, + origin_value=origin_value, + ) + + +def _tuple_bytes_bytes_validator( + instance, + attribute, + value, + expected_type=None, + origin_value=None, +): + type_ = type(value) + if type_ != tuple and not isinstance(value, tuple): + if origin_value is None: + origin_value = value + raise AttributeTypeError(origin_value, attribute) + if len(value) != 2: + if origin_value is None: + origin_value = value + raise AttributeTypeError(origin_value, attribute) + if type(value[0]) is not bytes or type(value[1]) is not bytes: + if origin_value is None: + origin_value = value + raise AttributeTypeError(origin_value, attribute) + + +def _tuple_finite_validator( + instance, + attribute, + value, + expected_type=None, + origin_value=None, +): + # might be useful to optimise the sub-validator tuple, in practice, we only + # have [bytes, bytes] + type_ = type(value) + if origin_value is None: + origin_value = value + if type_ != tuple and not isinstance(value, tuple): + raise AttributeTypeError(origin_value, attribute) + if expected_type is None: + expected_type = attribute.type + args = expected_type.__args__ + + # assert len(args) != 2 or args[1] is Ellipsis + if len(args) != len(value): + raise AttributeTypeError(origin_value, attribute) + for item_type, item in zip(args, value): + validator = optimized_validator(item_type) + validator( + instance, + attribute, + item, + expected_type=item_type, + origin_value=origin_value, + ) + + +def _immutable_dict_validator( + instance, + attribute, + value, + expected_type=None, + origin_value=None, +): + value_type = type(value) + if origin_value is None: + origin_value = value + if value_type != ImmutableDict and not isinstance(value, ImmutableDict): + raise AttributeTypeError(origin_value, attribute) + + if expected_type is None: + expected_type = attribute.type + (expected_key_type, expected_value_type) = expected_type.__args__ + + key_validator = optimized_validator(expected_key_type) + value_validator = optimized_validator(expected_value_type) + + for (item_key, item_value) in value.items(): + key_validator( + instance, + attribute, + item_key, + expected_type=expected_key_type, + origin_value=origin_value, + ) + value_validator( + instance, + attribute, + item_value, + expected_type=expected_value_type, + origin_value=origin_value, + ) + + +def optimized_validator(type_): if type_ is object or type_ is Any: - return True + return _true_validator if type_ is None: - return value is None + return _none_validator origin = getattr(type_, "__origin__", None) # Non-generic type, check it directly if origin is None: - # This is functionally equivalent to using just this: - # return isinstance(value, type) - # but using type equality before isinstance allows very quick checks - # when the exact class is used (which is the overwhelming majority of cases) - # while still allowing subclasses to be used. - return type(value) == type_ or isinstance(value, type_) - - # Check the type of the value itself - # - # For the same reason as above, this condition is functionally equivalent to: - # if origin is not Union and not isinstance(value, origin): - if origin is not Union and type(value) != origin and not isinstance(value, origin): - return False + return _origin_type_validator # Then, if it's a container, check its items. if origin is tuple: args = type_.__args__ if len(args) == 2 and args[1] is Ellipsis: # Infinite tuple - return all(_check_type(args[0], item) for item in value) + return _tuple_infinite_validator + elif args == (bytes, bytes): + return _tuple_bytes_bytes_validator else: - # Finite tuple - if len(args) != len(value): - return False - - return all( - _check_type(item_type, item) for (item_type, item) in zip(args, value) - ) + return _tuple_finite_validator elif origin is Union: args = type_.__args__ - return any(_check_type(variant, value) for variant in args) - elif origin is ImmutableDict: - (key_type, value_type) = type_.__args__ - return all( - _check_type(key_type, key) and _check_type(value_type, value) - for (key, value) in value.items() - ) - else: - # No need to check dict or list. because they are converted to ImmutableDict - # and tuple respectively. - raise NotImplementedError(f"Type-checking {type_}") - - -def type_validator(): - """Like attrs_strict.type_validator(), but stricter. - - It is an attrs validator, which checks attributes have the specified type, - using type equality instead of ``isinstance()``, for improved performance - """ + all_validators = tuple((optimized_validator(t), t) for t in args) + + def union_validator( + instance, + attribute, + value, + expected_type=None, + origin_value=None, + ): + if origin_value is None: + origin_value = value + for (validator, type_) in all_validators: + try: + validator( + instance, + attribute, + value, + expected_type=type_, + origin_value=origin_value, + ) + except AttributeTypeError: + pass + else: + break + else: + raise AttributeTypeError(origin_value, attribute) - def validator(instance, attribute, value): - if not _check_type(attribute.type, value): - raise AttributeTypeError(value, attribute) + return union_validator + elif origin is ImmutableDict: + return _immutable_dict_validator + # No need to check dict or list. because they are converted to ImmutableDict + # and tuple respectively. + raise NotImplementedError(f"Type-checking {type_}") + + +def optimize_all_validators(cls, old_fields): + """process validators to turn them into a faster version … eventually""" + new_fields = [] + for f in old_fields: + validator = f.validator + if validator is generic_type_validator: + validator = optimized_validator(f.type) + elif isinstance(validator, _AndValidator): + new_and = [] + for v in validator._validators: + if v is generic_type_validator: + v = optimized_validator(f.type) + new_and.append(v) + validator = and_(*new_and) + else: + validator = None - return validator + if validator is not None: + f = f.evolve(validator=validator) + new_fields.append(f) + return new_fields ModelType = TypeVar("ModelType", bound="BaseModel") @@ -285,15 +453,15 @@ ) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class Person(BaseModel): """Represents the author/committer of a revision or release.""" object_type: Final = "person" - fullname = attr.ib(type=bytes, validator=type_validator()) - name = attr.ib(type=Optional[bytes], validator=type_validator(), eq=False) - email = attr.ib(type=Optional[bytes], validator=type_validator(), eq=False) + fullname = attr.ib(type=bytes, validator=generic_type_validator) + name = attr.ib(type=Optional[bytes], validator=generic_type_validator, eq=False) + email = attr.ib(type=Optional[bytes], validator=generic_type_validator, eq=False) @classmethod def from_fullname(cls, fullname: bytes): @@ -367,37 +535,41 @@ return super().from_dict(d) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class Timestamp(BaseModel): """Represents a naive timestamp from a VCS.""" object_type: Final = "timestamp" - seconds = attr.ib(type=int, validator=type_validator()) - microseconds = attr.ib(type=int, validator=type_validator()) + seconds = attr.ib(type=int) + microseconds = attr.ib(type=int) @seconds.validator def check_seconds(self, attribute, value): """Check that seconds fit in a 64-bits signed integer.""" + if value.__class__ is not int: + raise AttributeTypeError(value, attribute) if not (-(2**63) <= value < 2**63): raise ValueError("Seconds must be a signed 64-bits integer.") @microseconds.validator def check_microseconds(self, attribute, value): """Checks that microseconds are positive and < 1000000.""" + if value.__class__ is not int: + raise AttributeTypeError(value, attribute) if not (0 <= value < 10**6): raise ValueError("Microseconds must be in [0, 1000000[.") -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class TimestampWithTimezone(BaseModel): """Represents a TZ-aware timestamp from a VCS.""" object_type: Final = "timestamp_with_timezone" - timestamp = attr.ib(type=Timestamp, validator=type_validator()) + timestamp = attr.ib(type=Timestamp, validator=generic_type_validator) - offset_bytes = attr.ib(type=bytes, validator=type_validator()) + offset_bytes = attr.ib(type=bytes, validator=generic_type_validator) """Raw git representation of the timezone, as an offset from UTC. It should follow this format: ``+HHMM`` or ``-HHMM`` (including ``+0000`` and ``-0000``). @@ -586,15 +758,15 @@ return self._parse_offset_bytes(self.offset_bytes) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class Origin(HashableObject, BaseModel): """Represents a software source: a VCS and an URL.""" object_type: Final = "origin" - url = attr.ib(type=str, validator=type_validator()) + url = attr.ib(type=str, validator=generic_type_validator) - id = attr.ib(type=Sha1Git, validator=type_validator(), default=b"") + id = attr.ib(type=Sha1Git, validator=generic_type_validator, default=b"") def unique_key(self) -> KeyType: return {"url": self.url} @@ -610,22 +782,24 @@ ) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class OriginVisit(BaseModel): """Represents an origin visit with a given type at a given point in time, by a SWH loader.""" object_type: Final = "origin_visit" - origin = attr.ib(type=str, validator=type_validator()) - date = attr.ib(type=datetime.datetime, validator=type_validator()) - type = attr.ib(type=str, validator=type_validator()) + origin = attr.ib(type=str, validator=generic_type_validator) + date = attr.ib(type=datetime.datetime) + type = attr.ib(type=str, validator=generic_type_validator) """Should not be set before calling 'origin_visit_add()'.""" - visit = attr.ib(type=Optional[int], validator=type_validator(), default=None) + visit = attr.ib(type=Optional[int], validator=generic_type_validator, default=None) @date.validator def check_date(self, attribute, value): """Checks the date has a timezone.""" + if value.__class__ is not datetime.datetime: + raise AttributeTypeError(value, attribute) if value is not None and value.tzinfo is None: raise ValueError("date must be a timezone-aware datetime.") @@ -641,16 +815,16 @@ return {"origin": self.origin, "date": str(self.date)} -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class OriginVisitStatus(BaseModel): """Represents a visit update of an origin at a given point in time.""" object_type: Final = "origin_visit_status" - origin = attr.ib(type=str, validator=type_validator()) - visit = attr.ib(type=int, validator=type_validator()) + origin = attr.ib(type=str, validator=generic_type_validator) + visit = attr.ib(type=int, validator=generic_type_validator) - date = attr.ib(type=datetime.datetime, validator=type_validator()) + date = attr.ib(type=datetime.datetime) status = attr.ib( type=str, validator=attr.validators.in_( @@ -658,13 +832,13 @@ ), ) snapshot = attr.ib( - type=Optional[Sha1Git], validator=type_validator(), repr=hash_repr + type=Optional[Sha1Git], validator=generic_type_validator, repr=hash_repr ) # Type is optional be to able to use it before adding it to the database model - type = attr.ib(type=Optional[str], validator=type_validator(), default=None) + type = attr.ib(type=Optional[str], validator=generic_type_validator, default=None) metadata = attr.ib( type=Optional[ImmutableDict[str, object]], - validator=type_validator(), + validator=generic_type_validator, converter=freeze_optional_dict, default=None, ) @@ -672,6 +846,8 @@ @date.validator def check_date(self, attribute, value): """Checks the date has a timezone.""" + if value.__class__ is not datetime.datetime: + raise AttributeTypeError(value, attribute) if value is not None and value.tzinfo is None: raise ValueError("date must be a timezone-aware datetime.") @@ -707,19 +883,21 @@ return f"ObjectType.{self.name}" -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class SnapshotBranch(BaseModel): """Represents one of the branches of a snapshot.""" object_type: Final = "snapshot_branch" - target = attr.ib(type=bytes, validator=type_validator(), repr=hash_repr) - target_type = attr.ib(type=TargetType, validator=type_validator()) + target = attr.ib(type=bytes, repr=hash_repr) + target_type = attr.ib(type=TargetType, validator=generic_type_validator) @target.validator def check_target(self, attribute, value): """Checks the target type is not an alias, checks the target is a valid sha1_git.""" + if value.__class__ is not bytes: + raise AttributeTypeError(value, attribute) if self.target_type != TargetType.ALIAS and self.target is not None: if len(value) != 20: raise ValueError("Wrong length for bytes identifier: %d" % len(value)) @@ -729,7 +907,7 @@ return cls(target=d["target"], target_type=TargetType(d["target_type"])) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class Snapshot(HashableObject, BaseModel): """Represents the full state of an origin at a given point in time.""" @@ -737,10 +915,12 @@ branches = attr.ib( type=ImmutableDict[bytes, Optional[SnapshotBranch]], - validator=type_validator(), + validator=generic_type_validator, converter=freeze_optional_dict, ) - id = attr.ib(type=Sha1Git, validator=type_validator(), default=b"", repr=hash_repr) + id = attr.ib( + type=Sha1Git, validator=generic_type_validator, default=b"", repr=hash_repr + ) def _compute_hash_from_attributes(self) -> bytes: return _compute_hash_from_manifest( @@ -763,26 +943,34 @@ return CoreSWHID(object_type=SwhidObjectType.SNAPSHOT, object_id=self.id) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class Release(HashableObjectWithManifest, BaseModel): object_type: Final = "release" - name = attr.ib(type=bytes, validator=type_validator()) - message = attr.ib(type=Optional[bytes], validator=type_validator()) - target = attr.ib(type=Optional[Sha1Git], validator=type_validator(), repr=hash_repr) - target_type = attr.ib(type=ObjectType, validator=type_validator()) - synthetic = attr.ib(type=bool, validator=type_validator()) - author = attr.ib(type=Optional[Person], validator=type_validator(), default=None) + name = attr.ib(type=bytes, validator=generic_type_validator) + message = attr.ib(type=Optional[bytes], validator=generic_type_validator) + target = attr.ib( + type=Optional[Sha1Git], validator=generic_type_validator, repr=hash_repr + ) + target_type = attr.ib(type=ObjectType, validator=generic_type_validator) + synthetic = attr.ib(type=bool, validator=generic_type_validator) + author = attr.ib( + type=Optional[Person], validator=generic_type_validator, default=None + ) date = attr.ib( - type=Optional[TimestampWithTimezone], validator=type_validator(), default=None + type=Optional[TimestampWithTimezone], + validator=generic_type_validator, + default=None, ) metadata = attr.ib( type=Optional[ImmutableDict[str, object]], - validator=type_validator(), + validator=generic_type_validator, converter=freeze_optional_dict, default=None, ) - id = attr.ib(type=Sha1Git, validator=type_validator(), default=b"", repr=hash_repr) + id = attr.ib( + type=Sha1Git, validator=generic_type_validator, default=b"", repr=hash_repr + ) raw_manifest = attr.ib(type=Optional[bytes], default=None) def _compute_hash_from_attributes(self) -> bytes: @@ -839,31 +1027,37 @@ return tuple((k, v) for k, v in value) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class Revision(HashableObjectWithManifest, BaseModel): object_type: Final = "revision" - message = attr.ib(type=Optional[bytes], validator=type_validator()) - author = attr.ib(type=Optional[Person], validator=type_validator()) - committer = attr.ib(type=Optional[Person], validator=type_validator()) - date = attr.ib(type=Optional[TimestampWithTimezone], validator=type_validator()) + message = attr.ib(type=Optional[bytes], validator=generic_type_validator) + author = attr.ib(type=Optional[Person], validator=generic_type_validator) + committer = attr.ib(type=Optional[Person], validator=generic_type_validator) + date = attr.ib( + type=Optional[TimestampWithTimezone], validator=generic_type_validator + ) committer_date = attr.ib( - type=Optional[TimestampWithTimezone], validator=type_validator() + type=Optional[TimestampWithTimezone], validator=generic_type_validator ) - type = attr.ib(type=RevisionType, validator=type_validator()) - directory = attr.ib(type=Sha1Git, validator=type_validator(), repr=hash_repr) - synthetic = attr.ib(type=bool, validator=type_validator()) + type = attr.ib(type=RevisionType, validator=generic_type_validator) + directory = attr.ib(type=Sha1Git, validator=generic_type_validator, repr=hash_repr) + synthetic = attr.ib(type=bool, validator=generic_type_validator) metadata = attr.ib( type=Optional[ImmutableDict[str, object]], - validator=type_validator(), + validator=generic_type_validator, converter=freeze_optional_dict, default=None, ) - parents = attr.ib(type=Tuple[Sha1Git, ...], validator=type_validator(), default=()) - id = attr.ib(type=Sha1Git, validator=type_validator(), default=b"", repr=hash_repr) + parents = attr.ib( + type=Tuple[Sha1Git, ...], validator=generic_type_validator, default=() + ) + id = attr.ib( + type=Sha1Git, validator=generic_type_validator, default=b"", repr=hash_repr + ) extra_headers = attr.ib( type=Tuple[Tuple[bytes, bytes], ...], - validator=type_validator(), + validator=generic_type_validator, converter=tuplify_extra_headers, default=(), ) @@ -951,28 +1145,32 @@ _DIR_ENTRY_TYPES = ["file", "dir", "rev"] -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class DirectoryEntry(BaseModel): object_type: Final = "directory_entry" - name = attr.ib(type=bytes, validator=type_validator()) + name = attr.ib(type=bytes) type = attr.ib(type=str, validator=attr.validators.in_(_DIR_ENTRY_TYPES)) - target = attr.ib(type=Sha1Git, validator=type_validator(), repr=hash_repr) - perms = attr.ib(type=int, validator=type_validator(), converter=int, repr=oct) + target = attr.ib(type=Sha1Git, validator=generic_type_validator, repr=hash_repr) + perms = attr.ib(type=int, validator=generic_type_validator, converter=int, repr=oct) """Usually one of the values of `swh.model.from_disk.DentryPerms`.""" @name.validator def check_name(self, attribute, value): + if value.__class__ is not bytes: + raise AttributeTypeError(value, attribute) if b"/" in value: raise ValueError(f"{value!r} is not a valid directory entry name.") -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class Directory(HashableObjectWithManifest, BaseModel): object_type: Final = "directory" - entries = attr.ib(type=Tuple[DirectoryEntry, ...], validator=type_validator()) - id = attr.ib(type=Sha1Git, validator=type_validator(), default=b"", repr=hash_repr) + entries = attr.ib(type=Tuple[DirectoryEntry, ...], validator=generic_type_validator) + id = attr.ib( + type=Sha1Git, validator=generic_type_validator, default=b"", repr=hash_repr + ) raw_manifest = attr.ib(type=Optional[bytes], default=None) def _compute_hash_from_attributes(self) -> bytes: @@ -1086,7 +1284,7 @@ return (True, dir_) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class BaseContent(BaseModel): status = attr.ib( type=str, validator=attr.validators.in_(["visible", "hidden", "absent"]) @@ -1122,16 +1320,16 @@ return {algo: getattr(self, algo) for algo in DEFAULT_ALGORITHMS} -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class Content(BaseContent): object_type: Final = "content" - sha1 = attr.ib(type=bytes, validator=type_validator(), repr=hash_repr) - sha1_git = attr.ib(type=Sha1Git, validator=type_validator(), repr=hash_repr) - sha256 = attr.ib(type=bytes, validator=type_validator(), repr=hash_repr) - blake2s256 = attr.ib(type=bytes, validator=type_validator(), repr=hash_repr) + sha1 = attr.ib(type=bytes, validator=generic_type_validator, repr=hash_repr) + sha1_git = attr.ib(type=Sha1Git, validator=generic_type_validator, repr=hash_repr) + sha256 = attr.ib(type=bytes, validator=generic_type_validator, repr=hash_repr) + blake2s256 = attr.ib(type=bytes, validator=generic_type_validator, repr=hash_repr) - length = attr.ib(type=int, validator=type_validator()) + length = attr.ib(type=int) status = attr.ib( type=str, @@ -1139,11 +1337,10 @@ default="visible", ) - data = attr.ib(type=Optional[bytes], validator=type_validator(), default=None) + data = attr.ib(type=Optional[bytes], validator=generic_type_validator, default=None) ctime = attr.ib( type=Optional[datetime.datetime], - validator=type_validator(), default=None, eq=False, ) @@ -1151,14 +1348,19 @@ @length.validator def check_length(self, attribute, value): """Checks the length is positive.""" + if value.__class__ is not int: + raise AttributeTypeError(value, attribute) if value < 0: raise ValueError("Length must be positive.") @ctime.validator def check_ctime(self, attribute, value): """Checks the ctime has a timezone.""" - if value is not None and value.tzinfo is None: - raise ValueError("ctime must be a timezone-aware datetime.") + if value is not None: + if value.__class__ is not datetime.datetime: + raise AttributeTypeError(value, attribute) + if value.tzinfo is None: + raise ValueError("ctime must be a timezone-aware datetime.") def to_dict(self): content = super().to_dict() @@ -1205,29 +1407,33 @@ return CoreSWHID(object_type=SwhidObjectType.CONTENT, object_id=self.sha1_git) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class SkippedContent(BaseContent): object_type: Final = "skipped_content" - sha1 = attr.ib(type=Optional[bytes], validator=type_validator(), repr=hash_repr) + sha1 = attr.ib( + type=Optional[bytes], validator=generic_type_validator, repr=hash_repr + ) sha1_git = attr.ib( - type=Optional[Sha1Git], validator=type_validator(), repr=hash_repr + type=Optional[Sha1Git], validator=generic_type_validator, repr=hash_repr + ) + sha256 = attr.ib( + type=Optional[bytes], validator=generic_type_validator, repr=hash_repr ) - sha256 = attr.ib(type=Optional[bytes], validator=type_validator(), repr=hash_repr) blake2s256 = attr.ib( - type=Optional[bytes], validator=type_validator(), repr=hash_repr + type=Optional[bytes], validator=generic_type_validator, repr=hash_repr ) - length = attr.ib(type=Optional[int], validator=type_validator()) + length = attr.ib(type=Optional[int]) status = attr.ib(type=str, validator=attr.validators.in_(["absent"])) - reason = attr.ib(type=Optional[str], validator=type_validator(), default=None) + reason = attr.ib(type=Optional[str], default=None) - origin = attr.ib(type=Optional[str], validator=type_validator(), default=None) + origin = attr.ib(type=Optional[str], validator=generic_type_validator, default=None) ctime = attr.ib( type=Optional[datetime.datetime], - validator=type_validator(), + validator=generic_type_validator, default=None, eq=False, ) @@ -1238,18 +1444,25 @@ assert self.reason == value if value is None: raise ValueError("Must provide a reason if content is absent.") + elif value.__class__ is not str: + raise AttributeTypeError(value, attribute) @length.validator def check_length(self, attribute, value): """Checks the length is positive or -1.""" - if value < -1: + if value.__class__ is not int: + raise AttributeTypeError(value, attribute) + elif value < -1: raise ValueError("Length must be positive or -1.") @ctime.validator def check_ctime(self, attribute, value): """Checks the ctime has a timezone.""" - if value is not None and value.tzinfo is None: - raise ValueError("ctime must be a timezone-aware datetime.") + if value is not None: + if value.__class__ is not datetime.datetime: + raise AttributeTypeError(value, attribute) + elif value.tzinfo is None: + raise ValueError("ctime must be a timezone-aware datetime.") def to_dict(self): content = super().to_dict() @@ -1298,19 +1511,19 @@ return f"MetadataAuthorityType.{self.name}" -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class MetadataAuthority(BaseModel): """Represents an entity that provides metadata about an origin or software artifact.""" object_type: Final = "metadata_authority" - type = attr.ib(type=MetadataAuthorityType, validator=type_validator()) - url = attr.ib(type=str, validator=type_validator()) + type = attr.ib(type=MetadataAuthorityType, validator=generic_type_validator) + url = attr.ib(type=str, validator=generic_type_validator) metadata = attr.ib( type=Optional[ImmutableDict[str, Any]], default=None, - validator=type_validator(), + validator=generic_type_validator, converter=freeze_optional_dict, ) @@ -1332,19 +1545,19 @@ return {"type": self.type.value, "url": self.url} -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class MetadataFetcher(BaseModel): """Represents a software component used to fetch metadata from a metadata authority, and ingest them into the Software Heritage archive.""" object_type: Final = "metadata_fetcher" - name = attr.ib(type=str, validator=type_validator()) - version = attr.ib(type=str, validator=type_validator()) + name = attr.ib(type=str, validator=generic_type_validator) + version = attr.ib(type=str, validator=generic_type_validator) metadata = attr.ib( type=Optional[ImmutableDict[str, Any]], default=None, - validator=type_validator(), + validator=generic_type_validator, converter=freeze_optional_dict, ) @@ -1369,40 +1582,34 @@ return value.astimezone(datetime.timezone.utc).replace(microsecond=0) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class RawExtrinsicMetadata(HashableObject, BaseModel): object_type: Final = "raw_extrinsic_metadata" # target object - target = attr.ib(type=ExtendedSWHID, validator=type_validator()) + target = attr.ib(type=ExtendedSWHID, validator=generic_type_validator) # source discovery_date = attr.ib(type=datetime.datetime, converter=normalize_discovery_date) - authority = attr.ib(type=MetadataAuthority, validator=type_validator()) - fetcher = attr.ib(type=MetadataFetcher, validator=type_validator()) + authority = attr.ib(type=MetadataAuthority, validator=generic_type_validator) + fetcher = attr.ib(type=MetadataFetcher, validator=generic_type_validator) # the metadata itself - format = attr.ib(type=str, validator=type_validator()) - metadata = attr.ib(type=bytes, validator=type_validator()) + format = attr.ib(type=str, validator=generic_type_validator) + metadata = attr.ib(type=bytes, validator=generic_type_validator) # context - origin = attr.ib(type=Optional[str], default=None, validator=type_validator()) - visit = attr.ib(type=Optional[int], default=None, validator=type_validator()) - snapshot = attr.ib( - type=Optional[CoreSWHID], default=None, validator=type_validator() + origin = attr.ib(type=Optional[str], default=None, validator=generic_type_validator) + visit = attr.ib(type=Optional[int], default=None) + snapshot = attr.ib(type=Optional[CoreSWHID], default=None) + release = attr.ib(type=Optional[CoreSWHID], default=None) + revision = attr.ib(type=Optional[CoreSWHID], default=None) + path = attr.ib(type=Optional[bytes], default=None) + directory = attr.ib(type=Optional[CoreSWHID], default=None) + + id = attr.ib( + type=Sha1Git, validator=generic_type_validator, default=b"", repr=hash_repr ) - release = attr.ib( - type=Optional[CoreSWHID], default=None, validator=type_validator() - ) - revision = attr.ib( - type=Optional[CoreSWHID], default=None, validator=type_validator() - ) - path = attr.ib(type=Optional[bytes], default=None, validator=type_validator()) - directory = attr.ib( - type=Optional[CoreSWHID], default=None, validator=type_validator() - ) - - id = attr.ib(type=Sha1Git, validator=type_validator(), default=b"", repr=hash_repr) def _compute_hash_from_attributes(self) -> bytes: return _compute_hash_from_manifest( @@ -1414,12 +1621,15 @@ if value is None: return - if self.target.object_type not in ( - SwhidExtendedObjectType.SNAPSHOT, - SwhidExtendedObjectType.RELEASE, - SwhidExtendedObjectType.REVISION, - SwhidExtendedObjectType.DIRECTORY, - SwhidExtendedObjectType.CONTENT, + if value.__class__ is not str: + raise AttributeTypeError(value, attribute) + obj_type = self.target.object_type + if not ( + obj_type is SwhidExtendedObjectType.SNAPSHOT + or obj_type is SwhidExtendedObjectType.RELEASE + or obj_type is SwhidExtendedObjectType.REVISION + or obj_type is SwhidExtendedObjectType.DIRECTORY + or obj_type is SwhidExtendedObjectType.CONTENT ): raise ValueError( f"Unexpected 'origin' context for " @@ -1438,13 +1648,16 @@ def check_visit(self, attribute, value): if value is None: return + if value.__class__ is not int: + raise AttributeTypeError(value, attribute) - if self.target.object_type not in ( - SwhidExtendedObjectType.SNAPSHOT, - SwhidExtendedObjectType.RELEASE, - SwhidExtendedObjectType.REVISION, - SwhidExtendedObjectType.DIRECTORY, - SwhidExtendedObjectType.CONTENT, + obj_type = self.target.object_type + if not ( + obj_type is SwhidExtendedObjectType.SNAPSHOT + or obj_type is SwhidExtendedObjectType.RELEASE + or obj_type is SwhidExtendedObjectType.REVISION + or obj_type is SwhidExtendedObjectType.DIRECTORY + or obj_type is SwhidExtendedObjectType.CONTENT ): raise ValueError( f"Unexpected 'visit' context for " @@ -1461,61 +1674,87 @@ def check_snapshot(self, attribute, value): if value is None: return + if value.__class__ is not CoreSWHID: + raise AttributeTypeError(value, attribute) - if self.target.object_type not in ( - SwhidExtendedObjectType.RELEASE, - SwhidExtendedObjectType.REVISION, - SwhidExtendedObjectType.DIRECTORY, - SwhidExtendedObjectType.CONTENT, + obj_type = self.target.object_type + if not ( + obj_type is SwhidExtendedObjectType.RELEASE + or obj_type is SwhidExtendedObjectType.REVISION + or obj_type is SwhidExtendedObjectType.DIRECTORY + or obj_type is SwhidExtendedObjectType.CONTENT ): raise ValueError( f"Unexpected 'snapshot' context for " f"{self.target.object_type.name.lower()} object: {value}" ) - self._check_swhid(SwhidObjectType.SNAPSHOT, value) + if value.object_type != SwhidObjectType.SNAPSHOT: + raise ValueError( + f"Expected SWHID type 'snapshot', " + f"got '{value.object_type.name.lower()}' in {value}" + ) @release.validator def check_release(self, attribute, value): if value is None: return + if value.__class__ is not CoreSWHID: + raise AttributeTypeError(value, attribute) - if self.target.object_type not in ( - SwhidExtendedObjectType.REVISION, - SwhidExtendedObjectType.DIRECTORY, - SwhidExtendedObjectType.CONTENT, + obj_type = self.target.object_type + if not ( + obj_type is SwhidExtendedObjectType.REVISION + or obj_type is SwhidExtendedObjectType.DIRECTORY + or obj_type is SwhidExtendedObjectType.CONTENT ): raise ValueError( f"Unexpected 'release' context for " f"{self.target.object_type.name.lower()} object: {value}" ) - self._check_swhid(SwhidObjectType.RELEASE, value) + if value.object_type != SwhidObjectType.RELEASE: + raise ValueError( + f"Expected SWHID type 'release', " + f"got '{value.object_type.name.lower()}' in {value}" + ) @revision.validator def check_revision(self, attribute, value): if value is None: return - if self.target.object_type not in ( - SwhidExtendedObjectType.DIRECTORY, - SwhidExtendedObjectType.CONTENT, + if value.__class__ is not CoreSWHID: + raise AttributeTypeError(value, attribute) + + obj_type = self.target.object_type + if not ( + obj_type is SwhidExtendedObjectType.DIRECTORY + or obj_type is SwhidExtendedObjectType.CONTENT ): raise ValueError( f"Unexpected 'revision' context for " f"{self.target.object_type.name.lower()} object: {value}" ) - self._check_swhid(SwhidObjectType.REVISION, value) + if value.object_type != SwhidObjectType.REVISION: + raise ValueError( + f"Expected SWHID type 'revision', " + f"got '{value.object_type.name.lower()}' in {value}" + ) @path.validator def check_path(self, attribute, value): if value is None: return - if self.target.object_type not in ( - SwhidExtendedObjectType.DIRECTORY, - SwhidExtendedObjectType.CONTENT, + if value.__class__ is not bytes: + raise AttributeTypeError(value, attribute) + + obj_type = self.target.object_type + if not ( + obj_type is SwhidExtendedObjectType.DIRECTORY + or obj_type is SwhidExtendedObjectType.CONTENT ): raise ValueError( f"Unexpected 'path' context for " @@ -1527,22 +1766,19 @@ if value is None: return - if self.target.object_type not in (SwhidExtendedObjectType.CONTENT,): + if value.__class__ is not CoreSWHID: + raise AttributeTypeError(value, attribute) + + if self.target.object_type is not SwhidExtendedObjectType.CONTENT: raise ValueError( f"Unexpected 'directory' context for " f"{self.target.object_type.name.lower()} object: {value}" ) - self._check_swhid(SwhidObjectType.DIRECTORY, value) - - def _check_swhid(self, expected_object_type, swhid): - if isinstance(swhid, str): - raise ValueError(f"Expected SWHID, got a string: {swhid}") - - if swhid.object_type != expected_object_type: + if value.object_type != SwhidObjectType.DIRECTORY: raise ValueError( - f"Expected SWHID type '{expected_object_type.name.lower()}', " - f"got '{swhid.object_type.name.lower()}' in {swhid}" + f"Expected SWHID type 'directory', " + f"got '{value.object_type.name.lower()}' in {value}" ) def to_dict(self): @@ -1592,16 +1828,18 @@ ) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators) class ExtID(HashableObject, BaseModel): object_type: Final = "extid" - extid_type = attr.ib(type=str, validator=type_validator()) - extid = attr.ib(type=bytes, validator=type_validator()) - target = attr.ib(type=CoreSWHID, validator=type_validator()) - extid_version = attr.ib(type=int, validator=type_validator(), default=0) + extid_type = attr.ib(type=str, validator=generic_type_validator) + extid = attr.ib(type=bytes, validator=generic_type_validator) + target = attr.ib(type=CoreSWHID, validator=generic_type_validator) + extid_version = attr.ib(type=int, validator=generic_type_validator, default=0) - id = attr.ib(type=Sha1Git, validator=type_validator(), default=b"", repr=hash_repr) + id = attr.ib( + type=Sha1Git, validator=generic_type_validator, default=b"", repr=hash_repr + ) @classmethod def from_dict(cls, d): diff --git a/swh/model/tests/test_model.py b/swh/model/tests/test_model.py --- a/swh/model/tests/test_model.py +++ b/swh/model/tests/test_model.py @@ -43,7 +43,7 @@ TargetType, Timestamp, TimestampWithTimezone, - type_validator, + optimized_validator, ) import swh.model.swhids from swh.model.swhids import CoreSWHID, ExtendedSWHID, ObjectType @@ -167,6 +167,16 @@ [("foo", "bar"), ("", ""), _custom_namedtuple("", ""), _custom_tuple(("", ""))], [("foo",), ("foo", "bar", "baz"), ("foo", 42), (42, "foo")], ), + ( + Tuple[bytes, bytes], + [ + (b"foo", b"bar"), + (b"", b""), + _custom_namedtuple(b"", b""), + _custom_tuple((b"", b"")), + ], + [(b"foo",), (b"foo", b"bar", b"baz"), (b"foo", 42), (42, b"foo")], + ), ( Tuple[str, ...], [ @@ -275,8 +285,9 @@ for value in values ], ) -def test_type_validator_valid(type_, value): - type_validator()(None, attr.ib(type=type_), value) +def test_optimized_type_validator_valid(type_, value): + validator = optimized_validator(type_) + validator(None, attr.ib(type=type_), value) @pytest.mark.parametrize( @@ -287,9 +298,10 @@ for value in values ], ) -def test_type_validator_invalid(type_, value): +def test_optimized_type_validator_invalid(type_, value): + validator = optimized_validator(type_) with pytest.raises(AttributeTypeError): - type_validator()(None, attr.ib(type=type_), value) + validator(None, attr.ib(type=type_), value) @pytest.mark.parametrize("object_type, objects", TEST_OBJECTS.items())