diff --git a/swh/core/tests/test_utils.py b/swh/core/tests/test_utils.py index f4baefe..3a7c501 100644 --- a/swh/core/tests/test_utils.py +++ b/swh/core/tests/test_utils.py @@ -1,121 +1,133 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest from swh.core import utils def test_grouper(): # given actual_data = utils.grouper((i for i in range(0, 9)), 2) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks assert out == [[0, 1], [2, 3], [4, 5], [6, 7], [8]] # given actual_data = utils.grouper((i for i in range(9, 0, -1)), 4) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks assert out == [[9, 8, 7, 6], [5, 4, 3, 2], [1]] def test_grouper_with_stop_value(): # given actual_data = utils.grouper(((i, i + 1) for i in range(0, 9)), 2) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks assert out == [ [(0, 1), (1, 2)], [(2, 3), (3, 4)], [(4, 5), (5, 6)], [(6, 7), (7, 8)], [(8, 9)], ] # given actual_data = utils.grouper((i for i in range(9, 0, -1)), 4) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks assert out == [[9, 8, 7, 6], [5, 4, 3, 2], [1]] def test_backslashescape_errors(): raw_data_err = b"abcd\x80" with pytest.raises(UnicodeDecodeError): raw_data_err.decode("utf-8", "strict") assert raw_data_err.decode("utf-8", "backslashescape") == "abcd\\x80" raw_data_ok = b"abcd\xc3\xa9" assert raw_data_ok.decode("utf-8", "backslashescape") == raw_data_ok.decode( "utf-8", "strict" ) unicode_data = "abcdef\u00a3" assert unicode_data.encode("ascii", "backslashescape") == b"abcdef\\xa3" def test_encode_with_unescape(): valid_data = "\\x01020304\\x00" valid_data_encoded = b"\x01020304\x00" assert valid_data_encoded == utils.encode_with_unescape(valid_data) def test_encode_with_unescape_invalid_escape(): invalid_data = "test\\abcd" with pytest.raises(ValueError) as exc: utils.encode_with_unescape(invalid_data) assert "invalid escape" in exc.value.args[0] assert "position 4" in exc.value.args[0] def test_decode_with_escape(): backslashes = b"foo\\bar\\\\baz" backslashes_escaped = "foo\\\\bar\\\\\\\\baz" assert backslashes_escaped == utils.decode_with_escape(backslashes) valid_utf8 = b"foo\xc3\xa2" valid_utf8_escaped = "foo\u00e2" assert valid_utf8_escaped == utils.decode_with_escape(valid_utf8) invalid_utf8 = b"foo\xa2" invalid_utf8_escaped = "foo\\xa2" assert invalid_utf8_escaped == utils.decode_with_escape(invalid_utf8) valid_utf8_nul = b"foo\xc3\xa2\x00" valid_utf8_nul_escaped = "foo\u00e2\\x00" assert valid_utf8_nul_escaped == utils.decode_with_escape(valid_utf8_nul) def test_commonname(): # when actual_commonname = utils.commonname("/some/where/to/", "/some/where/to/go/to") # then assert "go/to" == actual_commonname # when actual_commonname2 = utils.commonname(b"/some/where/to/", b"/some/where/to/go/to") # then assert b"go/to" == actual_commonname2 + + +def test_numfile_sotkey(): + assert utils.numfile_sortkey("00-xxx.sql") == (0, "-xxx.sql") + assert utils.numfile_sortkey("01-xxx.sql") == (1, "-xxx.sql") + assert utils.numfile_sortkey("10-xxx.sql") == (10, "-xxx.sql") + assert utils.numfile_sortkey("99-xxx.sql") == (99, "-xxx.sql") + assert utils.numfile_sortkey("100-xxx.sql") == (100, "-xxx.sql") + assert utils.numfile_sortkey("00100-xxx.sql") == (100, "-xxx.sql") + assert utils.numfile_sortkey("1.sql") == (1, ".sql") + assert utils.numfile_sortkey("1") == (1, "") + assert utils.numfile_sortkey("toto-01.sql") == (999999, "toto-01.sql") diff --git a/swh/core/utils.py b/swh/core/utils.py index a14daa5..60a35ea 100644 --- a/swh/core/utils.py +++ b/swh/core/utils.py @@ -1,122 +1,132 @@ # Copyright (C) 2016-2017 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import codecs from contextlib import contextmanager import itertools import os import re +from typing import Tuple @contextmanager def cwd(path): """Contextually change the working directory to do thy bidding. Then gets back to the original location. """ prev_cwd = os.getcwd() os.chdir(path) try: yield finally: os.chdir(prev_cwd) def grouper(iterable, n): """Collect data into fixed-length size iterables. The last block might contain less elements as it will hold only the remaining number of elements. The invariant here is that the number of elements in the input iterable and the sum of the number of elements of all iterables generated from this function should be equal. Args: iterable (Iterable): an iterable n (int): size of block to slice the iterable into Yields: fixed-length blocks as iterables. As mentioned, the last iterable might be less populated. """ args = [iter(iterable)] * n stop_value = object() for _data in itertools.zip_longest(*args, fillvalue=stop_value): yield (d for d in _data if d is not stop_value) def backslashescape_errors(exception): if isinstance(exception, UnicodeDecodeError): bad_data = exception.object[exception.start : exception.end] escaped = "".join(r"\x%02x" % x for x in bad_data) return escaped, exception.end return codecs.backslashreplace_errors(exception) codecs.register_error("backslashescape", backslashescape_errors) def encode_with_unescape(value): """Encode an unicode string containing \\x backslash escapes""" slices = [] start = 0 odd_backslashes = False i = 0 while i < len(value): if value[i] == "\\": odd_backslashes = not odd_backslashes else: if odd_backslashes: if value[i] != "x": raise ValueError( "invalid escape for %r at position %d" % (value, i - 1) ) slices.append( value[start : i - 1].replace("\\\\", "\\").encode("utf-8") ) slices.append(bytes.fromhex(value[i + 1 : i + 3])) odd_backslashes = False start = i = i + 3 continue i += 1 slices.append(value[start:i].replace("\\\\", "\\").encode("utf-8")) return b"".join(slices) def decode_with_escape(value): """Decode a bytestring as utf-8, escaping the bytes of invalid utf-8 sequences as \\x. We also escape NUL bytes as they are invalid in JSON strings. """ # escape backslashes value = value.replace(b"\\", b"\\\\") value = value.replace(b"\x00", b"\\x00") return value.decode("utf-8", "backslashescape") def commonname(path0, path1, as_str=False): """Compute the commonname between the path0 and path1. """ return path1.split(path0)[1] -def numfile_sortkey(fname): +def numfile_sortkey(fname: str) -> Tuple[int, str]: """Simple function to sort filenames of the form: nnxxx.ext where nn is a number according to the numbers. + Returns a tuple (order, remaining), where 'order' is the numeric (int) + value extracted from the file name, and 'remaining' is the remaining part + of the file name. + Typically used to sort sql/nn-swh-xxx.sql files. + + Unmatched file names will return 999999 as order value. + """ - num, rem = re.match(r"(\d*)(.*)", fname).groups() - return (num and int(num) or 99, rem) + m = re.match(r"(\d*)(.*)", fname) + assert m is not None + num, rem = m.groups() + return (int(num) if num else 999999, rem)