diff --git a/swh/loader/mercurial/hgutil.py b/swh/loader/mercurial/hgutil.py --- a/swh/loader/mercurial/hgutil.py +++ b/swh/loader/mercurial/hgutil.py @@ -5,19 +5,15 @@ from collections import defaultdict from dataclasses import dataclass -import io -import os -import signal -import time -import traceback +from functools import partial from typing import Dict, List, Mapping, NewType, Optional, Set -from billiard import Process, Queue - # The internal Mercurial API is not guaranteed to be stable. from mercurial import bookmarks, context, error, hg, smartset, util # type: ignore import mercurial.ui # type: ignore +from swh.loader.core.utils import clone_with_timeout + NULLID = mercurial.node.nullid HgNodeId = NewType("HgNodeId", bytes) Repository = hg.localrepo @@ -116,64 +112,13 @@ ) -class CloneTimeout(Exception): - pass - - -class CloneFailure(Exception): - pass - - -def _clone_task(src: str, dest: str, errors: Queue) -> None: - """Clone task to run in a subprocess. - - Args: - src: clone source - dest: clone destination - errors: message queue to communicate errors - """ - try: - hg.clone( - ui=mercurial.ui.ui.load(), - peeropts={}, - source=src.encode(), - dest=dest.encode(), - update=False, - ) - except Exception as e: - exc_buffer = io.StringIO() - traceback.print_exc(file=exc_buffer) - errors.put_nowait(exc_buffer.getvalue()) - raise e - - -def clone(src: str, dest: str, timeout: float) -> None: - """Clone a repository with timeout. - - Args: - src: clone source - dest: clone destination - timeout: timeout in seconds - """ - errors: Queue = Queue() - process = Process(target=_clone_task, args=(src, dest, errors)) - process.start() - process.join(timeout) - - if process.is_alive(): - process.terminate() - # Give it literally a second (in successive steps of 0.1 second), then kill it. - # Can't use `process.join(1)` here, billiard appears to be bugged - # https://github.com/celery/billiard/issues/270 - killed = False - for _ in range(10): - time.sleep(0.1) - if not process.is_alive(): - break - else: - killed = True - os.kill(process.pid, signal.SIGKILL) - raise CloneTimeout(src, timeout, killed) - - if not errors.empty(): - raise CloneFailure(src, dest, errors.get()) +def clone(src: str, dest: str, timeout: float): + closure = partial( + hg.clone, + ui=mercurial.ui.ui.load(), + peeropts={}, + source=src.encode(), + dest=dest.encode(), + update=False, + ) + clone_with_timeout(src, dest, closure, timeout) diff --git a/swh/loader/mercurial/tests/test_hgutil.py b/swh/loader/mercurial/tests/test_hgutil.py deleted file mode 100644 --- a/swh/loader/mercurial/tests/test_hgutil.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (C) 2020-2021 The Software Heritage developers -# See the AUTHORS file at the top-level directory of this distribution -# License: GNU General Public License version 3, or any later version -# See top-level LICENSE file for more information -import signal -import time -import traceback - -from mercurial import hg # type: ignore -import pytest - -from .. import hgutil - - -def test_clone_timeout(monkeypatch): - src = "https://www.mercurial-scm.org/repo/hello" - dest = "/dev/null" - timeout = 1 - sleepy_time = 100 * timeout - assert sleepy_time > timeout - - def clone(*args, **kwargs): - # ignore SIGTERM to force sigkill - signal.signal(signal.SIGTERM, lambda signum, frame: None) - time.sleep(sleepy_time) # we make sure we exceed the timeout - - monkeypatch.setattr(hg, "clone", clone) - - with pytest.raises(hgutil.CloneTimeout) as e: - hgutil.clone(src, dest, timeout) - killed = True - assert e.value.args == (src, timeout, killed) - - -def test_clone_error(caplog, tmp_path, monkeypatch): - src = "https://www.mercurial-scm.org/repo/hello" - dest = "/dev/null" - expected_traceback = "Some traceback" - - def clone(*args, **kwargs): - raise ValueError() - - def print_exc(file): - file.write(expected_traceback) - - monkeypatch.setattr(hg, "clone", clone) - monkeypatch.setattr(traceback, "print_exc", print_exc) - - with pytest.raises(hgutil.CloneFailure) as e: - hgutil.clone(src, dest, 1) - assert e.value.args == (src, dest, expected_traceback)