diff --git a/swh/core/tests/test_utils.py b/swh/core/tests/test_utils.py index 1933d38..e03551d 100644 --- a/swh/core/tests/test_utils.py +++ b/swh/core/tests/test_utils.py @@ -1,138 +1,169 @@ -# Copyright (C) 2015-2018 The Software Heritage developers +# Copyright (C) 2015-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest from swh.core import utils def test_grouper(): # given actual_data = utils.grouper((i for i in range(0, 9)), 2) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks assert out == [[0, 1], [2, 3], [4, 5], [6, 7], [8]] # given actual_data = utils.grouper((i for i in range(9, 0, -1)), 4) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks assert out == [[9, 8, 7, 6], [5, 4, 3, 2], [1]] def test_grouper_with_stop_value(): # given actual_data = utils.grouper(((i, i + 1) for i in range(0, 9)), 2) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks assert out == [ [(0, 1), (1, 2)], [(2, 3), (3, 4)], [(4, 5), (5, 6)], [(6, 7), (7, 8)], [(8, 9)], ] # given actual_data = utils.grouper((i for i in range(9, 0, -1)), 4) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks assert out == [[9, 8, 7, 6], [5, 4, 3, 2], [1]] +def test_iter_chunks(): + def chunks(input_, remainder): + return list(utils.iter_chunks(input_, 3, remainder=remainder)) + + # all even, remainder=False + assert chunks(["ab", "cd", "ef"], False) == ["abc", "def"] + assert chunks(["abc", "def"], False) == ["abc", "def"] + assert chunks(["abcd", "ef"], False) == ["abc", "def"] + + # all even, remainder=True + assert chunks(["ab", "cd", "ef"], True) == ["abc", "def"] + assert chunks(["abc", "def"], True) == ["abc", "def"] + assert chunks(["abcd", "ef"], True) == ["abc", "def"] + + # uneven, remainder=False + assert chunks([], False) == [] + assert chunks(["ab"], False) == [] + assert chunks(["ab", "cd", "ef", "g"], False) == ["abc", "def"] + assert chunks(["ab", "cd", "efg"], False) == ["abc", "def"] + assert chunks(["abc", "def", "g"], False) == ["abc", "def"] + assert chunks(["abcd", "ef", "g"], False) == ["abc", "def"] + + # uneven, remainder=True + assert chunks([], True) == [] + assert chunks(["ab"], True) == ["ab"] + assert chunks(["ab", "cd", "ef", "g"], True) == ["abc", "def", "g"] + assert chunks(["ab", "cd", "efg"], True) == ["abc", "def", "g"] + assert chunks(["abc", "def", "g"], True) == ["abc", "def", "g"] + assert chunks(["abcd", "ef", "g"], True) == ["abc", "def", "g"] + + def test_backslashescape_errors(): raw_data_err = b"abcd\x80" with pytest.raises(UnicodeDecodeError): raw_data_err.decode("utf-8", "strict") assert raw_data_err.decode("utf-8", "backslashescape") == "abcd\\x80" raw_data_ok = b"abcd\xc3\xa9" assert raw_data_ok.decode("utf-8", "backslashescape") == raw_data_ok.decode( "utf-8", "strict" ) unicode_data = "abcdef\u00a3" assert unicode_data.encode("ascii", "backslashescape") == b"abcdef\\xa3" def test_encode_with_unescape(): valid_data = "\\x01020304\\x00" valid_data_encoded = b"\x01020304\x00" assert valid_data_encoded == utils.encode_with_unescape(valid_data) def test_encode_with_unescape_invalid_escape(): invalid_data = "test\\abcd" with pytest.raises(ValueError) as exc: utils.encode_with_unescape(invalid_data) assert "invalid escape" in exc.value.args[0] assert "position 4" in exc.value.args[0] def test_decode_with_escape(): backslashes = b"foo\\bar\\\\baz" backslashes_escaped = "foo\\\\bar\\\\\\\\baz" assert backslashes_escaped == utils.decode_with_escape(backslashes) valid_utf8 = b"foo\xc3\xa2" valid_utf8_escaped = "foo\u00e2" assert valid_utf8_escaped == utils.decode_with_escape(valid_utf8) invalid_utf8 = b"foo\xa2" invalid_utf8_escaped = "foo\\xa2" assert invalid_utf8_escaped == utils.decode_with_escape(invalid_utf8) valid_utf8_nul = b"foo\xc3\xa2\x00" valid_utf8_nul_escaped = "foo\u00e2\\x00" assert valid_utf8_nul_escaped == utils.decode_with_escape(valid_utf8_nul) def test_commonname(): # when actual_commonname = utils.commonname("/some/where/to/", "/some/where/to/go/to") # then assert "go/to" == actual_commonname # when actual_commonname2 = utils.commonname(b"/some/where/to/", b"/some/where/to/go/to") # then assert b"go/to" == actual_commonname2 def test_numfile_sotkey(): assert utils.numfile_sortkey("00-xxx.sql") == (0, "-xxx.sql") assert utils.numfile_sortkey("01-xxx.sql") == (1, "-xxx.sql") assert utils.numfile_sortkey("10-xxx.sql") == (10, "-xxx.sql") assert utils.numfile_sortkey("99-xxx.sql") == (99, "-xxx.sql") assert utils.numfile_sortkey("100-xxx.sql") == (100, "-xxx.sql") assert utils.numfile_sortkey("00100-xxx.sql") == (100, "-xxx.sql") assert utils.numfile_sortkey("1.sql") == (1, ".sql") assert utils.numfile_sortkey("1") == (1, "") assert utils.numfile_sortkey("toto-01.sql") == (999999, "toto-01.sql") def test_basename_sotkey(): assert utils.basename_sortkey("00-xxx.sql") == (0, "-xxx.sql") assert utils.basename_sortkey("path/to/00-xxx.sql") == (0, "-xxx.sql") diff --git a/swh/core/utils.py b/swh/core/utils.py index 79f41cd..e65ed16 100644 --- a/swh/core/utils.py +++ b/swh/core/utils.py @@ -1,137 +1,191 @@ -# Copyright (C) 2016-2017 The Software Heritage developers +# Copyright (C) 2016-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import codecs from contextlib import contextmanager import itertools import os import re -from typing import Tuple +from typing import Iterable, Tuple, TypeVar @contextmanager def cwd(path): """Contextually change the working directory to do thy bidding. Then gets back to the original location. """ prev_cwd = os.getcwd() os.chdir(path) try: yield finally: os.chdir(prev_cwd) def grouper(iterable, n): - """Collect data into fixed-length size iterables. The last block might - contain less elements as it will hold only the remaining number - of elements. + """ + Collect data into fixed-length size iterables. The last block might + contain less elements as it will hold only the remaining number + of elements. + + The invariant here is that the number of elements in the input + iterable and the sum of the number of elements of all iterables + generated from this function should be equal. - The invariant here is that the number of elements in the input - iterable and the sum of the number of elements of all iterables - generated from this function should be equal. + If ``iterable`` is an iterable of bytes or strings that you need to join + later, then :func:`iter_chunks`` is preferable, as it avoids this join + by slicing directly. Args: iterable (Iterable): an iterable n (int): size of block to slice the iterable into Yields: fixed-length blocks as iterables. As mentioned, the last iterable might be less populated. """ args = [iter(iterable)] * n stop_value = object() for _data in itertools.zip_longest(*args, fillvalue=stop_value): yield (d for d in _data if d is not stop_value) +TStr = TypeVar("TStr", bytes, str) + + +def iter_chunks( + iterable: Iterable[TStr], chunk_size: int, *, remainder: bool = False +) -> Iterable[TStr]: + """ + Reads ``bytes`` objects (resp. ``str`` objects) from the ``iterable``, + and yields them as chunks of exactly ``chunk_size`` bytes (resp. characters). + + ``iterable`` is typically obtained by repeatedly calling a method like + :meth:`io.RawIOBase.read`; which does only guarantees an upper bound on the size; + whereas this function returns chunks of exactly the size. + + Args: + iterable: the input data + chunk_size: the exact size of chunks to return + remainder: if True, a last chunk with size strictly smaller than ``chunk_size`` + may be returned, if the data stream from the ``iterable`` had a length that + is not a multiple of ``chunk_size`` + """ + buf = None + iterator = iter(iterable) + while True: + assert buf is None or len(buf) < chunk_size + try: + new_data = next(iterator) + except StopIteration: + if remainder and buf: + yield buf # may be shorter than ``chunk_size`` + return + + if buf: + buf += new_data + else: + # spares a copy + buf = new_data + + new_buf = None + for i in range(0, len(buf), chunk_size): + chunk = buf[i : i + chunk_size] + if len(chunk) == chunk_size: + yield chunk + else: + assert not new_buf + new_buf = chunk + buf = new_buf + + def backslashescape_errors(exception): if isinstance(exception, UnicodeDecodeError): bad_data = exception.object[exception.start : exception.end] escaped = "".join(r"\x%02x" % x for x in bad_data) return escaped, exception.end return codecs.backslashreplace_errors(exception) codecs.register_error("backslashescape", backslashescape_errors) def encode_with_unescape(value): """Encode an unicode string containing \\x backslash escapes""" slices = [] start = 0 odd_backslashes = False i = 0 while i < len(value): if value[i] == "\\": odd_backslashes = not odd_backslashes else: if odd_backslashes: if value[i] != "x": raise ValueError( "invalid escape for %r at position %d" % (value, i - 1) ) slices.append( value[start : i - 1].replace("\\\\", "\\").encode("utf-8") ) slices.append(bytes.fromhex(value[i + 1 : i + 3])) odd_backslashes = False start = i = i + 3 continue i += 1 slices.append(value[start:i].replace("\\\\", "\\").encode("utf-8")) return b"".join(slices) def decode_with_escape(value): """Decode a bytestring as utf-8, escaping the bytes of invalid utf-8 sequences as \\x. We also escape NUL bytes as they are invalid in JSON strings. """ # escape backslashes value = value.replace(b"\\", b"\\\\") value = value.replace(b"\x00", b"\\x00") return value.decode("utf-8", "backslashescape") def commonname(path0, path1, as_str=False): """Compute the commonname between the path0 and path1. """ return path1.split(path0)[1] def numfile_sortkey(fname: str) -> Tuple[int, str]: """Simple function to sort filenames of the form: nnxxx.ext where nn is a number according to the numbers. Returns a tuple (order, remaining), where 'order' is the numeric (int) value extracted from the file name, and 'remaining' is the remaining part of the file name. Typically used to sort sql/nn-swh-xxx.sql files. Unmatched file names will return 999999 as order value. """ m = re.match(r"(\d*)(.*)", fname) assert m is not None num, rem = m.groups() return (int(num) if num else 999999, rem) def basename_sortkey(fname: str) -> Tuple[int, str]: "like numfile_sortkey but on basenames" return numfile_sortkey(os.path.basename(fname))