diff --git a/swh/model/fields/__init__.py b/swh/model/fields/__init__.py index b09b056..d2b3cef 100644 --- a/swh/model/fields/__init__.py +++ b/swh/model/fields/__init__.py @@ -1,13 +1,13 @@ # Copyright (C) 2015 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information # We do our imports here but we don't use them, so flake8 complains # flake8: noqa -from .simple import (validate_type, validate_int, validate_str, +from .simple import (validate_type, validate_int, validate_str, validate_bytes, validate_datetime, validate_enum) from .hashes import (validate_sha1, validate_sha1_git, validate_sha256) from .compound import (validate_against_schema, validate_all_keys, validate_any_key) diff --git a/swh/model/fields/simple.py b/swh/model/fields/simple.py index d850285..0f8b305 100644 --- a/swh/model/fields/simple.py +++ b/swh/model/fields/simple.py @@ -1,75 +1,80 @@ # Copyright (C) 2015 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import numbers from ..exceptions import ValidationError def validate_type(value, type): """Validate that value is an integer""" if not isinstance(value, type): if isinstance(type, tuple): typestr = 'one of %s' % ', '.join(typ.__name__ for typ in type) else: typestr = type.__name__ raise ValidationError( 'Unexpected type %(type)s, expected %(expected_type)s', params={ 'type': value.__class__.__name__, 'expected_type': typestr, }, code='unexpected-type' ) return True def validate_int(value): """Validate that the given value is an int""" return validate_type(value, numbers.Integral) def validate_str(value): """Validate that the given value is a string""" return validate_type(value, str) +def validate_bytes(value): + """Validate that the given value is a bytes object""" + return validate_type(value, bytes) + + def validate_datetime(value): """Validate that the given value is either a datetime, or a numeric number of seconds since the UNIX epoch.""" errors = [] try: validate_type(value, (datetime.datetime, numbers.Real)) except ValidationError as e: errors.append(e) if isinstance(value, datetime.datetime) and value.tzinfo is None: errors.append(ValidationError( 'Datetimes must be timezone-aware in swh', code='datetime-without-tzinfo', )) if errors: raise ValidationError(errors) return True def validate_enum(value, expected_values): """Validate that value is contained in expected_values""" if value not in expected_values: raise ValidationError( 'Unexpected value %(value)s, expected one of %(expected_values)', params={ 'value': value, 'expected_values': ', '.join(sorted(expected_values)), }, code='unexpected-value', ) return True diff --git a/swh/model/tests/fields/test_simple.py b/swh/model/tests/fields/test_simple.py index d6c1591..9af424f 100644 --- a/swh/model/tests/fields/test_simple.py +++ b/swh/model/tests/fields/test_simple.py @@ -1,96 +1,128 @@ # Copyright (C) 2015 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import unittest from nose.tools import istest from swh.model.exceptions import ValidationError from swh.model.fields import simple class ValidateSimple(unittest.TestCase): def setUp(self): self.valid_str = 'I am a valid string' + self.valid_bytes = b'I am a valid bytes object' + self.enum_values = {'an enum value', 'other', 'and another'} self.invalid_enum_value = 'invalid enum value' self.valid_int = 42 self.valid_real = 42.42 self.valid_datetime = datetime.datetime(1999, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) self.invalid_datetime_notz = datetime.datetime(1999, 1, 1, 12, 0, 0) @istest def validate_int(self): self.assertTrue(simple.validate_int(self.valid_int)) @istest def validate_int_invalid_type(self): with self.assertRaises(ValidationError) as cm: simple.validate_int(self.valid_str) exc = cm.exception self.assertEqual(exc.code, 'unexpected-type') self.assertEqual(exc.params['expected_type'], 'Integral') self.assertEqual(exc.params['type'], 'str') @istest def validate_str(self): self.assertTrue(simple.validate_str(self.valid_str)) @istest def validate_str_invalid_type(self): with self.assertRaises(ValidationError) as cm: simple.validate_str(self.valid_int) exc = cm.exception self.assertEqual(exc.code, 'unexpected-type') self.assertEqual(exc.params['expected_type'], 'str') self.assertEqual(exc.params['type'], 'int') + with self.assertRaises(ValidationError) as cm: + simple.validate_str(self.valid_bytes) + + exc = cm.exception + self.assertEqual(exc.code, 'unexpected-type') + self.assertEqual(exc.params['expected_type'], 'str') + self.assertEqual(exc.params['type'], 'bytes') + + @istest + def validate_bytes(self): + self.assertTrue(simple.validate_bytes(self.valid_bytes)) + + @istest + def validate_bytes_invalid_type(self): + with self.assertRaises(ValidationError) as cm: + simple.validate_bytes(self.valid_int) + + exc = cm.exception + self.assertEqual(exc.code, 'unexpected-type') + self.assertEqual(exc.params['expected_type'], 'bytes') + self.assertEqual(exc.params['type'], 'int') + + with self.assertRaises(ValidationError) as cm: + simple.validate_bytes(self.valid_str) + + exc = cm.exception + self.assertEqual(exc.code, 'unexpected-type') + self.assertEqual(exc.params['expected_type'], 'bytes') + self.assertEqual(exc.params['type'], 'str') + @istest def validate_datetime(self): self.assertTrue(simple.validate_datetime(self.valid_datetime)) self.assertTrue(simple.validate_datetime(self.valid_int)) self.assertTrue(simple.validate_datetime(self.valid_real)) @istest def validate_datetime_invalid_type(self): with self.assertRaises(ValidationError) as cm: simple.validate_datetime(self.valid_str) exc = cm.exception self.assertEqual(exc.code, 'unexpected-type') self.assertEqual(exc.params['expected_type'], 'one of datetime, Real') self.assertEqual(exc.params['type'], 'str') @istest def validate_datetime_invalide_tz(self): with self.assertRaises(ValidationError) as cm: simple.validate_datetime(self.invalid_datetime_notz) exc = cm.exception self.assertEqual(exc.code, 'datetime-without-tzinfo') @istest def validate_enum(self): for value in self.enum_values: self.assertTrue(simple.validate_enum(value, self.enum_values)) @istest def validate_enum_invalid_value(self): with self.assertRaises(ValidationError) as cm: simple.validate_enum(self.invalid_enum_value, self.enum_values) exc = cm.exception self.assertEqual(exc.code, 'unexpected-value') self.assertEqual(exc.params['value'], self.invalid_enum_value) self.assertEqual(exc.params['expected_values'], ', '.join(sorted(self.enum_values)))