diff --git a/swh/loader/package/tests/test_utils.py b/swh/loader/package/tests/test_utils.py index a61194d..bd7df6c 100644 --- a/swh/loader/package/tests/test_utils.py +++ b/swh/loader/package/tests/test_utils.py @@ -1,227 +1,238 @@ # Copyright (C) 2019-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import json import os from unittest.mock import MagicMock from urllib.error import URLError from urllib.parse import quote import pytest from swh.loader.exception import NotFound import swh.loader.package from swh.loader.package.utils import api_info, download, release_name def test_version_generation(): assert ( swh.loader.package.__version__ != "devel" ), "Make sure swh.loader.core is installed (e.g. pip install -e .)" @pytest.mark.fs def test_download_fail_to_download(tmp_path, requests_mock): url = "https://pypi.org/pypi/arrow/json" status_code = 404 requests_mock.get(url, status_code=status_code) with pytest.raises(ValueError) as e: download(url, tmp_path) assert e.value.args[0] == "Fail to query '%s'. Reason: %s" % (url, status_code) _filename = "requests-0.0.1.tar.gz" _data = "this is something" def _check_download_ok(url, dest, filename=_filename, hashes=None): actual_filepath, actual_hashes = download(url, dest, hashes=hashes) actual_filename = os.path.basename(actual_filepath) assert actual_filename == filename assert actual_hashes["length"] == len(_data) assert ( actual_hashes["checksums"]["sha1"] == "fdd1ce606a904b08c816ba84f3125f2af44d92b2" ) assert ( actual_hashes["checksums"]["sha256"] == "1d9224378d77925d612c9f926eb9fb92850e6551def8328011b6a972323298d5" ) @pytest.mark.fs def test_download_ok(tmp_path, requests_mock): """Download without issue should provide filename and hashes""" url = f"https://pypi.org/pypi/requests/{_filename}" requests_mock.get(url, text=_data, headers={"content-length": str(len(_data))}) _check_download_ok(url, dest=str(tmp_path)) @pytest.mark.fs def test_download_ok_no_header(tmp_path, requests_mock): """Download without issue should provide filename and hashes""" url = f"https://pypi.org/pypi/requests/{_filename}" requests_mock.get(url, text=_data) # no header information _check_download_ok(url, dest=str(tmp_path)) @pytest.mark.fs def test_download_ok_with_hashes(tmp_path, requests_mock): """Download without issue should provide filename and hashes""" url = f"https://pypi.org/pypi/requests/{_filename}" requests_mock.get(url, text=_data, headers={"content-length": str(len(_data))}) # good hashes for such file good = { "sha1": "fdd1ce606a904b08c816ba84f3125f2af44d92b2", "sha256": "1d9224378d77925d612c9f926eb9fb92850e6551def8328011b6a972323298d5", # noqa } _check_download_ok(url, dest=str(tmp_path), hashes=good) @pytest.mark.fs def test_download_fail_hashes_mismatch(tmp_path, requests_mock): """Mismatch hash after download should raise """ url = f"https://pypi.org/pypi/requests/{_filename}" requests_mock.get(url, text=_data, headers={"content-length": str(len(_data))}) # good hashes for such file good = { "sha1": "fdd1ce606a904b08c816ba84f3125f2af44d92b2", "sha256": "1d9224378d77925d612c9f926eb9fb92850e6551def8328011b6a972323298d5", # noqa } for hash_algo in good.keys(): wrong_hash = good[hash_algo].replace("1", "0") expected_hashes = good.copy() expected_hashes[hash_algo] = wrong_hash # set the wrong hash expected_msg = "Failure when fetching %s. " "Checksum mismatched: %s != %s" % ( url, wrong_hash, good[hash_algo], ) with pytest.raises(ValueError, match=expected_msg): download(url, dest=str(tmp_path), hashes=expected_hashes) @pytest.mark.fs def test_ftp_download_ok(tmp_path, mocker): """Download without issue should provide filename and hashes""" url = f"ftp://pypi.org/pypi/requests/{_filename}" cm = MagicMock() cm.getstatus.return_value = 200 cm.read.side_effect = [_data.encode(), b""] cm.__enter__.return_value = cm mocker.patch("swh.loader.package.utils.urlopen").return_value = cm _check_download_ok(url, dest=str(tmp_path)) @pytest.mark.fs def test_ftp_download_ko(tmp_path, mocker): """Download without issue should provide filename and hashes""" filename = "requests-0.0.1.tar.gz" url = "ftp://pypi.org/pypi/requests/%s" % filename mocker.patch("swh.loader.package.utils.urlopen").side_effect = URLError("FTP error") with pytest.raises(URLError): download(url, dest=str(tmp_path)) @pytest.mark.fs def test_download_with_redirection(tmp_path, requests_mock): """Download with redirection should use the targeted URL to extract filename""" url = "https://example.org/project/requests/download" redirection_url = f"https://example.org/project/requests/files/{_filename}" requests_mock.get(url, status_code=302, headers={"location": redirection_url}) requests_mock.get( redirection_url, text=_data, headers={"content-length": str(len(_data))} ) _check_download_ok(url, dest=str(tmp_path)) +def test_download_extracting_filename_from_url(tmp_path, requests_mock): + """Extracting filename from url must sanitize the filename first""" + url = "https://example.org/project/requests-0.0.1.tar.gz?a=b&c=d&foo=bar" + + requests_mock.get( + url, status_code=200, text=_data, headers={"content-length": str(len(_data))} + ) + + _check_download_ok(url, dest=str(tmp_path)) + + @pytest.mark.fs @pytest.mark.parametrize( "filename", [f'"{_filename}"', _filename, '"filename with spaces.tar.gz"'] ) def test_download_filename_from_content_disposition(tmp_path, requests_mock, filename): """Filename should be extracted from content-disposition request header when available.""" url = "https://example.org/download/requests/tar.gz/v0.0.1" requests_mock.get( url, text=_data, headers={ "content-length": str(len(_data)), "content-disposition": f"attachment; filename={filename}", }, ) _check_download_ok(url, dest=str(tmp_path), filename=filename.strip('"')) @pytest.mark.fs @pytest.mark.parametrize("filename", ['"archive école.tar.gz"', "archive_école.tgz"]) def test_download_utf8_filename_from_content_disposition( tmp_path, requests_mock, filename ): """Filename should be extracted from content-disposition request header when available.""" url = "https://example.org/download/requests/tar.gz/v0.0.1" data = "this is something" requests_mock.get( url, text=data, headers={ "content-length": str(len(data)), "content-disposition": f"attachment; filename*=utf-8''{quote(filename)}", }, ) _check_download_ok(url, dest=str(tmp_path), filename=filename.strip('"')) def test_api_info_failure(requests_mock): """Failure to fetch info/release information should raise""" url = "https://pypi.org/pypi/requests/json" status_code = 400 requests_mock.get(url, status_code=status_code) with pytest.raises(NotFound) as e0: api_info(url) assert e0.value.args[0] == "Fail to query '%s'. Reason: %s" % (url, status_code) def test_api_info(requests_mock): """Fetching json info from pypi project should be ok""" url = "https://pypi.org/pypi/requests/json" requests_mock.get(url, text='{"version": "0.0.1"}') actual_info = json.loads(api_info(url)) assert actual_info == { "version": "0.0.1", } def test_release_name(): for version, filename, expected_release in [ ("0.0.1", None, "releases/0.0.1"), ("0.0.2", "something", "releases/0.0.2/something"), ]: assert release_name(version, filename) == expected_release diff --git a/swh/loader/package/utils.py b/swh/loader/package/utils.py index 841b7bc..4757581 100644 --- a/swh/loader/package/utils.py +++ b/swh/loader/package/utils.py @@ -1,184 +1,185 @@ # Copyright (C) 2019-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import copy import functools import itertools import logging import os import re from typing import Callable, Dict, Optional, Tuple, TypeVar -from urllib.parse import unquote +from urllib.parse import unquote, urlsplit from urllib.request import urlopen import requests from swh.loader.exception import NotFound from swh.loader.package import DEFAULT_PARAMS from swh.model.hashutil import HASH_BLOCK_SIZE, MultiHash from swh.model.model import Person logger = logging.getLogger(__name__) DOWNLOAD_HASHES = set(["sha1", "sha256", "length"]) EMPTY_AUTHOR = Person(fullname=b"", name=None, email=None,) def api_info(url: str, **extra_params) -> bytes: """Basic api client to retrieve information on project. This deals with fetching json metadata about pypi projects. Args: url (str): The api url (e.g PyPI, npm, etc...) Raises: NotFound in case of query failures (for some reasons: 404, ...) Returns: The associated response's information """ response = requests.get(url, **{**DEFAULT_PARAMS, **extra_params}) if response.status_code != 200: raise NotFound(f"Fail to query '{url}'. Reason: {response.status_code}") return response.content def _content_disposition_filename(header: str) -> Optional[str]: fname = None fnames = re.findall(r"filename[\*]?=([^;]+)", header) if fnames and "utf-8''" in fnames[0].lower(): # RFC 5987 fname = re.sub("utf-8''", "", fnames[0], flags=re.IGNORECASE) fname = unquote(fname) elif fnames: fname = fnames[0] if fname: fname = os.path.basename(fname.strip().strip('"')) return fname def download( url: str, dest: str, hashes: Dict = {}, filename: Optional[str] = None, auth: Optional[Tuple[str, str]] = None, extra_request_headers: Optional[Dict[str, str]] = None, ) -> Tuple[str, Dict]: """Download a remote tarball from url, uncompresses and computes swh hashes on it. Args: url: Artifact uri to fetch, uncompress and hash dest: Directory to write the archive to hashes: Dict of expected hashes (key is the hash algo) for the artifact to download (those hashes are expected to be hex string) auth: Optional tuple of login/password (for http authentication service, e.g. deposit) Raises: ValueError in case of any error when fetching/computing (length, checksums mismatched...) Returns: Tuple of local (filepath, hashes of filepath) """ params = copy.deepcopy(DEFAULT_PARAMS) if auth is not None: params["auth"] = auth if extra_request_headers is not None: params["headers"].update(extra_request_headers) # so the connection does not hang indefinitely (read/connection timeout) timeout = params.get("timeout", 60) if url.startswith("ftp://"): response = urlopen(url, timeout=timeout) chunks = (response.read(HASH_BLOCK_SIZE) for _ in itertools.count()) response_data = itertools.takewhile(bool, chunks) else: response = requests.get(url, **params, timeout=timeout, stream=True) if response.status_code != 200: raise ValueError( "Fail to query '%s'. Reason: %s" % (url, response.status_code) ) # update URL to response one as requests follow redirection by default # on GET requests url = response.url # try to extract filename from content-disposition header if available if filename is None and "content-disposition" in response.headers: filename = _content_disposition_filename( response.headers["content-disposition"] ) response_data = response.iter_content(chunk_size=HASH_BLOCK_SIZE) - filename = filename if filename else os.path.basename(url) + filename = filename if filename else os.path.basename(urlsplit(url).path) + logger.debug("filename: %s", filename) filepath = os.path.join(dest, filename) logger.debug("filepath: %s", filepath) h = MultiHash(hash_names=DOWNLOAD_HASHES) with open(filepath, "wb") as f: for chunk in response_data: h.update(chunk) f.write(chunk) response.close() # Also check the expected hashes if provided if hashes: actual_hashes = h.hexdigest() for algo_hash in hashes.keys(): actual_digest = actual_hashes[algo_hash] expected_digest = hashes[algo_hash] if actual_digest != expected_digest: raise ValueError( "Failure when fetching %s. " "Checksum mismatched: %s != %s" % (url, expected_digest, actual_digest) ) computed_hashes = h.hexdigest() length = computed_hashes.pop("length") extrinsic_metadata = { "length": length, "filename": filename, "checksums": computed_hashes, "url": url, } logger.debug("extrinsic_metadata", extrinsic_metadata) return filepath, extrinsic_metadata def release_name(version: str, filename: Optional[str] = None) -> str: if filename: return "releases/%s/%s" % (version, filename) return "releases/%s" % version TReturn = TypeVar("TReturn") TSelf = TypeVar("TSelf") _UNDEFINED = object() def cached_method(f: Callable[[TSelf], TReturn]) -> Callable[[TSelf], TReturn]: cache_name = f"_cached_{f.__name__}" @functools.wraps(f) def newf(self): value = getattr(self, cache_name, _UNDEFINED) if value is _UNDEFINED: value = f(self) setattr(self, cache_name, value) return value return newf