diff --git a/requirements-test.txt b/requirements-test.txt index 5c83e89..04a3b9a 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,7 +1,8 @@ pytest aioresponses pytest_asyncio pytest_flask swh.core[testing-core] swh.model[testing] swh.storage[testing] +swh.web[testing] diff --git a/swh/scanner/exceptions.py b/swh/scanner/exceptions.py index 952cad7..ca5b83e 100644 --- a/swh/scanner/exceptions.py +++ b/swh/scanner/exceptions.py @@ -1,9 +1,14 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information class APIError(Exception): def __str__(self): - return 'API Error: "%s"' % self.args + return '"%s"' % self.args + + +def error_response(reason: str, status_code: int, api_url: str): + error_msg = f'{status_code} {reason}: \'{api_url}\'' + raise APIError(error_msg) diff --git a/swh/scanner/scanner.py b/swh/scanner/scanner.py index e4ac2d8..81034f4 100644 --- a/swh/scanner/scanner.py +++ b/swh/scanner/scanner.py @@ -1,145 +1,143 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import itertools import asyncio import aiohttp from typing import List, Dict, Tuple, Iterator from pathlib import PosixPath -from .exceptions import APIError +from .exceptions import error_response from .model import Tree from swh.model.cli import pid_of_file, pid_of_dir from swh.model.identifiers import ( parse_persistent_identifier, DIRECTORY, CONTENT ) async def pids_discovery( pids: List[str], session: aiohttp.ClientSession, api_url: str, ) -> Dict[str, Dict[str, bool]]: """API Request to get information about the persistent identifiers given in input. Args: pids: a list of persistent identifier api_url: url for the API request Returns: A dictionary with: key: persistent identifier searched value: value['known'] = True if the pid is found value['known'] = False if the pid is not found """ endpoint = api_url + 'known/' chunk_size = 1000 requests = [] def get_chunk(pids): for i in range(0, len(pids), chunk_size): yield pids[i:i + chunk_size] async def make_request(pids): async with session.post(endpoint, json=pids) as resp: if resp.status != 200: - error_message = '%s with given values %s' % ( - resp.text, str(pids)) - raise APIError(error_message) + error_response(resp.reason, resp.status, endpoint) return await resp.json() if len(pids) > chunk_size: for pids_chunk in get_chunk(pids): requests.append(asyncio.create_task( make_request(pids_chunk))) res = await asyncio.gather(*requests) # concatenate list of dictionaries return dict(itertools.chain.from_iterable(e.items() for e in res)) else: return await make_request(pids) def get_subpaths( path: PosixPath) -> Iterator[Tuple[PosixPath, str]]: """Find the persistent identifier of the directories and files under a given path. Args: path: the root path Yields: pairs of: path, the relative persistent identifier """ def pid_of(path): if path.is_dir(): return pid_of_dir(bytes(path)) elif path.is_file() or path.is_symlink(): return pid_of_file(bytes(path)) dirpath, dnames, fnames = next(os.walk(path)) for node in itertools.chain(dnames, fnames): sub_path = PosixPath(dirpath).joinpath(node) yield (sub_path, pid_of(sub_path)) async def parse_path( path: PosixPath, session: aiohttp.ClientSession, api_url: str ) -> Iterator[Tuple[str, str, bool]]: """Check if the sub paths of the given path are present in the archive or not. Args: path: the source path api_url: url for the API request Returns: a map containing tuples with: a subpath of the given path, the pid of the subpath and the result of the api call """ parsed_paths = dict(get_subpaths(path)) parsed_pids = await pids_discovery( list(parsed_paths.values()), session, api_url) def unpack(tup): subpath, pid = tup return (subpath, pid, parsed_pids[pid]['known']) return map(unpack, parsed_paths.items()) async def run( root: PosixPath, api_url: str, source_tree: Tree) -> None: """Start scanning from the given root. It fills the source tree with the path discovered. Args: root: the root path to scan api_url: url for the API request """ async def _scan(root, session, api_url, source_tree): for path, pid, found in await parse_path(root, session, api_url): obj_type = parse_persistent_identifier(pid).object_type if obj_type == CONTENT: source_tree.addNode(path, pid if found else None) elif obj_type == DIRECTORY: if found: source_tree.addNode(path, pid) else: source_tree.addNode(path) await _scan(path, session, api_url, source_tree) async with aiohttp.ClientSession() as session: await _scan(root, session, api_url, source_tree) diff --git a/swh/scanner/tests/flask_api.py b/swh/scanner/tests/flask_api.py index 5dd637f..b24211b 100644 --- a/swh/scanner/tests/flask_api.py +++ b/swh/scanner/tests/flask_api.py @@ -1,24 +1,31 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from flask import Flask, request from .data import present_pids +from swh.web.common.exc import LargePayloadExc + def create_app(): app = Flask(__name__) @app.route('/known/', methods=['POST']) def known(): pids = request.get_json() + + if len(pids) > 900: + raise LargePayloadExc('The maximum number of PIDs this endpoint ' + 'can receive is 900') + res = {pid: {'known': False} for pid in pids} for pid in pids: if pid in present_pids: res[pid]['known'] = True return res return app diff --git a/swh/scanner/tests/test_scanner.py b/swh/scanner/tests/test_scanner.py index 4d3a089..c1b609b 100644 --- a/swh/scanner/tests/test_scanner.py +++ b/swh/scanner/tests/test_scanner.py @@ -1,65 +1,76 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest import json from pathlib import PosixPath from .data import correct_api_response from swh.scanner.scanner import pids_discovery, get_subpaths, run from swh.scanner.model import Tree from swh.scanner.exceptions import APIError aio_url = 'http://example.org/api/known/' def test_scanner_correct_api_request(mock_aioresponse, event_loop, aiosession): mock_aioresponse.post(aio_url, status=200, content_type='application/json', body=json.dumps(correct_api_response)) actual_result = event_loop.run_until_complete( pids_discovery([], aiosession, 'http://example.org/api/')) assert correct_api_response == actual_result def test_scanner_raise_apierror(mock_aioresponse, event_loop, aiosession): mock_aioresponse.post(aio_url, content_type='application/json', status=413) with pytest.raises(APIError): event_loop.run_until_complete( pids_discovery([], aiosession, 'http://example.org/api/')) +def test_scanner_raise_apierror_input_size_limit( + event_loop, aiosession, live_server): + + api_url = live_server.url() + '/' + request = ["swh:1:cnt:7c4c57ba9ff496ad179b8f65b1d286edbda34c9a" + for i in range(901)] # /known/ is limited at 900 + + with pytest.raises(APIError): + event_loop.run_until_complete( + pids_discovery(request, aiosession, api_url)) + + def test_scanner_get_subpaths(tmp_path, temp_paths): for subpath, pid in get_subpaths(tmp_path): assert subpath in temp_paths['paths'] assert pid in temp_paths['pids'] @pytest.mark.options(debug=False) def test_app(app): assert not app.debug def test_scanner_result(live_server, event_loop, test_folder): - live_server.start() api_url = live_server.url() + '/' result_path = test_folder.joinpath(PosixPath('sample-folder-result.json')) with open(result_path, 'r') as json_file: expected_result = json.loads(json_file.read()) sample_folder = test_folder.joinpath(PosixPath('sample-folder')) source_tree = Tree(sample_folder) event_loop.run_until_complete( run(sample_folder, api_url, source_tree)) actual_result = source_tree.getJsonTree() assert actual_result == expected_result