diff --git a/swh/core/tests/test_utils.py b/swh/core/tests/test_utils.py --- a/swh/core/tests/test_utils.py +++ b/swh/core/tests/test_utils.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2017 The Software Heritage developers +# Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -29,6 +29,30 @@ self.assertEqual(out, [[9, 8, 7, 6], [5, 4, 3, 2], [1]]) + def test_grouper_with_stop_value(self): + # given + actual_data = utils.grouper(((i, i+1) for i in range(0, 9)), 2) + + out = [] + for d in actual_data: + out.append(list(d)) # force generator resolution for checks + + self.assertEqual(out, [ + [(0, 1), (1, 2)], + [(2, 3), (3, 4)], + [(4, 5), (5, 6)], + [(6, 7), (7, 8)], + [(8, 9)]]) + + # given + actual_data = utils.grouper((i for i in range(9, 0, -1)), 4) + + out = [] + for d in actual_data: + out.append(list(d)) # force generator resolution for checks + + self.assertEqual(out, [[9, 8, 7, 6], [5, 4, 3, 2], [1]]) + def test_backslashescape_errors(self): raw_data_err = b'abcd\x80' with self.assertRaises(UnicodeDecodeError): diff --git a/swh/core/utils.py b/swh/core/utils.py --- a/swh/core/utils.py +++ b/swh/core/utils.py @@ -26,20 +26,27 @@ def grouper(iterable, n): - """Collect data into fixed-length chunks or blocks. + """Collect data into fixed-length size iterables. The last block might + contain less elements as it will hold only the remaining number + of elements. + + The invariant here is that the number of elements in the input + iterable and the sum of the number of elements of all iterables + generated from this function should be equal. Args: - iterable: an iterable - n: size of block - fillvalue: value to use for the last block + iterable (Iterable): an iterable + n (int): size of block to slice the iterable into - Returns: - fixed-length chunks of blocks as iterables + Yields: + fixed-length blocks as iterables. As mentioned, the last + iterable might be less populated. """ args = [iter(iterable)] * n - for _data in itertools.zip_longest(*args, fillvalue=None): - yield (d for d in _data if d is not None) + stop_value = object() + for _data in itertools.zip_longest(*args, fillvalue=stop_value): + yield (d for d in _data if d is not stop_value) def backslashescape_errors(exception):