diff --git a/swh/loader/package/tests/test_utils.py b/swh/loader/package/tests/test_utils.py index 5b5a544..99e5afd 100644 --- a/swh/loader/package/tests/test_utils.py +++ b/swh/loader/package/tests/test_utils.py @@ -1,160 +1,201 @@ # Copyright (C) 2019-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import json import os +from unittest.mock import MagicMock +from urllib.error import URLError import pytest from swh.loader.exception import NotFound import swh.loader.package from swh.loader.package.utils import api_info, download, release_name def test_version_generation(): assert ( swh.loader.package.__version__ != "devel" ), "Make sure swh.loader.core is installed (e.g. pip install -e .)" @pytest.mark.fs def test_download_fail_to_download(tmp_path, requests_mock): url = "https://pypi.org/pypi/arrow/json" status_code = 404 requests_mock.get(url, status_code=status_code) with pytest.raises(ValueError) as e: download(url, tmp_path) assert e.value.args[0] == "Fail to query '%s'. Reason: %s" % (url, status_code) @pytest.mark.fs def test_download_ok(tmp_path, requests_mock): """Download without issue should provide filename and hashes""" filename = "requests-0.0.1.tar.gz" url = "https://pypi.org/pypi/requests/%s" % filename data = "this is something" requests_mock.get(url, text=data, headers={"content-length": str(len(data))}) actual_filepath, actual_hashes = download(url, dest=str(tmp_path)) actual_filename = os.path.basename(actual_filepath) assert actual_filename == filename assert actual_hashes["length"] == len(data) assert ( actual_hashes["checksums"]["sha1"] == "fdd1ce606a904b08c816ba84f3125f2af44d92b2" ) # noqa assert ( actual_hashes["checksums"]["sha256"] == "1d9224378d77925d612c9f926eb9fb92850e6551def8328011b6a972323298d5" ) @pytest.mark.fs def test_download_ok_no_header(tmp_path, requests_mock): """Download without issue should provide filename and hashes""" filename = "requests-0.0.1.tar.gz" url = "https://pypi.org/pypi/requests/%s" % filename data = "this is something" requests_mock.get(url, text=data) # no header information actual_filepath, actual_hashes = download(url, dest=str(tmp_path)) actual_filename = os.path.basename(actual_filepath) assert actual_filename == filename assert actual_hashes["length"] == len(data) assert ( actual_hashes["checksums"]["sha1"] == "fdd1ce606a904b08c816ba84f3125f2af44d92b2" ) # noqa assert ( actual_hashes["checksums"]["sha256"] == "1d9224378d77925d612c9f926eb9fb92850e6551def8328011b6a972323298d5" ) @pytest.mark.fs def test_download_ok_with_hashes(tmp_path, requests_mock): """Download without issue should provide filename and hashes""" filename = "requests-0.0.1.tar.gz" url = "https://pypi.org/pypi/requests/%s" % filename data = "this is something" requests_mock.get(url, text=data, headers={"content-length": str(len(data))}) # good hashes for such file good = { "sha1": "fdd1ce606a904b08c816ba84f3125f2af44d92b2", "sha256": "1d9224378d77925d612c9f926eb9fb92850e6551def8328011b6a972323298d5", # noqa } actual_filepath, actual_hashes = download(url, dest=str(tmp_path), hashes=good) actual_filename = os.path.basename(actual_filepath) assert actual_filename == filename assert actual_hashes["length"] == len(data) assert actual_hashes["checksums"]["sha1"] == good["sha1"] assert actual_hashes["checksums"]["sha256"] == good["sha256"] @pytest.mark.fs def test_download_fail_hashes_mismatch(tmp_path, requests_mock): """Mismatch hash after download should raise """ filename = "requests-0.0.1.tar.gz" url = "https://pypi.org/pypi/requests/%s" % filename data = "this is something" requests_mock.get(url, text=data, headers={"content-length": str(len(data))}) # good hashes for such file good = { "sha1": "fdd1ce606a904b08c816ba84f3125f2af44d92b2", "sha256": "1d9224378d77925d612c9f926eb9fb92850e6551def8328011b6a972323298d5", # noqa } for hash_algo in good.keys(): wrong_hash = good[hash_algo].replace("1", "0") expected_hashes = good.copy() expected_hashes[hash_algo] = wrong_hash # set the wrong hash expected_msg = "Failure when fetching %s. " "Checksum mismatched: %s != %s" % ( url, wrong_hash, good[hash_algo], ) with pytest.raises(ValueError, match=expected_msg): download(url, dest=str(tmp_path), hashes=expected_hashes) +@pytest.mark.fs +def test_ftp_download_ok(tmp_path, mocker): + """Download without issue should provide filename and hashes""" + filename = "requests-0.0.1.tar.gz" + url = "ftp://pypi.org/pypi/requests/%s" % filename + data = b"this is something" + + cm = MagicMock() + cm.getstatus.return_value = 200 + cm.read.side_effect = [data, b""] + cm.__enter__.return_value = cm + mocker.patch("swh.loader.package.utils.urlopen").return_value = cm + + actual_filepath, actual_hashes = download(url, dest=str(tmp_path)) + + actual_filename = os.path.basename(actual_filepath) + assert actual_filename == filename + assert actual_hashes["length"] == len(data) + assert ( + actual_hashes["checksums"]["sha1"] == "fdd1ce606a904b08c816ba84f3125f2af44d92b2" + ) # noqa + assert ( + actual_hashes["checksums"]["sha256"] + == "1d9224378d77925d612c9f926eb9fb92850e6551def8328011b6a972323298d5" + ) + + +@pytest.mark.fs +def test_ftp_download_ko(tmp_path, mocker): + """Download without issue should provide filename and hashes""" + filename = "requests-0.0.1.tar.gz" + url = "ftp://pypi.org/pypi/requests/%s" % filename + + mocker.patch("swh.loader.package.utils.urlopen").side_effect = URLError("FTP error") + + with pytest.raises(URLError): + download(url, dest=str(tmp_path)) + + def test_api_info_failure(requests_mock): """Failure to fetch info/release information should raise""" url = "https://pypi.org/pypi/requests/json" status_code = 400 requests_mock.get(url, status_code=status_code) with pytest.raises(NotFound) as e0: api_info(url) assert e0.value.args[0] == "Fail to query '%s'. Reason: %s" % (url, status_code) def test_api_info(requests_mock): """Fetching json info from pypi project should be ok""" url = "https://pypi.org/pypi/requests/json" requests_mock.get(url, text='{"version": "0.0.1"}') actual_info = json.loads(api_info(url)) assert actual_info == { "version": "0.0.1", } def test_release_name(): for version, filename, expected_release in [ ("0.0.1", None, "releases/0.0.1"), ("0.0.2", "something", "releases/0.0.2/something"), ]: assert release_name(version, filename) == expected_release diff --git a/swh/loader/package/utils.py b/swh/loader/package/utils.py index 8bf0d4c..1f00406 100644 --- a/swh/loader/package/utils.py +++ b/swh/loader/package/utils.py @@ -1,147 +1,160 @@ # Copyright (C) 2019-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import copy import functools +import itertools import logging import os from typing import Callable, Dict, Optional, Tuple, TypeVar +from urllib.request import urlopen import requests from swh.loader.exception import NotFound from swh.loader.package import DEFAULT_PARAMS from swh.model.hashutil import HASH_BLOCK_SIZE, MultiHash from swh.model.model import Person logger = logging.getLogger(__name__) DOWNLOAD_HASHES = set(["sha1", "sha256", "length"]) EMPTY_AUTHOR = Person(fullname=b"", name=None, email=None,) def api_info(url: str, **extra_params) -> bytes: """Basic api client to retrieve information on project. This deals with fetching json metadata about pypi projects. Args: url (str): The api url (e.g PyPI, npm, etc...) Raises: NotFound in case of query failures (for some reasons: 404, ...) Returns: The associated response's information """ response = requests.get(url, **{**DEFAULT_PARAMS, **extra_params}) if response.status_code != 200: raise NotFound(f"Fail to query '{url}'. Reason: {response.status_code}") return response.content def download( url: str, dest: str, hashes: Dict = {}, filename: Optional[str] = None, auth: Optional[Tuple[str, str]] = None, extra_request_headers: Optional[Dict[str, str]] = None, ) -> Tuple[str, Dict]: """Download a remote tarball from url, uncompresses and computes swh hashes on it. Args: url: Artifact uri to fetch, uncompress and hash dest: Directory to write the archive to hashes: Dict of expected hashes (key is the hash algo) for the artifact to download (those hashes are expected to be hex string) auth: Optional tuple of login/password (for http authentication service, e.g. deposit) Raises: ValueError in case of any error when fetching/computing (length, checksums mismatched...) Returns: Tuple of local (filepath, hashes of filepath) """ params = copy.deepcopy(DEFAULT_PARAMS) if auth is not None: params["auth"] = auth if extra_request_headers is not None: params["headers"].update(extra_request_headers) # so the connection does not hang indefinitely (read/connection timeout) timeout = params.get("timeout", 60) - response = requests.get(url, **params, timeout=timeout, stream=True) - if response.status_code != 200: - raise ValueError("Fail to query '%s'. Reason: %s" % (url, response.status_code)) + + if url.startswith("ftp://"): + response = urlopen(url, timeout=timeout) + chunks = (response.read(HASH_BLOCK_SIZE) for _ in itertools.count()) + response_data = itertools.takewhile(bool, chunks) + else: + response = requests.get(url, **params, timeout=timeout, stream=True) + if response.status_code != 200: + raise ValueError( + "Fail to query '%s'. Reason: %s" % (url, response.status_code) + ) + response_data = response.iter_content(chunk_size=HASH_BLOCK_SIZE) filename = filename if filename else os.path.basename(url) logger.debug("filename: %s", filename) filepath = os.path.join(dest, filename) logger.debug("filepath: %s", filepath) h = MultiHash(hash_names=DOWNLOAD_HASHES) with open(filepath, "wb") as f: - for chunk in response.iter_content(chunk_size=HASH_BLOCK_SIZE): + for chunk in response_data: h.update(chunk) f.write(chunk) + response.close() + # Also check the expected hashes if provided if hashes: actual_hashes = h.hexdigest() for algo_hash in hashes.keys(): actual_digest = actual_hashes[algo_hash] expected_digest = hashes[algo_hash] if actual_digest != expected_digest: raise ValueError( "Failure when fetching %s. " "Checksum mismatched: %s != %s" % (url, expected_digest, actual_digest) ) computed_hashes = h.hexdigest() length = computed_hashes.pop("length") extrinsic_metadata = { "length": length, "filename": filename, "checksums": computed_hashes, "url": url, } logger.debug("extrinsic_metadata", extrinsic_metadata) return filepath, extrinsic_metadata def release_name(version: str, filename: Optional[str] = None) -> str: if filename: return "releases/%s/%s" % (version, filename) return "releases/%s" % version TReturn = TypeVar("TReturn") TSelf = TypeVar("TSelf") _UNDEFINED = object() def cached_method(f: Callable[[TSelf], TReturn]) -> Callable[[TSelf], TReturn]: cache_name = f"_cached_{f.__name__}" @functools.wraps(f) def newf(self): value = getattr(self, cache_name, _UNDEFINED) if value is _UNDEFINED: value = f(self) setattr(self, cache_name, value) return value return newf