diff --git a/swh/core/tests/test_utils.py b/swh/core/tests/test_utils.py --- a/swh/core/tests/test_utils.py +++ b/swh/core/tests/test_utils.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2017 The Software Heritage developers +# Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -29,6 +29,32 @@ self.assertEqual(out, [[9, 8, 7, 6], [5, 4, 3, 2], [1]]) + def test_grouper_with_fillvalue(self): + # given + actual_data = utils.grouper(((i, i+1) for i in range(0, 9)), 2, + fillvalue=(None, None)) + + out = [] + for d in actual_data: + out.append(list(d)) # force generator resolution for checks + + self.assertEqual(out, [ + [(0, 1), (1, 2)], + [(2, 3), (3, 4)], + [(4, 5), (5, 6)], + [(6, 7), (7, 8)], + [(8, 9)]]) + + # given + actual_data = utils.grouper((i for i in range(9, 0, -1)), 4, + fillvalue='a') + + out = [] + for d in actual_data: + out.append(list(d)) # force generator resolution for checks + + self.assertEqual(out, [[9, 8, 7, 6], [5, 4, 3, 2], [1]]) + def test_backslashescape_errors(self): raw_data_err = b'abcd\x80' with self.assertRaises(UnicodeDecodeError): diff --git a/swh/core/utils.py b/swh/core/utils.py --- a/swh/core/utils.py +++ b/swh/core/utils.py @@ -25,21 +25,24 @@ os.chdir(prev_cwd) -def grouper(iterable, n): +def grouper(iterable, n, fillvalue=None): """Collect data into fixed-length chunks or blocks. Args: - iterable: an iterable - n: size of block - fillvalue: value to use for the last block + iterable (Iterable): an iterable + n (int): size of block to slice the iterable into + fillvalue (Optional[Something]): value to use as fill-in + values (typically for the last loop, the iterable might be + less than n elements). None by default but could be anything + relevant for the caller (e.g tuple of (None, None)) Returns: fixed-length chunks of blocks as iterables """ args = [iter(iterable)] * n - for _data in itertools.zip_longest(*args, fillvalue=None): - yield (d for d in _data if d is not None) + for _data in itertools.zip_longest(*args, fillvalue=fillvalue): + yield (d for d in _data if d is not fillvalue) def backslashescape_errors(exception):