diff --git a/swh/core/tests/test_utils.py b/swh/core/tests/test_utils.py --- a/swh/core/tests/test_utils.py +++ b/swh/core/tests/test_utils.py @@ -31,3 +31,80 @@ out.append(list(d)) # force generator resolution for checks self.assertEqual(out, [[9, 8, 7, 6], [5, 4, 3, 2], [1]]) + + @istest + def backslashescape_errors(self): + raw_data_err = b'abcd\x80' + with self.assertRaises(UnicodeDecodeError): + raw_data_err.decode('utf-8', 'strict') + + self.assertEquals( + raw_data_err.decode('utf-8', 'backslashescape'), + 'abcd\\x80', + ) + + raw_data_ok = b'abcd\xc3\xa9' + self.assertEquals( + raw_data_ok.decode('utf-8', 'backslashescape'), + raw_data_ok.decode('utf-8', 'strict'), + ) + + unicode_data = 'abcdef\u00a3' + self.assertEquals( + unicode_data.encode('ascii', 'backslashescape'), + b'abcdef\\xa3', + ) + + @istest + def encode_with_unescape(self): + valid_data = '\\x01020304\\x00' + valid_data_encoded = b'\x01020304\x00' + + self.assertEquals( + valid_data_encoded, + utils.encode_with_unescape(valid_data) + ) + + @istest + def encode_with_unescape_invalid_escape(self): + invalid_data = 'test\\abcd' + + with self.assertRaises(ValueError) as exc: + utils.encode_with_unescape(invalid_data) + + self.assertIn('invalid escape', exc.exception.args[0]) + self.assertIn('position 4', exc.exception.args[0]) + + @istest + def decode_with_escape(self): + backslashes = b'foo\\bar\\\\baz' + backslashes_escaped = 'foo\\\\bar\\\\\\\\baz' + + self.assertEquals( + backslashes_escaped, + utils.decode_with_escape(backslashes), + ) + + valid_utf8 = b'foo\xc3\xa2' + valid_utf8_escaped = 'foo\u00e2' + + self.assertEquals( + valid_utf8_escaped, + utils.decode_with_escape(valid_utf8), + ) + + invalid_utf8 = b'foo\xa2' + invalid_utf8_escaped = 'foo\\xa2' + + self.assertEquals( + invalid_utf8_escaped, + utils.decode_with_escape(invalid_utf8), + ) + + valid_utf8_nul = b'foo\xc3\xa2\x00' + valid_utf8_nul_escaped = 'foo\u00e2\\x00' + + self.assertEquals( + valid_utf8_nul_escaped, + utils.decode_with_escape(valid_utf8_nul), + ) diff --git a/swh/core/utils.py b/swh/core/utils.py --- a/swh/core/utils.py +++ b/swh/core/utils.py @@ -5,6 +5,7 @@ import itertools +import codecs def grouper(iterable, n): @@ -22,3 +23,57 @@ args = [iter(iterable)] * n for _data in itertools.zip_longest(*args, fillvalue=None): yield (d for d in _data if d is not None) + + +def backslashescape_errors(exception): + if isinstance(exception, UnicodeDecodeError): + bad_data = exception.object[exception.start:exception.end] + escaped = ''.join(r'\x%02x' % x for x in bad_data) + return escaped, exception.end + + return codecs.backslashreplace_errors(exception) + +codecs.register_error('backslashescape', backslashescape_errors) + + +def encode_with_unescape(value): + """Encode an unicode string containing \\x backslash escapes""" + slices = [] + start = 0 + odd_backslashes = False + i = 0 + while i < len(value): + if value[i] == '\\': + odd_backslashes = not odd_backslashes + else: + if odd_backslashes: + if value[i] != 'x': + raise ValueError('invalid escape for %r at position %d' % + (value, i-1)) + slices.append( + value[start:i-1].replace('\\\\', '\\').encode('utf-8') + ) + slices.append(bytes.fromhex(value[i+1:i+3])) + + odd_backslashes = False + start = i = i + 3 + continue + + i += 1 + + slices.append( + value[start:i].replace('\\\\', '\\').encode('utf-8') + ) + + return b''.join(slices) + + +def decode_with_escape(value): + """Decode a bytestring as utf-8, escaping the bytes of invalid utf-8 sequences + as \\x. We also escape NUL bytes as they are invalid in JSON + strings. + """ + # escape backslashes + value = value.replace(b'\\', b'\\\\') + value = value.replace(b'\x00', b'\\x00') + return value.decode('utf-8', 'backslashescape')