diff --git a/swh/core/tests/test_utils.py b/swh/core/tests/test_utils.py --- a/swh/core/tests/test_utils.py +++ b/swh/core/tests/test_utils.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2018 The Software Heritage developers +# Copyright (C) 2015-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -54,6 +54,37 @@ assert out == [[9, 8, 7, 6], [5, 4, 3, 2], [1]] +def test_iter_chunks(): + def chunks(input_, remainder): + return list(utils.iter_chunks(input_, 3, remainder=remainder)) + + # all even, remainder=False + assert chunks(["ab", "cd", "ef"], False) == ["abc", "def"] + assert chunks(["abc", "def"], False) == ["abc", "def"] + assert chunks(["abcd", "ef"], False) == ["abc", "def"] + + # all even, remainder=True + assert chunks(["ab", "cd", "ef"], True) == ["abc", "def"] + assert chunks(["abc", "def"], True) == ["abc", "def"] + assert chunks(["abcd", "ef"], True) == ["abc", "def"] + + # uneven, remainder=False + assert chunks([], False) == [] + assert chunks(["ab"], False) == [] + assert chunks(["ab", "cd", "ef", "g"], False) == ["abc", "def"] + assert chunks(["ab", "cd", "efg"], False) == ["abc", "def"] + assert chunks(["abc", "def", "g"], False) == ["abc", "def"] + assert chunks(["abcd", "ef", "g"], False) == ["abc", "def"] + + # uneven, remainder=True + assert chunks([], True) == [] + assert chunks(["ab"], True) == ["ab"] + assert chunks(["ab", "cd", "ef", "g"], True) == ["abc", "def", "g"] + assert chunks(["ab", "cd", "efg"], True) == ["abc", "def", "g"] + assert chunks(["abc", "def", "g"], True) == ["abc", "def", "g"] + assert chunks(["abcd", "ef", "g"], True) == ["abc", "def", "g"] + + def test_backslashescape_errors(): raw_data_err = b"abcd\x80" with pytest.raises(UnicodeDecodeError): diff --git a/swh/core/utils.py b/swh/core/utils.py --- a/swh/core/utils.py +++ b/swh/core/utils.py @@ -1,4 +1,4 @@ -# Copyright (C) 2016-2017 The Software Heritage developers +# Copyright (C) 2016-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -8,7 +8,7 @@ import itertools import os import re -from typing import Tuple +from typing import Iterable, Tuple, TypeVar @contextmanager @@ -26,13 +26,18 @@ def grouper(iterable, n): - """Collect data into fixed-length size iterables. The last block might - contain less elements as it will hold only the remaining number - of elements. + """ + Collect data into fixed-length size iterables. The last block might + contain less elements as it will hold only the remaining number + of elements. + + The invariant here is that the number of elements in the input + iterable and the sum of the number of elements of all iterables + generated from this function should be equal. - The invariant here is that the number of elements in the input - iterable and the sum of the number of elements of all iterables - generated from this function should be equal. + If ``iterable`` is an iterable of bytes or strings that you need to join + later, then :func:`iter_chunks`` is preferable, as it avoids this join + by slicing directly. Args: iterable (Iterable): an iterable @@ -49,6 +54,55 @@ yield (d for d in _data if d is not stop_value) +TStr = TypeVar("TStr", bytes, str) + + +def iter_chunks( + iterable: Iterable[TStr], chunk_size: int, *, remainder: bool = False +) -> Iterable[TStr]: + """ + Reads ``bytes`` objects (resp. ``str`` objects) from the ``iterable``, + and yields them as chunks of exactly ``chunk_size`` bytes (resp. characters). + + ``iterable`` is typically obtained by repeatedly calling a method like + :meth:`io.RawIOBase.read`; which does only guarantees an upper bound on the size; + whereas this function returns chunks of exactly the size. + + Args: + iterable: the input data + chunk_size: the exact size of chunks to return + remainder: if True, a last chunk with size strictly smaller than ``chunk_size`` + may be returned, if the data stream from the ``iterable`` had a length that + is not a multiple of ``chunk_size`` + """ + buf = None + iterator = iter(iterable) + while True: + assert buf is None or len(buf) < chunk_size + try: + new_data = next(iterator) + except StopIteration: + if remainder and buf: + yield buf # may be shorter than ``chunk_size`` + return + + if buf: + buf += new_data + else: + # spares a copy + buf = new_data + + new_buf = None + for i in range(0, len(buf), chunk_size): + chunk = buf[i : i + chunk_size] + if len(chunk) == chunk_size: + yield chunk + else: + assert not new_buf + new_buf = chunk + buf = new_buf + + def backslashescape_errors(exception): if isinstance(exception, UnicodeDecodeError): bad_data = exception.object[exception.start : exception.end]