diff --git a/swh/core/tests/test_utils.py b/swh/core/tests/test_utils.py index a2a4397..dfc0d46 100644 --- a/swh/core/tests/test_utils.py +++ b/swh/core/tests/test_utils.py @@ -1,116 +1,142 @@ -# Copyright (C) 2015-2017 The Software Heritage developers +# Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import unittest from swh.core import utils class UtilsLib(unittest.TestCase): def test_grouper(self): # given actual_data = utils.grouper((i for i in range(0, 9)), 2) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks self.assertEqual(out, [[0, 1], [2, 3], [4, 5], [6, 7], [8]]) # given actual_data = utils.grouper((i for i in range(9, 0, -1)), 4) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks self.assertEqual(out, [[9, 8, 7, 6], [5, 4, 3, 2], [1]]) + def test_grouper_with_fillvalue(self): + # given + actual_data = utils.grouper(((i, i+1) for i in range(0, 9)), 2, + fillvalue=(None, None)) + + out = [] + for d in actual_data: + out.append(list(d)) # force generator resolution for checks + + self.assertEqual(out, [ + [(0, 1), (1, 2)], + [(2, 3), (3, 4)], + [(4, 5), (5, 6)], + [(6, 7), (7, 8)], + [(8, 9)]]) + + # given + actual_data = utils.grouper((i for i in range(9, 0, -1)), 4, + fillvalue='a') + + out = [] + for d in actual_data: + out.append(list(d)) # force generator resolution for checks + + self.assertEqual(out, [[9, 8, 7, 6], [5, 4, 3, 2], [1]]) + def test_backslashescape_errors(self): raw_data_err = b'abcd\x80' with self.assertRaises(UnicodeDecodeError): raw_data_err.decode('utf-8', 'strict') self.assertEqual( raw_data_err.decode('utf-8', 'backslashescape'), 'abcd\\x80', ) raw_data_ok = b'abcd\xc3\xa9' self.assertEqual( raw_data_ok.decode('utf-8', 'backslashescape'), raw_data_ok.decode('utf-8', 'strict'), ) unicode_data = 'abcdef\u00a3' self.assertEqual( unicode_data.encode('ascii', 'backslashescape'), b'abcdef\\xa3', ) def test_encode_with_unescape(self): valid_data = '\\x01020304\\x00' valid_data_encoded = b'\x01020304\x00' self.assertEqual( valid_data_encoded, utils.encode_with_unescape(valid_data) ) def test_encode_with_unescape_invalid_escape(self): invalid_data = 'test\\abcd' with self.assertRaises(ValueError) as exc: utils.encode_with_unescape(invalid_data) self.assertIn('invalid escape', exc.exception.args[0]) self.assertIn('position 4', exc.exception.args[0]) def test_decode_with_escape(self): backslashes = b'foo\\bar\\\\baz' backslashes_escaped = 'foo\\\\bar\\\\\\\\baz' self.assertEqual( backslashes_escaped, utils.decode_with_escape(backslashes), ) valid_utf8 = b'foo\xc3\xa2' valid_utf8_escaped = 'foo\u00e2' self.assertEqual( valid_utf8_escaped, utils.decode_with_escape(valid_utf8), ) invalid_utf8 = b'foo\xa2' invalid_utf8_escaped = 'foo\\xa2' self.assertEqual( invalid_utf8_escaped, utils.decode_with_escape(invalid_utf8), ) valid_utf8_nul = b'foo\xc3\xa2\x00' valid_utf8_nul_escaped = 'foo\u00e2\\x00' self.assertEqual( valid_utf8_nul_escaped, utils.decode_with_escape(valid_utf8_nul), ) def test_commonname(self): # when actual_commonname = utils.commonname('/some/where/to/', '/some/where/to/go/to') # then self.assertEqual('go/to', actual_commonname) # when actual_commonname2 = utils.commonname(b'/some/where/to/', b'/some/where/to/go/to') # then self.assertEqual(b'go/to', actual_commonname2) diff --git a/swh/core/utils.py b/swh/core/utils.py index 857f8f6..3fbb073 100644 --- a/swh/core/utils.py +++ b/swh/core/utils.py @@ -1,117 +1,120 @@ # Copyright (C) 2016-2017 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import itertools import codecs import re from contextlib import contextmanager @contextmanager def cwd(path): """Contextually change the working directory to do thy bidding. Then gets back to the original location. """ prev_cwd = os.getcwd() os.chdir(path) try: yield finally: os.chdir(prev_cwd) -def grouper(iterable, n): +def grouper(iterable, n, fillvalue=None): """Collect data into fixed-length chunks or blocks. Args: - iterable: an iterable - n: size of block - fillvalue: value to use for the last block + iterable (Iterable): an iterable + n (int): size of block to slice the iterable into + fillvalue (Optional[Something]): value to use as fill-in + values (typically for the last loop, the iterable might be + less than n elements). None by default but could be anything + relevant for the caller (e.g tuple of (None, None)) Returns: fixed-length chunks of blocks as iterables """ args = [iter(iterable)] * n - for _data in itertools.zip_longest(*args, fillvalue=None): - yield (d for d in _data if d is not None) + for _data in itertools.zip_longest(*args, fillvalue=fillvalue): + yield (d for d in _data if d is not fillvalue) def backslashescape_errors(exception): if isinstance(exception, UnicodeDecodeError): bad_data = exception.object[exception.start:exception.end] escaped = ''.join(r'\x%02x' % x for x in bad_data) return escaped, exception.end return codecs.backslashreplace_errors(exception) codecs.register_error('backslashescape', backslashescape_errors) def encode_with_unescape(value): """Encode an unicode string containing \\x backslash escapes""" slices = [] start = 0 odd_backslashes = False i = 0 while i < len(value): if value[i] == '\\': odd_backslashes = not odd_backslashes else: if odd_backslashes: if value[i] != 'x': raise ValueError('invalid escape for %r at position %d' % (value, i-1)) slices.append( value[start:i-1].replace('\\\\', '\\').encode('utf-8') ) slices.append(bytes.fromhex(value[i+1:i+3])) odd_backslashes = False start = i = i + 3 continue i += 1 slices.append( value[start:i].replace('\\\\', '\\').encode('utf-8') ) return b''.join(slices) def decode_with_escape(value): """Decode a bytestring as utf-8, escaping the bytes of invalid utf-8 sequences as \\x. We also escape NUL bytes as they are invalid in JSON strings. """ # escape backslashes value = value.replace(b'\\', b'\\\\') value = value.replace(b'\x00', b'\\x00') return value.decode('utf-8', 'backslashescape') def commonname(path0, path1, as_str=False): """Compute the commonname between the path0 and path1. """ return path1.split(path0)[1] def numfile_sortkey(fname): """Simple function to sort filenames of the form: nnxxx.ext where nn is a number according to the numbers. Typically used to sort sql/nn-swh-xxx.sql files. """ num, rem = re.match(r'(\d*)(.*)', fname).groups() return (num and int(num) or 99, rem)