diff --git a/swh/loader/git/dumb.py b/swh/loader/git/dumb.py --- a/swh/loader/git/dumb.py +++ b/swh/loader/git/dumb.py @@ -11,12 +11,12 @@ import struct from tempfile import SpooledTemporaryFile from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Set, cast +import urllib.parse -from dulwich.client import HttpGitClient from dulwich.errors import NotGitRepository from dulwich.objects import S_IFGITLINK, Commit, ShaFile, Tree from dulwich.pack import Pack, PackData, PackIndex, load_pack_index_file -from urllib3.response import HTTPResponse +import requests from swh.loader.git.utils import HexBytes @@ -26,18 +26,7 @@ logger = logging.getLogger(__name__) -class DumbHttpGitClient(HttpGitClient): - """Simple wrapper around dulwich.client.HTTPGitClient - """ - - def __init__(self, base_url: str): - super().__init__(base_url) - self.user_agent = "Software Heritage dumb Git loader" - - def get(self, url: str) -> HTTPResponse: - logger.debug("Fetching %s", url) - response, _ = self._http_request(url, headers={"User-Agent": self.user_agent}) - return response +HEADERS = {"User-Agent": "Software Heritage dumb Git loader"} def check_protocol(repo_url: str) -> bool: @@ -52,12 +41,11 @@ """ if not repo_url.startswith("http"): return False - http_client = DumbHttpGitClient(repo_url) - url = http_client.get_url("info/refs?service=git-upload-pack") - response = http_client.get(url) - content_type = response.getheader("Content-Type") + url = urllib.parse.urljoin(repo_url, "info/refs?service=git-upload-pack/") + response = requests.get(url, headers=HEADERS) + content_type = response.headers.get("Content-Type") return ( - response.status in (200, 304,) + response.status_code in (200, 304,) # header is not mandatory in protocol specification and (content_type is None or not content_type.startswith("application/x-git-")) ) @@ -75,7 +63,8 @@ """ def __init__(self, repo_url: str, base_repo: RepoRepresentation): - self.http_client = DumbHttpGitClient(repo_url) + self._session = requests.Session() + self.repo_url = repo_url self.base_repo = base_repo self.objects: Dict[bytes, Set[bytes]] = defaultdict(set) self.refs = self._get_refs() @@ -124,10 +113,10 @@ return map(self._get_git_object, self.objects[object_type]) def _http_get(self, path: str) -> SpooledTemporaryFile: - url = self.http_client.get_url(path) - response = self.http_client.get(url) + url = urllib.parse.urljoin(self.repo_url, path) + response = self._session.get(url, headers=HEADERS) buffer = SpooledTemporaryFile(max_size=100 * 1024 * 1024) - buffer.write(response.data) + buffer.write(response.content) buffer.flush() buffer.seek(0) return buffer