diff --git a/swh/core/json.py b/swh/core/json.py index 59902e5..5b181bf 100644 --- a/swh/core/json.py +++ b/swh/core/json.py @@ -1,52 +1,53 @@ # Copyright (C) 2015 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 import datetime from json import JSONDecoder, JSONEncoder import dateutil.parser class SWHJSONEncoder(JSONEncoder): def default(self, o): if isinstance(o, bytes): return { 'swhtype': 'bytes', 'd': base64.b85encode(o).decode('ascii'), } elif isinstance(o, datetime.datetime): return { 'swhtype': 'datetime', 'd': o.isoformat(), } try: return super().default(o) except TypeError as e: try: iterable = iter(o) except TypeError: raise e from None else: return list(iterable) class SWHJSONDecoder(JSONDecoder): def decode_data(self, o): if isinstance(o, dict): - datatype = o.get('swhtype') - if datatype == 'bytes': - return base64.b85decode(o['d']) - elif datatype == 'datetime': - return dateutil.parser.parse(o['d']) + if set(o.keys()) == {'d', 'swhtype'}: + datatype = o['swhtype'] + if datatype == 'bytes': + return base64.b85decode(o['d']) + elif datatype == 'datetime': + return dateutil.parser.parse(o['d']) return {key: self.decode_data(value) for key, value in o.items()} if isinstance(o, list): return [self.decode_data(value) for value in o] else: return o def raw_decode(self, s, idx=0): data, index = super().raw_decode(s, idx) return self.decode_data(data), index diff --git a/swh/core/tests/test_json.py b/swh/core/tests/test_json.py index d42b30f..66d56da 100644 --- a/swh/core/tests/test_json.py +++ b/swh/core/tests/test_json.py @@ -1,46 +1,52 @@ # Copyright (C) 2015 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import json import unittest from nose.tools import istest from swh.core.json import SWHJSONDecoder, SWHJSONEncoder class JSON(unittest.TestCase): def setUp(self): self.tz = datetime.timezone(datetime.timedelta(minutes=118)) self.data = { "bytes": b"123456789\x99\xaf\xff\x00\x12", "datetime_naive": datetime.datetime(2015, 1, 1, 12, 4, 42, 231455), "datetime_tz": datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, tzinfo=self.tz), "datetime_utc": datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, tzinfo=datetime.timezone.utc), + "swhtype": "fake", + "swh_dict": {"swhtype": 42, "d": "test"}, + "random_dict": {"swhtype": 43}, } self.encoded_data = { "bytes": {"swhtype": "bytes", "d": "F)}kWH8wXmIhn8j01^"}, "datetime_naive": {"swhtype": "datetime", "d": "2015-01-01T12:04:42.231455"}, "datetime_tz": {"swhtype": "datetime", "d": "2015-03-04T18:25:13.001234+01:58"}, "datetime_utc": {"swhtype": "datetime", "d": "2015-03-04T18:25:13.001234+00:00"}, + "swhtype": "fake", + "swh_dict": {"swhtype": 42, "d": "test"}, + "random_dict": {"swhtype": 43}, } @istest def round_trip(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.data, json.loads(data, cls=SWHJSONDecoder)) @istest def encode(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.encoded_data, json.loads(data))