diff --git a/swh/loader/mercurial/hgutil.py b/swh/loader/mercurial/hgutil.py index 189fba9..028e04c 100644 --- a/swh/loader/mercurial/hgutil.py +++ b/swh/loader/mercurial/hgutil.py @@ -1,92 +1,104 @@ # Copyright (C) 2020-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import io -from multiprocessing import Process, Queue +import os +import signal +import time import traceback from typing import Dict, NewType +from billiard import Process, Queue + # The internal Mercurial API is not guaranteed to be stable. from mercurial import context, error, hg, smartset, util # type: ignore import mercurial.ui # type: ignore NULLID = mercurial.node.nullid HgNodeId = NewType("HgNodeId", bytes) Repository = hg.localrepo BaseContext = context.basectx LRUCacheDict = util.lrucachedict HgSpanSet = smartset._spanset HgFilteredSet = smartset.filteredset LookupError = error.LookupError def repository(path: str) -> hg.localrepo: ui = mercurial.ui.ui.load() return hg.repository(ui, path.encode()) def branches(repo: hg.localrepo) -> Dict[bytes, HgNodeId]: """List repository named branches and their tip node.""" result = {} for tag, heads, tip, isclosed in repo.branchmap().iterbranches(): if isclosed: continue result[tag] = tip return result class CloneTimeout(Exception): pass class CloneFailure(Exception): pass def _clone_task(src: str, dest: str, errors: Queue) -> None: """Clone task to run in a subprocess. Args: src: clone source dest: clone destination errors: message queue to communicate errors """ try: hg.clone( ui=mercurial.ui.ui.load(), peeropts={}, source=src.encode(), dest=dest.encode(), update=False, ) except Exception as e: exc_buffer = io.StringIO() traceback.print_exc(file=exc_buffer) errors.put_nowait(exc_buffer.getvalue()) raise e -def clone(src: str, dest: str, timeout: int) -> None: +def clone(src: str, dest: str, timeout: float) -> None: """Clone a repository with timeout. Args: src: clone source dest: clone destination timeout: timeout in seconds """ errors: Queue = Queue() process = Process(target=_clone_task, args=(src, dest, errors)) process.start() process.join(timeout) if process.is_alive(): process.terminate() - process.join(1) - if process.is_alive(): - process.kill() - raise CloneTimeout(src, timeout) + # Give it a second (literally), then kill it + # Can't use `process.join(1)` here, billiard appears to be bugged + # https://github.com/celery/billiard/issues/270 + killed = False + for _ in range(10): + time.sleep(0.1) + if not process.is_alive(): + break + else: + killed = True + os.kill(process.pid, signal.SIGKILL) + raise CloneTimeout(src, timeout, killed) if not errors.empty(): raise CloneFailure(src, dest, errors.get()) diff --git a/swh/loader/mercurial/tests/test_hgutil.py b/swh/loader/mercurial/tests/test_hgutil.py index 3ca3682..3deceff 100644 --- a/swh/loader/mercurial/tests/test_hgutil.py +++ b/swh/loader/mercurial/tests/test_hgutil.py @@ -1,46 +1,49 @@ # Copyright (C) 2020-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information - +import signal import time import traceback from mercurial import hg # type: ignore import pytest from .. import hgutil def test_clone_timeout(monkeypatch): src = "https://www.mercurial-scm.org/repo/hello" dest = "/dev/null" - timeout = 1 + timeout = 0.1 def clone(*args, **kwargs): - time.sleep(5) + # ignore SIGTERM to force sigkill + signal.signal(signal.SIGTERM, lambda signum, frame: None) + time.sleep(2) monkeypatch.setattr(hg, "clone", clone) with pytest.raises(hgutil.CloneTimeout) as e: hgutil.clone(src, dest, timeout) - assert e.value.args == (src, timeout) + killed = True + assert e.value.args == (src, timeout, killed) def test_clone_error(caplog, tmp_path, monkeypatch): src = "https://www.mercurial-scm.org/repo/hello" dest = "/dev/null" expected_traceback = "Some traceback" def clone(*args, **kwargs): raise ValueError() def print_exc(file): file.write(expected_traceback) monkeypatch.setattr(hg, "clone", clone) monkeypatch.setattr(traceback, "print_exc", print_exc) with pytest.raises(hgutil.CloneFailure) as e: hgutil.clone(src, dest, 1) assert e.value.args == (src, dest, expected_traceback)