diff --git a/swh/core/tests/test_utils.py b/swh/core/tests/test_utils.py index f16b4c0..931647e 100644 --- a/swh/core/tests/test_utils.py +++ b/swh/core/tests/test_utils.py @@ -1,67 +1,110 @@ # Copyright (C) 2015-2016 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import unittest from nose.tools import istest from swh.core import utils class UtilsLib(unittest.TestCase): @istest def grouper(self): # given actual_data = utils.grouper((i for i in range(0, 9)), 2) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks self.assertEqual(out, [[0, 1], [2, 3], [4, 5], [6, 7], [8]]) # given actual_data = utils.grouper((i for i in range(9, 0, -1)), 4) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks self.assertEqual(out, [[9, 8, 7, 6], [5, 4, 3, 2], [1]]) @istest def backslashescape_errors(self): raw_data_err = b'abcd\x80' with self.assertRaises(UnicodeDecodeError): raw_data_err.decode('utf-8', 'strict') self.assertEquals( raw_data_err.decode('utf-8', 'backslashescape'), 'abcd\\x80', ) raw_data_ok = b'abcd\xc3\xa9' self.assertEquals( raw_data_ok.decode('utf-8', 'backslashescape'), raw_data_ok.decode('utf-8', 'strict'), ) unicode_data = 'abcdef\u00a3' self.assertEquals( unicode_data.encode('ascii', 'backslashescape'), b'abcdef\\xa3', ) @istest - def decode_invalid(self): - # given - invalid_str = b'my invalid \xff \xff string' + def encode_with_unescape(self): + valid_data = '\\x01020304\\x00' + valid_data_encoded = b'\x01020304\x00' + + self.assertEquals( + valid_data_encoded, + utils.encode_with_unescape(valid_data) + ) + + @istest + def encode_with_unescape_invalid_escape(self): + invalid_data = 'test\\abcd' - # when - actual_data = utils.decode_with_escape(invalid_str) + with self.assertRaises(ValueError) as exc: + utils.encode_with_unescape(invalid_data) + + self.assertIn('invalid escape', exc.exception.args[0]) + self.assertIn('position 4', exc.exception.args[0]) + + @istest + def decode_with_escape(self): + backslashes = b'foo\\bar\\\\baz' + backslashes_escaped = 'foo\\\\bar\\\\\\\\baz' + + self.assertEquals( + backslashes_escaped, + utils.decode_with_escape(backslashes), + ) - # then - self.assertEqual(actual_data, 'my invalid \\xff \\xff string') + valid_utf8 = b'foo\xc3\xa2' + valid_utf8_escaped = 'foo\u00e2' + + self.assertEquals( + valid_utf8_escaped, + utils.decode_with_escape(valid_utf8), + ) + + invalid_utf8 = b'foo\xa2' + invalid_utf8_escaped = 'foo\\xa2' + + self.assertEquals( + invalid_utf8_escaped, + utils.decode_with_escape(invalid_utf8), + ) + + valid_utf8_nul = b'foo\xc3\xa2\x00' + valid_utf8_nul_escaped = 'foo\u00e2\\x00' + + self.assertEquals( + valid_utf8_nul_escaped, + utils.decode_with_escape(valid_utf8_nul), + ) diff --git a/swh/core/utils.py b/swh/core/utils.py index ef28fc2..e52b8ef 100644 --- a/swh/core/utils.py +++ b/swh/core/utils.py @@ -1,47 +1,79 @@ # Copyright (C) 2016 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import itertools import codecs def grouper(iterable, n): """Collect data into fixed-length chunks or blocks. Args: iterable: an iterable n: size of block fillvalue: value to use for the last block Returns: fixed-length chunks of blocks as iterables """ args = [iter(iterable)] * n for _data in itertools.zip_longest(*args, fillvalue=None): yield (d for d in _data if d is not None) def backslashescape_errors(exception): if isinstance(exception, UnicodeDecodeError): bad_data = exception.object[exception.start:exception.end] escaped = ''.join(r'\x%02x' % x for x in bad_data) return escaped, exception.end return codecs.backslashreplace_errors(exception) codecs.register_error('backslashescape', backslashescape_errors) +def encode_with_unescape(value): + """Encode an unicode string containing \\x backslash escapes""" + slices = [] + start = 0 + odd_backslashes = False + i = 0 + while i < len(value): + if value[i] == '\\': + odd_backslashes = not odd_backslashes + else: + if odd_backslashes: + if value[i] != 'x': + raise ValueError('invalid escape for %r at position %d' % + (value, i-1)) + slices.append( + value[start:i-1].replace('\\\\', '\\').encode('utf-8') + ) + slices.append(bytes.fromhex(value[i+1:i+3])) + + odd_backslashes = False + start = i = i + 3 + continue + + i += 1 + + slices.append( + value[start:i].replace('\\\\', '\\').encode('utf-8') + ) + + return b''.join(slices) + + def decode_with_escape(value): """Decode a bytestring as utf-8, escaping the bytes of invalid utf-8 sequences as \\x. We also escape NUL bytes as they are invalid in JSON strings. """ # escape backslashes value = value.replace(b'\\', b'\\\\') value = value.replace(b'\x00', b'\\x00') return value.decode('utf-8', 'backslashescape')