diff --git a/swh/core/json.py b/swh/core/json.py new file mode 100644 index 0000000..59902e5 --- /dev/null +++ b/swh/core/json.py @@ -0,0 +1,52 @@ +# Copyright (C) 2015 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import base64 +import datetime +from json import JSONDecoder, JSONEncoder + +import dateutil.parser + + +class SWHJSONEncoder(JSONEncoder): + def default(self, o): + if isinstance(o, bytes): + return { + 'swhtype': 'bytes', + 'd': base64.b85encode(o).decode('ascii'), + } + elif isinstance(o, datetime.datetime): + return { + 'swhtype': 'datetime', + 'd': o.isoformat(), + } + try: + return super().default(o) + except TypeError as e: + try: + iterable = iter(o) + except TypeError: + raise e from None + else: + return list(iterable) + + +class SWHJSONDecoder(JSONDecoder): + def decode_data(self, o): + if isinstance(o, dict): + datatype = o.get('swhtype') + if datatype == 'bytes': + return base64.b85decode(o['d']) + elif datatype == 'datetime': + return dateutil.parser.parse(o['d']) + return {key: self.decode_data(value) for key, value in o.items()} + if isinstance(o, list): + return [self.decode_data(value) for value in o] + else: + return o + + def raw_decode(self, s, idx=0): + data, index = super().raw_decode(s, idx) + return self.decode_data(data), index diff --git a/swh/core/tests/test_json.py b/swh/core/tests/test_json.py new file mode 100644 index 0000000..d42b30f --- /dev/null +++ b/swh/core/tests/test_json.py @@ -0,0 +1,46 @@ +# Copyright (C) 2015 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import datetime +import json +import unittest + +from nose.tools import istest + +from swh.core.json import SWHJSONDecoder, SWHJSONEncoder + + +class JSON(unittest.TestCase): + def setUp(self): + self.tz = datetime.timezone(datetime.timedelta(minutes=118)) + + self.data = { + "bytes": b"123456789\x99\xaf\xff\x00\x12", + "datetime_naive": datetime.datetime(2015, 1, 1, 12, 4, 42, 231455), + "datetime_tz": datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, + tzinfo=self.tz), + "datetime_utc": datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, + tzinfo=datetime.timezone.utc), + } + + self.encoded_data = { + "bytes": {"swhtype": "bytes", "d": "F)}kWH8wXmIhn8j01^"}, + "datetime_naive": {"swhtype": "datetime", + "d": "2015-01-01T12:04:42.231455"}, + "datetime_tz": {"swhtype": "datetime", + "d": "2015-03-04T18:25:13.001234+01:58"}, + "datetime_utc": {"swhtype": "datetime", + "d": "2015-03-04T18:25:13.001234+00:00"}, + } + + @istest + def round_trip(self): + data = json.dumps(self.data, cls=SWHJSONEncoder) + self.assertEqual(self.data, json.loads(data, cls=SWHJSONDecoder)) + + @istest + def encode(self): + data = json.dumps(self.data, cls=SWHJSONEncoder) + self.assertEqual(self.encoded_data, json.loads(data))