diff --git a/swh/loader/core/tests/test_utils.py b/swh/loader/core/tests/test_utils.py --- a/swh/loader/core/tests/test_utils.py +++ b/swh/loader/core/tests/test_utils.py @@ -4,9 +4,18 @@ # See top-level LICENSE file for more information import os +import signal +from time import sleep from unittest.mock import patch -from swh.loader.core.utils import clean_dangling_folders +import pytest + +from swh.loader.core.utils import ( + CloneFailure, + CloneTimeout, + clean_dangling_folders, + clone_with_timeout, +) def prepare_arborescence_from(tmpdir, folder_names): @@ -93,3 +102,46 @@ mock_pid_exists.assert_called_once_with(1468) mock_rmtree.assert_called_once_with(os.path.join(rootpath, path2)) assert_dirs(actual_dirs, [path2, path1]) + + +def test_clone_with_timeout_no_error_no_timeout(): + def succeed(): + """This does nothing to simulate a successful clone""" + + clone_with_timeout("foo", "bar", succeed, timeout=0.5) + + +def test_clone_with_timeout_no_error_timeout(): + def slow(): + """This lasts for more than the timeout""" + sleep(1) + + with pytest.raises(CloneTimeout): + clone_with_timeout("foo", "bar", slow, timeout=0.5) + + +def test_clone_with_timeout_error(): + def raise_something(): + raise RuntimeError("panic!") + + with pytest.raises(CloneFailure): + clone_with_timeout("foo", "bar", raise_something, timeout=0.5) + + +def test_clone_with_timeout_sigkill(): + """This also tests that the traceback is useful""" + src = "https://www.mercurial-scm.org/repo/hello" + dest = "/dev/null" + timeout = 0.5 + sleepy_time = 100 * timeout + assert sleepy_time > timeout + + def ignores_sigterm(*args, **kwargs): + # ignore SIGTERM to force sigkill + signal.signal(signal.SIGTERM, lambda signum, frame: None) + sleep(sleepy_time) # we make sure we exceed the timeout + + with pytest.raises(CloneTimeout) as e: + clone_with_timeout(src, dest, ignores_sigterm, timeout) + killed = True + assert e.value.args == (src, timeout, killed) diff --git a/swh/loader/core/utils.py b/swh/loader/core/utils.py --- a/swh/loader/core/utils.py +++ b/swh/loader/core/utils.py @@ -1,12 +1,18 @@ -# Copyright (C) 2018-2021 The Software Heritage developers +# Copyright (C) 2018-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import io import os import shutil +import signal +import time +import traceback +from typing import Callable +from billiard import Process, Queue # type: ignore import psutil @@ -43,3 +49,57 @@ except Exception as e: if log: log.warn("Fail to clean dangling path %s: %s", path_to_cleanup, e) + + +class CloneTimeout(Exception): + pass + + +class CloneFailure(Exception): + pass + + +def _clone_task(clone_func: Callable[[], None], errors: Queue) -> None: + try: + clone_func() + except Exception as e: + exc_buffer = io.StringIO() + traceback.print_exc(file=exc_buffer) + errors.put_nowait(exc_buffer.getvalue()) + raise e + + +def clone_with_timeout( + src: str, dest: str, clone_func: Callable[[], None], timeout: float +) -> None: + """Clone a repository with timeout. + + Args: + src: clone source + dest: clone destination + clone_func: callable that does the actual cloning + timeout: timeout in seconds + """ + errors: Queue = Queue() + process = Process(target=_clone_task, args=(clone_func, errors)) + process.start() + process.join(timeout) + + if process.is_alive(): + process.terminate() + # Give it literally a second (in successive steps of 0.1 second), + # then kill it. + # Can't use `process.join(1)` here, billiard appears to be bugged + # https://github.com/celery/billiard/issues/270 + killed = False + for _ in range(10): + time.sleep(0.1) + if not process.is_alive(): + break + else: + killed = True + os.kill(process.pid, signal.SIGKILL) + raise CloneTimeout(src, timeout, killed) + + if not errors.empty(): + raise CloneFailure(src, dest, errors.get())