diff --git a/swh/core/tests/test_utils.py b/swh/core/tests/test_utils.py index dfc0d46..79ab62a 100644 --- a/swh/core/tests/test_utils.py +++ b/swh/core/tests/test_utils.py @@ -1,142 +1,142 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import unittest from swh.core import utils class UtilsLib(unittest.TestCase): def test_grouper(self): # given actual_data = utils.grouper((i for i in range(0, 9)), 2) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks self.assertEqual(out, [[0, 1], [2, 3], [4, 5], [6, 7], [8]]) # given actual_data = utils.grouper((i for i in range(9, 0, -1)), 4) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks self.assertEqual(out, [[9, 8, 7, 6], [5, 4, 3, 2], [1]]) - def test_grouper_with_fillvalue(self): + def test_grouper_with_stop_value(self): # given actual_data = utils.grouper(((i, i+1) for i in range(0, 9)), 2, - fillvalue=(None, None)) + stop_value=(None, None)) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks self.assertEqual(out, [ [(0, 1), (1, 2)], [(2, 3), (3, 4)], [(4, 5), (5, 6)], [(6, 7), (7, 8)], [(8, 9)]]) # given actual_data = utils.grouper((i for i in range(9, 0, -1)), 4, - fillvalue='a') + stop_value='a') out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks self.assertEqual(out, [[9, 8, 7, 6], [5, 4, 3, 2], [1]]) def test_backslashescape_errors(self): raw_data_err = b'abcd\x80' with self.assertRaises(UnicodeDecodeError): raw_data_err.decode('utf-8', 'strict') self.assertEqual( raw_data_err.decode('utf-8', 'backslashescape'), 'abcd\\x80', ) raw_data_ok = b'abcd\xc3\xa9' self.assertEqual( raw_data_ok.decode('utf-8', 'backslashescape'), raw_data_ok.decode('utf-8', 'strict'), ) unicode_data = 'abcdef\u00a3' self.assertEqual( unicode_data.encode('ascii', 'backslashescape'), b'abcdef\\xa3', ) def test_encode_with_unescape(self): valid_data = '\\x01020304\\x00' valid_data_encoded = b'\x01020304\x00' self.assertEqual( valid_data_encoded, utils.encode_with_unescape(valid_data) ) def test_encode_with_unescape_invalid_escape(self): invalid_data = 'test\\abcd' with self.assertRaises(ValueError) as exc: utils.encode_with_unescape(invalid_data) self.assertIn('invalid escape', exc.exception.args[0]) self.assertIn('position 4', exc.exception.args[0]) def test_decode_with_escape(self): backslashes = b'foo\\bar\\\\baz' backslashes_escaped = 'foo\\\\bar\\\\\\\\baz' self.assertEqual( backslashes_escaped, utils.decode_with_escape(backslashes), ) valid_utf8 = b'foo\xc3\xa2' valid_utf8_escaped = 'foo\u00e2' self.assertEqual( valid_utf8_escaped, utils.decode_with_escape(valid_utf8), ) invalid_utf8 = b'foo\xa2' invalid_utf8_escaped = 'foo\\xa2' self.assertEqual( invalid_utf8_escaped, utils.decode_with_escape(invalid_utf8), ) valid_utf8_nul = b'foo\xc3\xa2\x00' valid_utf8_nul_escaped = 'foo\u00e2\\x00' self.assertEqual( valid_utf8_nul_escaped, utils.decode_with_escape(valid_utf8_nul), ) def test_commonname(self): # when actual_commonname = utils.commonname('/some/where/to/', '/some/where/to/go/to') # then self.assertEqual('go/to', actual_commonname) # when actual_commonname2 = utils.commonname(b'/some/where/to/', b'/some/where/to/go/to') # then self.assertEqual(b'go/to', actual_commonname2) diff --git a/swh/core/utils.py b/swh/core/utils.py index 3fbb073..3d790fc 100644 --- a/swh/core/utils.py +++ b/swh/core/utils.py @@ -1,120 +1,127 @@ # Copyright (C) 2016-2017 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import itertools import codecs import re from contextlib import contextmanager @contextmanager def cwd(path): """Contextually change the working directory to do thy bidding. Then gets back to the original location. """ prev_cwd = os.getcwd() os.chdir(path) try: yield finally: os.chdir(prev_cwd) -def grouper(iterable, n, fillvalue=None): - """Collect data into fixed-length chunks or blocks. +def grouper(iterable, n, stop_value=None): + """Collect data into fixed-length size iterables. The last block might + contain less elements as it will hold only the remaining number + of elements. + + The invariant here is that the number of elements in the input + iterable and the sum of the number of elements of all iterables + generated from this function should be equal. Args: iterable (Iterable): an iterable n (int): size of block to slice the iterable into - fillvalue (Optional[Something]): value to use as fill-in - values (typically for the last loop, the iterable might be - less than n elements). None by default but could be anything - relevant for the caller (e.g tuple of (None, None)) + stop_value (Optional[Something]): value to use as stop value + for the last iterable. That iterable might be less than n + elements. None by default but could be anything relevant for + the caller (e.g tuple of (None, None)) - Returns: - fixed-length chunks of blocks as iterables + Yields: + fixed-length blocks as iterables. As mentioned, the last + iterable might be less populated. """ args = [iter(iterable)] * n - for _data in itertools.zip_longest(*args, fillvalue=fillvalue): - yield (d for d in _data if d is not fillvalue) + for _data in itertools.zip_longest(*args, fillvalue=stop_value): + yield (d for d in _data if d is not stop_value) def backslashescape_errors(exception): if isinstance(exception, UnicodeDecodeError): bad_data = exception.object[exception.start:exception.end] escaped = ''.join(r'\x%02x' % x for x in bad_data) return escaped, exception.end return codecs.backslashreplace_errors(exception) codecs.register_error('backslashescape', backslashescape_errors) def encode_with_unescape(value): """Encode an unicode string containing \\x backslash escapes""" slices = [] start = 0 odd_backslashes = False i = 0 while i < len(value): if value[i] == '\\': odd_backslashes = not odd_backslashes else: if odd_backslashes: if value[i] != 'x': raise ValueError('invalid escape for %r at position %d' % (value, i-1)) slices.append( value[start:i-1].replace('\\\\', '\\').encode('utf-8') ) slices.append(bytes.fromhex(value[i+1:i+3])) odd_backslashes = False start = i = i + 3 continue i += 1 slices.append( value[start:i].replace('\\\\', '\\').encode('utf-8') ) return b''.join(slices) def decode_with_escape(value): """Decode a bytestring as utf-8, escaping the bytes of invalid utf-8 sequences as \\x. We also escape NUL bytes as they are invalid in JSON strings. """ # escape backslashes value = value.replace(b'\\', b'\\\\') value = value.replace(b'\x00', b'\\x00') return value.decode('utf-8', 'backslashescape') def commonname(path0, path1, as_str=False): """Compute the commonname between the path0 and path1. """ return path1.split(path0)[1] def numfile_sortkey(fname): """Simple function to sort filenames of the form: nnxxx.ext where nn is a number according to the numbers. Typically used to sort sql/nn-swh-xxx.sql files. """ num, rem = re.match(r'(\d*)(.*)', fname).groups() return (num and int(num) or 99, rem)