diff --git a/requirements-test.txt b/requirements-test.txt --- a/requirements-test.txt +++ b/requirements-test.txt @@ -5,3 +5,4 @@ swh.core[testing-core] swh.model[testing] swh.storage[testing] +swh.web[testing] diff --git a/swh/scanner/exceptions.py b/swh/scanner/exceptions.py --- a/swh/scanner/exceptions.py +++ b/swh/scanner/exceptions.py @@ -6,4 +6,9 @@ class APIError(Exception): def __str__(self): - return 'API Error: "%s"' % self.args + return '"%s"' % self.args + + +def error_response(reason: str, status_code: int, api_url: str): + error_msg = f'{status_code} {reason}: \'{api_url}\'' + raise APIError(error_msg) diff --git a/swh/scanner/scanner.py b/swh/scanner/scanner.py --- a/swh/scanner/scanner.py +++ b/swh/scanner/scanner.py @@ -10,7 +10,7 @@ from typing import List, Dict, Tuple, Iterator from pathlib import PosixPath -from .exceptions import APIError +from .exceptions import error_response from .model import Tree from swh.model.cli import pid_of_file, pid_of_dir @@ -49,9 +49,7 @@ async def make_request(pids): async with session.post(endpoint, json=pids) as resp: if resp.status != 200: - error_message = '%s with given values %s' % ( - resp.text, str(pids)) - raise APIError(error_message) + error_response(resp.reason, resp.status, endpoint) return await resp.json() diff --git a/swh/scanner/tests/conftest.py b/swh/scanner/tests/conftest.py --- a/swh/scanner/tests/conftest.py +++ b/swh/scanner/tests/conftest.py @@ -60,6 +60,12 @@ return app +@pytest.fixture(scope='session') +def live_server(live_server): + live_server.start() + return live_server + + @pytest.fixture def test_folder(): tests_path = PosixPath(os.path.abspath(__file__)).parent diff --git a/swh/scanner/tests/flask_api.py b/swh/scanner/tests/flask_api.py --- a/swh/scanner/tests/flask_api.py +++ b/swh/scanner/tests/flask_api.py @@ -7,6 +7,8 @@ from .data import present_pids +from swh.web.common.exc import LargePayloadExc + def create_app(): app = Flask(__name__) @@ -14,6 +16,11 @@ @app.route('/known/', methods=['POST']) def known(): pids = request.get_json() + + if len(pids) > 900: + raise LargePayloadExc('The maximum number of PIDs this endpoint ' + 'can receive is 900') + res = {pid: {'known': False} for pid in pids} for pid in pids: if pid in present_pids: diff --git a/swh/scanner/tests/test_scanner.py b/swh/scanner/tests/test_scanner.py --- a/swh/scanner/tests/test_scanner.py +++ b/swh/scanner/tests/test_scanner.py @@ -35,6 +35,18 @@ pids_discovery([], aiosession, 'http://example.org/api/')) +def test_scanner_raise_apierror_input_size_limit( + event_loop, aiosession, live_server): + + api_url = live_server.url() + '/' + request = ["swh:1:cnt:7c4c57ba9ff496ad179b8f65b1d286edbda34c9a" + for i in range(901)] # /known/ is limited at 900 + + with pytest.raises(APIError): + event_loop.run_until_complete( + pids_discovery(request, aiosession, api_url)) + + def test_scanner_get_subpaths(tmp_path, temp_paths): for subpath, pid in get_subpaths(tmp_path): assert subpath in temp_paths['paths'] @@ -47,7 +59,6 @@ def test_scanner_result(live_server, event_loop, test_folder): - live_server.start() api_url = live_server.url() + '/' result_path = test_folder.joinpath(PosixPath('sample-folder-result.json'))