diff --git a/swh/core/tests/test_utils.py b/swh/core/tests/test_utils.py --- a/swh/core/tests/test_utils.py +++ b/swh/core/tests/test_utils.py @@ -31,3 +31,37 @@ out.append(list(d)) # force generator resolution for checks self.assertEqual(out, [[9, 8, 7, 6], [5, 4, 3, 2], [1]]) + + @istest + def backslashescape_errors(self): + raw_data_err = b'abcd\x80' + with self.assertRaises(UnicodeDecodeError): + raw_data_err.decode('utf-8', 'strict') + + self.assertEquals( + raw_data_err.decode('utf-8', 'backslashescape'), + 'abcd\\x80', + ) + + raw_data_ok = b'abcd\xc3\xa9' + self.assertEquals( + raw_data_ok.decode('utf-8', 'backslashescape'), + raw_data_ok.decode('utf-8', 'strict'), + ) + + unicode_data = 'abcdef\u00a3' + self.assertEquals( + unicode_data.encode('ascii', 'backslashescape'), + b'abcdef\\xa3', + ) + + @istest + def decode_invalid(self): + # given + invalid_str = b'my invalid \xff \xff string' + + # when + actual_data = utils.decode_with_escape(invalid_str) + + # then + self.assertEqual(actual_data, 'my invalid \\xff \\xff string') diff --git a/swh/core/utils.py b/swh/core/utils.py --- a/swh/core/utils.py +++ b/swh/core/utils.py @@ -5,6 +5,7 @@ import itertools +import codecs def grouper(iterable, n): @@ -22,3 +23,25 @@ args = [iter(iterable)] * n for _data in itertools.zip_longest(*args, fillvalue=None): yield (d for d in _data if d is not None) + + +def backslashescape_errors(exception): + if isinstance(exception, UnicodeDecodeError): + bad_data = exception.object[exception.start:exception.end] + escaped = ''.join(r'\x%02x' % x for x in bad_data) + return escaped, exception.end + + return codecs.backslashreplace_errors(exception) + +codecs.register_error('backslashescape', backslashescape_errors) + + +def decode_with_escape(value): + """Decode a bytestring as utf-8, escaping the bytes of invalid utf-8 sequences + as \\x. We also escape NUL bytes as they are invalid in JSON + strings. + """ + # escape backslashes + value = value.replace(b'\\', b'\\\\') + value = value.replace(b'\x00', b'\\x00') + return value.decode('utf-8', 'backslashescape')