diff --git a/swh/scanner/cli.py b/swh/scanner/cli.py --- a/swh/scanner/cli.py +++ b/swh/scanner/cli.py @@ -5,18 +5,29 @@ # WARNING: do not import unnecessary things here to keep cli startup time under # control +import os +from typing import Any, Dict + import click from pathlib import PosixPath from typing import Tuple +from swh.core import config from swh.core.cli import CONTEXT_SETTINGS -@click.group(name="scanner", context_settings=CONTEXT_SETTINGS) -@click.pass_context -def scanner(ctx): - """Software Heritage Scanner tools.""" - pass +# All generic config code should reside in swh.core.config +DEFAULT_CONFIG_PATH = os.environ.get( + "SWH_CONFIG_FILE", os.path.join(click.get_app_dir("swh"), "global.yml") +) + + +DEFAULT_CONFIG: Dict[str, Any] = { + "web-api": { + "url": "https://archive.softwareheritage.org/api/1/", + "auth-token": None, + } +} def parse_url(url): @@ -53,12 +64,32 @@ yield re.compile(regex) +@click.group(name="scanner", context_settings=CONTEXT_SETTINGS) +@click.option( + "-C", + "--config-file", + default=DEFAULT_CONFIG_PATH, + type=click.Path(exists=True, dir_okay=False, path_type=str), + help="YAML configuration file", +) +@click.pass_context +def scanner(ctx, config_file: str): + """Software Heritage Scanner tools.""" + + # recursive merge not done by config.read + conf = config.read_raw_config(config.config_basepath(config_file)) + conf = config.merge_configs(DEFAULT_CONFIG, conf) + + ctx.ensure_object(dict) + ctx.obj["config"] = conf + + @scanner.command(name="scan") @click.argument("root_path", required=True, type=click.Path(exists=True)) @click.option( "-u", "--api-url", - default="https://archive.softwareheritage.org/api/1", + default=None, metavar="API_URL", show_default=True, help="URL for the api request", @@ -93,16 +124,19 @@ from .plot import generate_sunburst from .dashboard.dashboard import run_app + config = ctx.obj["config"] + if api_url: + config["web-api"]["url"] = parse_url(api_url) + sre_patterns = set() if patterns: sre_patterns = { reg_obj for reg_obj in extract_regex_objs(PosixPath(root_path), patterns) } - api_url = parse_url(api_url) source_tree = Tree(PosixPath(root_path)) loop = asyncio.get_event_loop() - loop.run_until_complete(run(root_path, api_url, source_tree, sre_patterns)) + loop.run_until_complete(run(config, root_path, source_tree, sre_patterns)) if interactive: root = PosixPath(root_path) @@ -113,5 +147,9 @@ source_tree.show(format) +def main(): + return scanner(auto_envvar_prefix="SWH_SCANNER") + + if __name__ == "__main__": - scan() + main() diff --git a/swh/scanner/scanner.py b/swh/scanner/scanner.py --- a/swh/scanner/scanner.py +++ b/swh/scanner/scanner.py @@ -6,9 +6,10 @@ import os import itertools import asyncio -import aiohttp -from typing import List, Dict, Tuple, Iterator, Union, Set, Any from pathlib import PosixPath +from typing import List, Dict, Tuple, Iterator, Union, Iterable, Pattern, Any + +import aiohttp from .exceptions import error_response from .model import Tree @@ -66,7 +67,9 @@ return await make_request(swhids) -def directory_filter(path_name: Union[str, bytes], exclude_patterns: Set[Any]) -> bool: +def directory_filter( + path_name: Union[str, bytes], exclude_patterns: Iterable[Pattern[str]] +) -> bool: """It checks if the path_name is matching with the patterns given in input. It is also used as a `dir_filter` function when generating the directory @@ -84,7 +87,7 @@ def get_subpaths( - path: PosixPath, exclude_patterns: Set[Any] + path: PosixPath, exclude_patterns: Iterable[Pattern[str]] ) -> Iterator[Tuple[PosixPath, str]]: """Find the SoftWare Heritage persistent IDentifier (SWHID) of the directories and files under a given path. @@ -126,7 +129,7 @@ path: PosixPath, session: aiohttp.ClientSession, api_url: str, - exclude_patterns: Set[Any], + exclude_patterns: Iterable[Pattern[str]], ) -> Iterator[Tuple[str, str, bool]]: """Check if the sub paths of the given path are present in the archive or not. @@ -153,7 +156,10 @@ async def run( - root: PosixPath, api_url: str, source_tree: Tree, exclude_patterns: Set[Any] + config: Dict[str, Any], + root: str, + source_tree: Tree, + exclude_patterns: Iterable[Pattern[str]], ) -> None: """Start scanning from the given root. @@ -164,6 +170,7 @@ api_url: url for the API request """ + api_url = config["web-api"]["url"] async def _scan(root, session, api_url, source_tree, exclude_patterns): for path, obj_swhid, known in await parse_path( @@ -178,5 +185,10 @@ if not known: await _scan(path, session, api_url, source_tree, exclude_patterns) - async with aiohttp.ClientSession() as session: + if config["web-api"]["auth-token"]: + headers = {"Authorization": f"Bearer {config['web-api']['auth-token']}"} + else: + headers = {} + + async with aiohttp.ClientSession(headers=headers) as session: await _scan(root, session, api_url, source_tree, exclude_patterns) diff --git a/swh/scanner/tests/test_scanner.py b/swh/scanner/tests/test_scanner.py --- a/swh/scanner/tests/test_scanner.py +++ b/swh/scanner/tests/test_scanner.py @@ -71,9 +71,10 @@ def test_scanner_result(live_server, event_loop, test_sample_folder): api_url = live_server.url() + "/" + config = {"web-api": {"url": api_url, "auth-token": None}} source_tree = Tree(test_sample_folder) - event_loop.run_until_complete(run(test_sample_folder, api_url, source_tree, set())) + event_loop.run_until_complete(run(config, test_sample_folder, source_tree, set())) for child_node in source_tree.iterate(): node_info = list(child_node.attributes.values())[0] @@ -87,6 +88,7 @@ live_server, event_loop, test_sample_folder ): api_url = live_server.url() + "/" + config = {"web-api": {"url": api_url, "auth-token": None}} patterns = (str(test_sample_folder) + "/toexclude",) exclude_pattern = { @@ -95,7 +97,7 @@ source_tree = Tree(test_sample_folder) event_loop.run_until_complete( - run(test_sample_folder, api_url, source_tree, exclude_pattern) + run(config, test_sample_folder, source_tree, exclude_pattern) ) for child_node in source_tree.iterate():