diff --git a/swh/loader/package/tests/test_utils.py b/swh/loader/package/tests/test_utils.py --- a/swh/loader/package/tests/test_utils.py +++ b/swh/loader/package/tests/test_utils.py @@ -6,6 +6,8 @@ import json import os +from unittest.mock import MagicMock +from urllib.error import URLError import pytest @@ -130,6 +132,45 @@ download(url, dest=str(tmp_path), hashes=expected_hashes) +@pytest.mark.fs +def test_ftp_download_ok(tmp_path, mocker): + """Download without issue should provide filename and hashes""" + filename = "requests-0.0.1.tar.gz" + url = "ftp://pypi.org/pypi/requests/%s" % filename + data = b"this is something" + + cm = MagicMock() + cm.getstatus.return_value = 200 + cm.read.return_value = data + cm.__enter__.return_value = cm + mocker.patch("swh.loader.package.utils.urlopen").return_value = cm + + actual_filepath, actual_hashes = download(url, dest=str(tmp_path)) + + actual_filename = os.path.basename(actual_filepath) + assert actual_filename == filename + assert actual_hashes["length"] == len(data) + assert ( + actual_hashes["checksums"]["sha1"] == "fdd1ce606a904b08c816ba84f3125f2af44d92b2" + ) # noqa + assert ( + actual_hashes["checksums"]["sha256"] + == "1d9224378d77925d612c9f926eb9fb92850e6551def8328011b6a972323298d5" + ) + + +@pytest.mark.fs +def test_ftp_download_ko(tmp_path, mocker): + """Download without issue should provide filename and hashes""" + filename = "requests-0.0.1.tar.gz" + url = "ftp://pypi.org/pypi/requests/%s" % filename + + mocker.patch("swh.loader.package.utils.urlopen").side_effect = URLError("FTP error") + + with pytest.raises(URLError): + download(url, dest=str(tmp_path)) + + def test_api_info_failure(requests_mock): """Failure to fetch info/release information should raise""" url = "https://pypi.org/pypi/requests/json" diff --git a/swh/loader/package/utils.py b/swh/loader/package/utils.py --- a/swh/loader/package/utils.py +++ b/swh/loader/package/utils.py @@ -3,11 +3,13 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from contextlib import closing import copy import functools import logging import os from typing import Callable, Dict, Optional, Tuple, TypeVar +from urllib.request import urlopen import requests @@ -79,9 +81,17 @@ params["headers"].update(extra_request_headers) # so the connection does not hang indefinitely (read/connection timeout) timeout = params.get("timeout", 60) - response = requests.get(url, **params, timeout=timeout, stream=True) - if response.status_code != 200: - raise ValueError("Fail to query '%s'. Reason: %s" % (url, response.status_code)) + + if url.startswith("ftp://"): + with closing(urlopen(url, timeout=timeout)) as response: + response_data = [response.read()] + else: + response = requests.get(url, **params, timeout=timeout, stream=True) + if response.status_code != 200: + raise ValueError( + "Fail to query '%s'. Reason: %s" % (url, response.status_code) + ) + response_data = response.iter_content(chunk_size=HASH_BLOCK_SIZE) filename = filename if filename else os.path.basename(url) logger.debug("filename: %s", filename) @@ -90,7 +100,7 @@ h = MultiHash(hash_names=DOWNLOAD_HASHES) with open(filepath, "wb") as f: - for chunk in response.iter_content(chunk_size=HASH_BLOCK_SIZE): + for chunk in response_data: h.update(chunk) f.write(chunk)