diff --git a/swh/core/tests/test_utils.py b/swh/core/tests/test_utils.py --- a/swh/core/tests/test_utils.py +++ b/swh/core/tests/test_utils.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2018 The Software Heritage developers +# Copyright (C) 2015-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -54,6 +54,50 @@ assert out == [[9, 8, 7, 6], [5, 4, 3, 2], [1]] +def test_iter_chunks(): + def make_cb(iterable): + """Given an iterable, returns a function whose successive calls return the + items of the iterable. + + ``make_cb(it)`` is equivalent to ``unittest.mock.MagicMock(side_effect=it)``, + except it returns empty values after :exc:`StopIteration` is raised. + """ + iterator = iter(iterable) + + def f(): + try: + return next(iterator) + except StopIteration: + return "" + + return f + + def chunks(input_, remainder): + return list(utils.iter_chunks(make_cb(input_), 3, remainder=remainder)) + + # all even, remainder=False + assert chunks(["ab", "cd", "ef"], False) == ["abc", "def"] + assert chunks(["abc", "def"], False) == ["abc", "def"] + assert chunks(["abcd", "ef"], False) == ["abc", "def"] + + # all even, remainder=True + assert chunks(["ab", "cd", "ef"], True) == ["abc", "def"] + assert chunks(["abc", "def"], True) == ["abc", "def"] + assert chunks(["abcd", "ef"], True) == ["abc", "def"] + + # uneven, remainder=False + assert chunks(["ab", "cd", "ef", "g"], False) == ["abc", "def"] + assert chunks(["ab", "cd", "efg"], False) == ["abc", "def"] + assert chunks(["abc", "def", "g"], False) == ["abc", "def"] + assert chunks(["abcd", "ef", "g"], False) == ["abc", "def"] + + # uneven, remainder=True + assert chunks(["ab", "cd", "ef", "g"], True) == ["abc", "def", "g"] + assert chunks(["ab", "cd", "efg"], True) == ["abc", "def", "g"] + assert chunks(["abc", "def", "g"], True) == ["abc", "def", "g"] + assert chunks(["abcd", "ef", "g"], True) == ["abc", "def", "g"] + + def test_backslashescape_errors(): raw_data_err = b"abcd\x80" with pytest.raises(UnicodeDecodeError): diff --git a/swh/core/utils.py b/swh/core/utils.py --- a/swh/core/utils.py +++ b/swh/core/utils.py @@ -1,4 +1,4 @@ -# Copyright (C) 2016-2017 The Software Heritage developers +# Copyright (C) 2016-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -8,7 +8,7 @@ import itertools import os import re -from typing import Tuple +from typing import Callable, Iterable, Tuple, TypeVar @contextmanager @@ -26,13 +26,18 @@ def grouper(iterable, n): - """Collect data into fixed-length size iterables. The last block might - contain less elements as it will hold only the remaining number - of elements. + """ + Collect data into fixed-length size iterables. The last block might + contain less elements as it will hold only the remaining number + of elements. + + The invariant here is that the number of elements in the input + iterable and the sum of the number of elements of all iterables + generated from this function should be equal. - The invariant here is that the number of elements in the input - iterable and the sum of the number of elements of all iterables - generated from this function should be equal. + If ``iterable`` is an iterable of bytes or strings that you need to join + later, then :func:`iter_chunks`` is preferable, as it avoids this join + by slicing directly. Args: iterable (Iterable): an iterable @@ -49,6 +54,53 @@ yield (d for d in _data if d is not stop_value) +TStr = TypeVar("TStr", bytes, str) + + +def iter_chunks( + f: Callable[[], TStr], chunk_size: int, *, remainder: bool = False +) -> Iterable[TStr]: + """ + Reads ``bytes`` objects (resp. ``str`` objects) from the callable ``f``, + and yields them as chunks of exactly ``chunk_size`` bytes (resp. characters). + + ``f`` is typically a method like :meth:`io.RawIOBase.read`; + which does only guarantees an upper bound on the size; whereas this + function returns exactly the size. + + Args: + f: the callable generating data + chunk_size: the exact size of chunks to return + remainder: if True, a last chunk with size strictly smaller than ``chunk_size`` + may be returned, if the data stream from ``f`` had a length that is not + a multiple of ``chunk_size`` + """ + buf = None + while True: + assert buf is None or len(buf) < chunk_size + new_data = f() + if not new_data: + if remainder and buf: + yield buf # may be shorter than ``chunk_size`` + return + + if buf: + buf += new_data + else: + # spares a copy + buf = new_data + + new_buf = None + for i in range(0, len(buf), chunk_size): + chunk = buf[i : i + chunk_size] + if len(chunk) == chunk_size: + yield chunk + else: + assert not new_buf + new_buf = chunk + buf = new_buf + + def backslashescape_errors(exception): if isinstance(exception, UnicodeDecodeError): bad_data = exception.object[exception.start : exception.end]