diff --git a/swh/scanner/cli.py b/swh/scanner/cli.py --- a/swh/scanner/cli.py +++ b/swh/scanner/cli.py @@ -142,32 +142,50 @@ type=click.Choice(["auto", "bfs", "greedybfs", "filepriority", "dirpriority"]), help="The scan policy.", ) +@click.option( + "-e", + "--extra-info", + "extra_info", + multiple=True, + type=click.Choice(["origin"]), + help="Add selected additional information about known software artifacts.", +) @click.pass_context -def scan(ctx, root_path, api_url, patterns, out_fmt, interactive, policy): +def scan(ctx, root_path, api_url, patterns, out_fmt, interactive, policy, extra_info): """Scan a source code project to discover files and directories already present in the archive. The source code project can be checked using different policies that can be set - using the -p/--policy option: - - auto: it selects the best policy based on the source code, for codebase(s) with - less than 1000 file/dir contents all the nodes will be queried. - - bfs: scan the source code in the BFS order, checking unknown directories only. - - greedybfs: same as "bfs" policy, but lookup the status of source code artifacts in - chunks, in order to minimize the number of Web API round-trips with the archive. - - filepriority: scan all the source code file contents, checking only unset - directories. (useful if the codebase contains a lot of source files) - - dirpriority: scan all the source code directories and check only unknown - directory contents. - """ + using the -p/--policy option:\n + \b + auto: it selects the best policy based on the source code, for codebase(s) + with less than 1000 file/dir contents all the nodes will be queried. + + bfs: scan the source code in the BFS order, checking unknown directories only. + + \b + greedybfs: same as "bfs" policy, but lookup the status of source code artifacts + in chunks, in order to minimize the number of Web API round-trips with the + archive. + + \b + filepriority: scan all the source code file contents, checking only unset + directories. (useful if the codebase contains a lot of source files) + + dirpriority: scan all the source code directories and check only unknown + directory contents. + + Other information about software artifacts could be specified with the -e/ + --extra-info option:\n + \b + origin: search the origin url of each source code files/dirs using the in-memory + compressed graph. +""" import swh.scanner.scanner as scanner config = setup_config(ctx, api_url) - scanner.scan(config, root_path, patterns, out_fmt, interactive, policy) + extra_info = set(extra_info) + scanner.scan(config, root_path, patterns, out_fmt, interactive, policy, extra_info) @scanner.group("db", help="Manage local knowledge base for swh-scanner") diff --git a/swh/scanner/client.py b/swh/scanner/client.py new file mode 100644 --- /dev/null +++ b/swh/scanner/client.py @@ -0,0 +1,98 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +""" +Minimal async web client for the Software Heritage Web API. + +This module could be removed when +`T2635 ` is implemented. +""" + +import asyncio +import itertools +from typing import Any, Dict, List, Optional + +import aiohttp + +from swh.model.identifiers import CoreSWHID + +from .exceptions import error_response + +# Maximum number of SWHIDs that can be requested by a single call to the +# Web API endpoint /known/ +QUERY_LIMIT = 1000 + +KNOWN_EP = "known/" +GRAPH_RANDOMWALK_EP = "graph/randomwalk/" + + +class Client: + """Manage requests to the Software Heritage Web API. + """ + + def __init__(self, api_url: str, session: aiohttp.ClientSession): + self.api_url = api_url + self.session = session + + async def get_origin(self, swhid: CoreSWHID) -> Optional[Any]: + """Walk the compressed graph to discover the origin of a given swhid + """ + endpoint = ( + f"{self.api_url}{GRAPH_RANDOMWALK_EP}{str(swhid)}/ori/?direction=" + f"backward&limit=-1&resolve_origins=true" + ) + res = None + async with self.session.get(endpoint) as resp: + if resp.status == 200: + res = await resp.text() + res = res.rstrip() + return res + if resp.status != 404: + error_response(resp.reason, resp.status, endpoint) + + return res + + async def known(self, swhids: List[CoreSWHID]) -> Dict[str, Dict[str, bool]]: + """API Request to get information about the SoftWare Heritage persistent + IDentifiers (SWHIDs) given in input. + + Args: + swhids: a list of CoreSWHID instances + api_url: url for the API request + + Returns: + A dictionary with: + + key: + string SWHID searched + value: + value['known'] = True if the SWHID is found + value['known'] = False if the SWHID is not found + + """ + endpoint = self.api_url + KNOWN_EP + requests = [] + + def get_chunk(swhids): + for i in range(0, len(swhids), QUERY_LIMIT): + yield swhids[i : i + QUERY_LIMIT] + + async def make_request(swhids): + swhids = [str(swhid) for swhid in swhids] + async with self.session.post(endpoint, json=swhids) as resp: + if resp.status != 200: + error_response(resp.reason, resp.status, endpoint) + + return await resp.json() + + if len(swhids) > QUERY_LIMIT: + for swhids_chunk in get_chunk(swhids): + requests.append(asyncio.create_task(make_request(swhids_chunk))) + + res = await asyncio.gather(*requests) + # concatenate list of dictionaries + return dict(itertools.chain.from_iterable(e.items() for e in res)) + else: + return await make_request(swhids) diff --git a/swh/scanner/data.py b/swh/scanner/data.py --- a/swh/scanner/data.py +++ b/swh/scanner/data.py @@ -4,12 +4,16 @@ # See top-level LICENSE file for more information from pathlib import Path -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple from swh.model.exceptions import ValidationError from swh.model.from_disk import Directory from swh.model.identifiers import CONTENT, DIRECTORY, CoreSWHID +from .client import Client + +SUPPORTED_INFO = {"known", "origin"} + class MerkleNodeInfo(dict): """Store additional information about Merkle DAG nodes, using SWHIDs as keys""" @@ -27,6 +31,45 @@ super(MerkleNodeInfo, self).__setitem__(key, value) +def init_merkle_node_info(source_tree: Directory, data: MerkleNodeInfo, info: set): + """Populate the MerkleNodeInfo with the SWHIDs of the given source tree and the + attributes that will be stored. + """ + if not info: + raise Exception("Data initialization requires node attributes values.") + nodes_info: Dict[str, Optional[str]] = {} + for ainfo in info: + if ainfo in SUPPORTED_INFO: + nodes_info[ainfo] = None + else: + raise Exception(f"Information {ainfo} is not supported.") + + for node in source_tree.iter_tree(): + data[node.swhid()] = nodes_info.copy() # type: ignore + + +async def add_origin(source_tree: Directory, data: MerkleNodeInfo, client: Client): + """Store origin information about software artifacts retrieved from the Software + Heritage graph service. + """ + queue = [] + queue.append(source_tree) + while queue: + for node in queue.copy(): + queue.remove(node) + node_ori = await client.get_origin(node.swhid()) + if node_ori: + data[node.swhid()]["origin"] = node_ori + if node.object_type == DIRECTORY: + for sub_node in node.iter_tree(): + data[sub_node.swhid()]["origin"] = node_ori # type: ignore + else: + if node.object_type == DIRECTORY: + children = [sub_node for sub_node in node.iter_tree()] + children.remove(node) + queue.extend(children) # type: ignore + + def get_directory_data( root_path: str, source_tree: Directory, diff --git a/swh/scanner/exceptions.py b/swh/scanner/exceptions.py --- a/swh/scanner/exceptions.py +++ b/swh/scanner/exceptions.py @@ -3,6 +3,8 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from typing import Any, Optional + class InvalidObjectType(TypeError): pass @@ -25,6 +27,6 @@ return '"%s"' % self.args -def error_response(reason: str, status_code: int, api_url: str): +def error_response(reason: Optional[Any], status_code: int, api_url: str): error_msg = f"{status_code} {reason}: '{api_url}'" raise APIError(error_msg) diff --git a/swh/scanner/output.py b/swh/scanner/output.py --- a/swh/scanner/output.py +++ b/swh/scanner/output.py @@ -93,12 +93,13 @@ def data_as_json(self): json = {} for node in self.source_tree.iter_tree(): - node_known = self.nodes_data[node.swhid()]["known"] rel_path = os.path.relpath( node.data[self.get_path_name(node)].decode(), self.source_tree.data["path"].decode(), ) - json[rel_path] = {"swhid": str(node.swhid()), "known": node_known} + json[rel_path] = {"swhid": str(node.swhid())} + for k, v in self.nodes_data[node.swhid()].items(): + json[rel_path][k] = v return json def print_json(self): diff --git a/swh/scanner/policy.py b/swh/scanner/policy.py --- a/swh/scanner/policy.py +++ b/swh/scanner/policy.py @@ -4,68 +4,14 @@ # See top-level LICENSE file for more information import abc -import asyncio -import itertools -from typing import Dict, List, no_type_check - -import aiohttp +from typing import no_type_check from swh.core.utils import grouper from swh.model.from_disk import Directory -from swh.model.identifiers import CONTENT, DIRECTORY, CoreSWHID +from swh.model.identifiers import CONTENT, DIRECTORY +from .client import QUERY_LIMIT, Client from .data import MerkleNodeInfo -from .exceptions import error_response - -# Maximum number of SWHIDs that can be requested by a single call to the -# Web API endpoint /known/ -QUERY_LIMIT = 1000 - - -async def swhids_discovery( - swhids: List[CoreSWHID], session: aiohttp.ClientSession, api_url: str, -) -> Dict[str, Dict[str, bool]]: - """API Request to get information about the SoftWare Heritage persistent - IDentifiers (SWHIDs) given in input. - - Args: - swhids: a list of CoreSWHID instances - api_url: url for the API request - - Returns: - A dictionary with: - - key: - string SWHID searched - value: - value['known'] = True if the SWHID is found - value['known'] = False if the SWHID is not found - - """ - endpoint = api_url + "known/" - requests = [] - - def get_chunk(swhids): - for i in range(0, len(swhids), QUERY_LIMIT): - yield swhids[i : i + QUERY_LIMIT] - - async def make_request(swhids): - swhids = [str(swhid) for swhid in swhids] - async with session.post(endpoint, json=swhids) as resp: - if resp.status != 200: - error_response(resp.reason, resp.status, endpoint) - - return await resp.json() - - if len(swhids) > QUERY_LIMIT: - for swhids_chunk in get_chunk(swhids): - requests.append(asyncio.create_task(make_request(swhids_chunk))) - - res = await asyncio.gather(*requests) - # concatenate list of dictionaries - return dict(itertools.chain.from_iterable(e.items() for e in res)) - else: - return await make_request(swhids) def source_size(source_tree: Directory): @@ -83,15 +29,11 @@ """representation of a source code project directory in the merkle tree""" def __init__(self, source_tree: Directory, data: MerkleNodeInfo): - self.data = data self.source_tree = source_tree - for node in source_tree.iter_tree(): - self.data[node.swhid()] = {"known": None} # type: ignore + self.data = data @abc.abstractmethod - async def run( - self, session: aiohttp.ClientSession, api_url: str, - ): + async def run(self, client: Client): """Scan a source code project""" raise NotImplementedError("Must implement run method") @@ -102,15 +44,13 @@ contents to known. """ - async def run( - self, session: aiohttp.ClientSession, api_url: str, - ): + async def run(self, client: Client): queue = [] queue.append(self.source_tree) while queue: swhids = [node.swhid() for node in queue] - swhids_res = await swhids_discovery(swhids, session, api_url) + swhids_res = await client.known(swhids) for node in queue.copy(): queue.remove(node) self.data[node.swhid()]["known"] = swhids_res[str(node.swhid())][ @@ -132,13 +72,11 @@ downstream contents of known directories to known. """ - async def run( - self, session: aiohttp.ClientSession, api_url: str, - ): + async def run(self, client: Client): ssize = source_size(self.source_tree) seen = [] - async for nodes_chunk in self.get_nodes_chunks(session, api_url, ssize): + async for nodes_chunk in self.get_nodes_chunks(client, ssize): for node in nodes_chunk: seen.append(node) if len(seen) == ssize: @@ -151,9 +89,7 @@ self.data[sub_node.swhid()]["known"] = True @no_type_check - async def get_nodes_chunks( - self, session: aiohttp.ClientSession, api_url: str, ssize: int - ): + async def get_nodes_chunks(self, client: Client, ssize: int): """Query chunks of QUERY_LIMIT nodes at once in order to fill the Web API rate limit. It query all the nodes in the case the source code contains less than QUERY_LIMIT nodes. @@ -162,7 +98,7 @@ for nodes_chunk in grouper(nodes, QUERY_LIMIT): nodes_chunk = [n for n in nodes_chunk] swhids = [node.swhid() for node in nodes_chunk] - swhids_res = await swhids_discovery(swhids, session, api_url) + swhids_res = await client.known(swhids) for node in nodes_chunk: swhid = node.swhid() self.data[swhid]["known"] = swhids_res[str(swhid)]["known"] @@ -177,9 +113,7 @@ """ @no_type_check - async def run( - self, session: aiohttp.ClientSession, api_url: str, - ): + async def run(self, client: Client): # get all the files all_contents = list( filter( @@ -190,7 +124,7 @@ # query the backend to get all file contents status cnt_swhids = [node.swhid() for node in all_contents] - cnt_status_res = await swhids_discovery(cnt_swhids, session, api_url) + cnt_status_res = await client.known(cnt_swhids) # set all the file contents status for cnt in all_contents: self.data[cnt.swhid()]["known"] = cnt_status_res[str(cnt.swhid())]["known"] @@ -215,7 +149,7 @@ for dir_ in unset_dirs: if self.data[dir_.swhid()]["known"] is None: # update directory status - dir_status = await swhids_discovery([dir_.swhid()], session, api_url) + dir_status = await client.known([dir_.swhid()]) dir_known = dir_status[str(dir_.swhid())]["known"] self.data[dir_.swhid()]["known"] = dir_known if dir_known: @@ -239,9 +173,7 @@ """ @no_type_check - async def run( - self, session: aiohttp.ClientSession, api_url: str, - ): + async def run(self, client: Client): # get all directory contents that have at least one file content unknown_dirs = list( filter( @@ -253,7 +185,7 @@ for dir_ in unknown_dirs: if self.data[dir_.swhid()]["known"] is None: - dir_status = await swhids_discovery([dir_.swhid()], session, api_url) + dir_status = await client.known([dir_.swhid()]) dir_known = dir_status[str(dir_.swhid())]["known"] self.data[dir_.swhid()]["known"] = dir_known # set all the downstream file contents to known @@ -277,7 +209,7 @@ ) ) empty_dirs_swhids = [n.swhid() for n in empty_dirs] - empty_dir_status = await swhids_discovery(empty_dirs_swhids, session, api_url) + empty_dir_status = await client.known(empty_dirs_swhids) # update status of directories that have no file contents for dir_ in empty_dirs: @@ -294,9 +226,7 @@ ) ) unknown_cnts_swhids = [n.swhid() for n in unknown_cnts] - unknown_cnts_status = await swhids_discovery( - unknown_cnts_swhids, session, api_url - ) + unknown_cnts_status = await client.known(unknown_cnts_swhids) for cnt in unknown_cnts: self.data[cnt.swhid()]["known"] = unknown_cnts_status[str(cnt.swhid())][ @@ -322,11 +252,9 @@ """ @no_type_check - async def run( - self, session: aiohttp.ClientSession, api_url: str, - ): + async def run(self, client: Client): all_nodes = [node for node in self.source_tree.iter_tree()] all_swhids = [node.swhid() for node in all_nodes] - swhids_res = await swhids_discovery(all_swhids, session, api_url) + swhids_res = await client.known(all_swhids) for node in all_nodes: self.data[node.swhid()]["known"] = swhids_res[str(node.swhid())]["known"] diff --git a/swh/scanner/scanner.py b/swh/scanner/scanner.py --- a/swh/scanner/scanner.py +++ b/swh/scanner/scanner.py @@ -11,7 +11,8 @@ from swh.model.cli import model_of_dir from swh.model.from_disk import Directory -from .data import MerkleNodeInfo +from .client import Client +from .data import MerkleNodeInfo, add_origin, init_merkle_node_info from .output import Output from .policy import ( QUERY_LIMIT, @@ -24,13 +25,14 @@ ) -async def run(config: Dict[str, Any], policy) -> None: +async def run( + config: Dict[str, Any], + policy, + source_tree: Directory, + nodes_data: MerkleNodeInfo, + extra_info: set, +) -> None: """Scan a given source code according to the policy given in input. - - Args: - root: the root path to scan - api_url: url for the API request - """ api_url = config["web-api"]["url"] @@ -40,7 +42,14 @@ headers = {} async with aiohttp.ClientSession(headers=headers, trust_env=True) as session: - await policy.run(session, api_url) + client = Client(api_url, session) + for info in extra_info: + if info == "known": + await policy.run(client) + elif info == "origin": + await add_origin(source_tree, nodes_data, client) + else: + raise Exception(f"The information '{info}' cannot be retrieved") def get_policy_obj(source_tree: Directory, nodes_data: MerkleNodeInfo, policy: str): @@ -69,16 +78,21 @@ out_fmt: str, interactive: bool, policy: str, + extra_info: set, ): """Scan a source code project to discover files and directories already present in the archive""" converted_patterns = [pattern.encode() for pattern in exclude_patterns] source_tree = model_of_dir(root_path.encode(), converted_patterns) + nodes_data = MerkleNodeInfo() + extra_info.add("known") + init_merkle_node_info(source_tree, nodes_data, extra_info) + policy = get_policy_obj(source_tree, nodes_data, policy) loop = asyncio.get_event_loop() - loop.run_until_complete(run(config, policy)) + loop.run_until_complete(run(config, policy, source_tree, nodes_data, extra_info)) out = Output(root_path, nodes_data, source_tree) if interactive: diff --git a/swh/scanner/tests/data.py b/swh/scanner/tests/data.py --- a/swh/scanner/tests/data.py +++ b/swh/scanner/tests/data.py @@ -3,12 +3,17 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -correct_api_response = { +correct_known_api_response = { "swh:1:dir:17d207da3804cc60a77cba58e76c3b2f767cb112": {"known": False}, "swh:1:dir:01fa282bb80be5907505d44b4692d3fa40fad140": {"known": True}, "swh:1:dir:4b825dc642cb6eb9a060e54bf8d69288fbee4904": {"known": True}, } +correct_origin_api_response = "https://bitbucket.org/chubbymaggie/bindead.git" + +sample_folder_root_swhid = "swh:1:dir:0a7b61ef5780b03aa274d11069564980246445ce" +fake_origin = {sample_folder_root_swhid: correct_origin_api_response} + present_swhids = [ "swh:1:cnt:7c4c57ba9ff496ad179b8f65b1d286edbda34c9a", # quotes.md "swh:1:cnt:68769579c3eaadbe555379b9c3538e6628bae1eb", # some-binary diff --git a/swh/scanner/tests/flask_api.py b/swh/scanner/tests/flask_api.py --- a/swh/scanner/tests/flask_api.py +++ b/swh/scanner/tests/flask_api.py @@ -3,12 +3,12 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from flask import Flask, request +from flask import Flask, abort, request from swh.scanner.exceptions import LargePayloadExc from swh.scanner.policy import QUERY_LIMIT -from .data import unknown_swhids +from .data import fake_origin, unknown_swhids def create_app(tmp_requests): @@ -38,4 +38,11 @@ return res + @app.route("/graph/randomwalk//ori/", methods=["GET"]) + def randomwalk(swhid): + if swhid in fake_origin.keys(): + return fake_origin[swhid] + else: + abort(404) + return app diff --git a/swh/scanner/tests/test_client.py b/swh/scanner/tests/test_client.py new file mode 100644 --- /dev/null +++ b/swh/scanner/tests/test_client.py @@ -0,0 +1,58 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import json + +import pytest + +from swh.model.identifiers import CoreSWHID +from swh.scanner.client import Client +from swh.scanner.exceptions import APIError + +from .data import correct_known_api_response, correct_origin_api_response + +AIO_URL = "http://example.org/api/" +KNOWN_URL = f"{AIO_URL}known/" +ORIGIN_URL = f"{AIO_URL}graph/randomwalk/" + + +def test_client_known_correct_api_request(mock_aioresponse, event_loop, aiosession): + mock_aioresponse.post( + KNOWN_URL, + status=200, + content_type="application/json", + body=json.dumps(correct_known_api_response), + ) + + client = Client(AIO_URL, aiosession) + actual_result = event_loop.run_until_complete(client.known([])) + + assert correct_known_api_response == actual_result + + +def test_client_known_raise_apierror(mock_aioresponse, event_loop, aiosession): + mock_aioresponse.post(KNOWN_URL, content_type="application/json", status=413) + + client = Client(AIO_URL, aiosession) + with pytest.raises(APIError): + event_loop.run_until_complete(client.known([])) + + +def test_client_get_origin_correct_api_request( + mock_aioresponse, event_loop, aiosession +): + origin_url = ( + f"{ORIGIN_URL}swh:1:dir:01fa282bb80be5907505d44b4692d3fa40fad140/ori" + f"/?direction=backward&limit=-1&resolve_origins=true" + ) + mock_aioresponse.get( + origin_url, status=200, body=correct_origin_api_response, + ) + + client = Client(AIO_URL, aiosession) + swhid = CoreSWHID.from_string("swh:1:dir:01fa282bb80be5907505d44b4692d3fa40fad140") + actual_result = event_loop.run_until_complete(client.get_origin(swhid)) + + assert correct_origin_api_response == actual_result diff --git a/swh/scanner/tests/test_data.py b/swh/scanner/tests/test_data.py --- a/swh/scanner/tests/test_data.py +++ b/swh/scanner/tests/test_data.py @@ -5,16 +5,22 @@ from pathlib import Path +from flask import url_for import pytest from swh.model.exceptions import ValidationError +from swh.scanner.client import Client from swh.scanner.data import ( MerkleNodeInfo, + add_origin, directory_content, get_directory_data, has_dirs, + init_merkle_node_info, ) +from .data import fake_origin + def test_merkle_node_data_wrong_args(): nodes_data = MerkleNodeInfo() @@ -26,6 +32,29 @@ nodes_data["swh:1:dir:17d207da3804cc60a77cba58e76c3b2f767cb112"] = "wrong value" +def test_init_merkle_supported_node_info(source_tree): + nodes_data = MerkleNodeInfo() + init_merkle_node_info(source_tree, nodes_data, {"known", "origin"}) + for _, node_attrs in nodes_data.items(): + assert "known" and "origin" in node_attrs.keys() + + +def test_init_merkle_not_supported_node_info(source_tree): + nodes_data = MerkleNodeInfo() + with pytest.raises(Exception): + init_merkle_node_info(source_tree, nodes_data, {"unsupported_info"}) + + +def test_add_origin(event_loop, live_server, aiosession, source_tree, nodes_data): + api_url = url_for("index", _external=True) + init_merkle_node_info(source_tree, nodes_data, {"known", "origin"}) + client = Client(api_url, aiosession) + + event_loop.run_until_complete(add_origin(source_tree, nodes_data, client)) + for node, attrs in nodes_data.items(): + assert attrs["origin"] == fake_origin[str(source_tree.swhid())] + + def test_get_directory_data(source_tree, nodes_data): root = Path(source_tree.data["path"].decode()) dirs_data = get_directory_data(root, source_tree, nodes_data) diff --git a/swh/scanner/tests/test_policy.py b/swh/scanner/tests/test_policy.py --- a/swh/scanner/tests/test_policy.py +++ b/swh/scanner/tests/test_policy.py @@ -3,51 +3,21 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import json from flask import url_for import pytest from swh.model.identifiers import CONTENT, CoreSWHID, ObjectType -from swh.scanner.data import MerkleNodeInfo -from swh.scanner.exceptions import APIError +from swh.scanner.client import Client +from swh.scanner.data import MerkleNodeInfo, init_merkle_node_info from swh.scanner.policy import ( DirectoryPriority, FilePriority, GreedyBFS, LazyBFS, source_size, - swhids_discovery, ) -from .data import correct_api_response - -aio_url = "http://example.org/api/known/" - - -def test_scanner_correct_api_request(mock_aioresponse, event_loop, aiosession): - mock_aioresponse.post( - aio_url, - status=200, - content_type="application/json", - body=json.dumps(correct_api_response), - ) - - actual_result = event_loop.run_until_complete( - swhids_discovery([], aiosession, "http://example.org/api/") - ) - - assert correct_api_response == actual_result - - -def test_scanner_raise_apierror(mock_aioresponse, event_loop, aiosession): - mock_aioresponse.post(aio_url, content_type="application/json", status=413) - - with pytest.raises(APIError): - event_loop.run_until_complete( - swhids_discovery([], aiosession, "http://example.org/api/") - ) - def test_scanner_directory_priority_has_contents(source_tree): nodes_data = MerkleNodeInfo() @@ -69,8 +39,10 @@ api_url = url_for("index", _external=True) nodes_data = MerkleNodeInfo() + init_merkle_node_info(source_tree_policy, nodes_data, {"known"}) policy = LazyBFS(source_tree_policy, nodes_data) - event_loop.run_until_complete(policy.run(aiosession, api_url)) + client = Client(api_url, aiosession) + event_loop.run_until_complete(policy.run(client)) backend_swhids_requests = get_backend_swhids_order(tmp_requests) @@ -105,8 +77,10 @@ api_url = url_for("index", _external=True) nodes_data = MerkleNodeInfo() + init_merkle_node_info(source_tree_policy, nodes_data, {"known"}) policy = DirectoryPriority(source_tree_policy, nodes_data) - event_loop.run_until_complete(policy.run(aiosession, api_url)) + client = Client(api_url, aiosession) + event_loop.run_until_complete(policy.run(client)) backend_swhids_requests = get_backend_swhids_order(tmp_requests) @@ -124,8 +98,10 @@ api_url = url_for("index", _external=True) nodes_data = MerkleNodeInfo() + init_merkle_node_info(source_tree_policy, nodes_data, {"known"}) policy = FilePriority(source_tree_policy, nodes_data) - event_loop.run_until_complete(policy.run(aiosession, api_url)) + client = Client(api_url, aiosession) + event_loop.run_until_complete(policy.run(client)) backend_swhids_requests = get_backend_swhids_order(tmp_requests) @@ -143,8 +119,10 @@ api_url = url_for("index", _external=True) nodes_data = MerkleNodeInfo() + init_merkle_node_info(big_source_tree, nodes_data, {"known"}) policy = GreedyBFS(big_source_tree, nodes_data) - event_loop.run_until_complete(policy.run(aiosession, api_url)) + client = Client(api_url, aiosession) + event_loop.run_until_complete(policy.run(client)) backend_swhids_requests = get_backend_swhids_order(tmp_requests) @@ -157,11 +135,13 @@ api_url = url_for("index", _external=True) nodes_data = MerkleNodeInfo() + init_merkle_node_info(big_source_tree, nodes_data, {"known"}) policy = GreedyBFS(big_source_tree, nodes_data) + client = Client(api_url, aiosession) chunks = [ n_chunk async for n_chunk in policy.get_nodes_chunks( - aiosession, api_url, source_size(big_source_tree) + client, source_size(big_source_tree) ) ] assert len(chunks) == 2 diff --git a/swh/scanner/tests/test_scanner.py b/swh/scanner/tests/test_scanner.py --- a/swh/scanner/tests/test_scanner.py +++ b/swh/scanner/tests/test_scanner.py @@ -6,7 +6,7 @@ from flask import url_for import pytest -from swh.scanner.data import MerkleNodeInfo +from swh.scanner.data import MerkleNodeInfo, init_merkle_node_info from swh.scanner.policy import DirectoryPriority, FilePriority, LazyBFS, QueryAll from swh.scanner.scanner import get_policy_obj, run @@ -33,8 +33,11 @@ config = {"web-api": {"url": api_url, "auth-token": None}} nodes_data = MerkleNodeInfo() + init_merkle_node_info(source_tree, nodes_data, {"known"}) policy = LazyBFS(source_tree, nodes_data) - event_loop.run_until_complete(run(config, policy)) + event_loop.run_until_complete( + run(config, policy, source_tree, nodes_data, {"known"}) + ) for node in source_tree.iter_tree(): if str(node.swhid()) in unknown_swhids: assert nodes_data[node.swhid()]["known"] is False @@ -47,8 +50,11 @@ config = {"web-api": {"url": api_url, "auth-token": None}} nodes_data = MerkleNodeInfo() + init_merkle_node_info(source_tree, nodes_data, {"known"}) policy = FilePriority(source_tree, nodes_data) - event_loop.run_until_complete(run(config, policy)) + event_loop.run_until_complete( + run(config, policy, source_tree, nodes_data, {"known"}) + ) for node in source_tree.iter_tree(): if str(node.swhid()) in unknown_swhids: assert nodes_data[node.swhid()]["known"] is False @@ -61,8 +67,11 @@ config = {"web-api": {"url": api_url, "auth-token": None}} nodes_data = MerkleNodeInfo() + init_merkle_node_info(source_tree, nodes_data, {"known"}) policy = DirectoryPriority(source_tree, nodes_data) - event_loop.run_until_complete(run(config, policy)) + event_loop.run_until_complete( + run(config, policy, source_tree, nodes_data, {"known"}) + ) for node in source_tree.iter_tree(): if str(node.swhid()) in unknown_swhids: assert nodes_data[node.swhid()]["known"] is False @@ -75,8 +84,11 @@ config = {"web-api": {"url": api_url, "auth-token": None}} nodes_data = MerkleNodeInfo() + init_merkle_node_info(source_tree, nodes_data, {"known"}) policy = QueryAll(source_tree, nodes_data) - event_loop.run_until_complete(run(config, policy)) + event_loop.run_until_complete( + run(config, policy, source_tree, nodes_data, {"known"}) + ) for node in source_tree.iter_tree(): if str(node.swhid()) in unknown_swhids: assert nodes_data[node.swhid()]["known"] is False