diff --git a/swh/scheduler/cli/__init__.py b/swh/scheduler/cli/__init__.py index 49c6f38..3a398bf 100644 --- a/swh/scheduler/cli/__init__.py +++ b/swh/scheduler/cli/__init__.py @@ -1,96 +1,99 @@ -# Copyright (C) 2016-2020 The Software Heritage developers +# Copyright (C) 2016-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information # WARNING: do not import unnecessary things here to keep cli startup time under # control import logging import click from swh.core.cli import CONTEXT_SETTINGS, AliasedGroup from swh.core.cli import swh as swh_cli_group +# If you're looking for subcommand imports, they are further down this file to +# avoid a circular import! + @swh_cli_group.group( name="scheduler", context_settings=CONTEXT_SETTINGS, cls=AliasedGroup ) @click.option( "--config-file", "-C", default=None, type=click.Path(exists=True, dir_okay=False,), help="Configuration file.", ) @click.option( "--database", "-d", default=None, help="Scheduling database DSN (imply cls is 'local')", ) @click.option( "--url", "-u", default=None, help="Scheduler's url access (imply cls is 'remote')" ) @click.option( "--no-stdout", is_flag=True, default=False, help="Do NOT output logs on the console" ) @click.pass_context def cli(ctx, config_file, database, url, no_stdout): """Software Heritage Scheduler tools. Use a local scheduler instance by default (plugged to the main scheduler db). """ try: from psycopg2 import OperationalError except ImportError: class OperationalError(Exception): pass from swh.core import config from swh.scheduler import DEFAULT_CONFIG, get_scheduler ctx.ensure_object(dict) logger = logging.getLogger(__name__) scheduler = None conf = config.read(config_file, DEFAULT_CONFIG) if "scheduler" not in conf: raise ValueError("missing 'scheduler' configuration") if database: conf["scheduler"]["cls"] = "local" conf["scheduler"]["db"] = database elif url: conf["scheduler"]["cls"] = "remote" conf["scheduler"]["url"] = url sched_conf = conf["scheduler"] try: logger.debug("Instantiating scheduler with %s", sched_conf) scheduler = get_scheduler(**sched_conf) except (ValueError, OperationalError): # it's the subcommand to decide whether not having a proper # scheduler instance is a problem. pass ctx.obj["scheduler"] = scheduler ctx.obj["config"] = conf -from . import admin, celery_monitor, task, task_type # noqa +from . import admin, celery_monitor, origin, task, task_type # noqa def main(): import click.core click.core.DEPRECATED_HELP_NOTICE = """ DEPRECATED! Please use the command 'swh scheduler'.""" cli.deprecated = True return cli(auto_envvar_prefix="SWH_SCHEDULER") if __name__ == "__main__": main() diff --git a/swh/scheduler/cli/origin.py b/swh/scheduler/cli/origin.py new file mode 100644 index 0000000..2771aca --- /dev/null +++ b/swh/scheduler/cli/origin.py @@ -0,0 +1,102 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable, List, Optional + +import click + +from . import cli + +if TYPE_CHECKING: + from ..model import ListedOrigin + + +@cli.group("origin") +@click.pass_context +def origin(ctx): + """Manipulate listed origins.""" + if not ctx.obj["scheduler"]: + raise ValueError("Scheduler class (local/remote) must be instantiated") + + +def format_origins( + origins: List[ListedOrigin], + fields: Optional[List[str]] = None, + with_header: bool = True, +) -> Iterable[str]: + """Format a list of origins as CSV. + + Arguments: + origins: list of origins to output + fields: optional list of fields to output (defaults to all fields) + with_header: if True, output a CSV header. + """ + import csv + from io import StringIO + + import attr + + from ..model import ListedOrigin + + expected_fields = [field.name for field in attr.fields(ListedOrigin)] + if not fields: + fields = expected_fields + + unknown_fields = set(fields) - set(expected_fields) + if unknown_fields: + raise ValueError( + "Unknown ListedOrigin field(s): %s" % ", ".join(unknown_fields) + ) + + output = StringIO() + writer = csv.writer(output) + + def csv_row(data): + """Return a single CSV-formatted row. We clear the output buffer after we're + done to keep it reasonably sized.""" + writer.writerow(data) + output.seek(0) + ret = output.read().rstrip() + output.seek(0) + output.truncate() + return ret + + if with_header: + yield csv_row(fields) + + for origin in origins: + yield csv_row(str(getattr(origin, field)) for field in fields) + + +@origin.command("grab-next") +@click.option( + "--policy", "-p", default="oldest_scheduled_first", help="Scheduling policy" +) +@click.option( + "--fields", "-f", default=None, help="Listed origin fields to print on output" +) +@click.option( + "--with-header/--without-header", + is_flag=True, + default=True, + help="Print the CSV header?", +) +@click.argument("count", type=int) +@click.pass_context +def grab_next(ctx, policy: str, fields: Optional[str], with_header: bool, count: int): + """Grab the next COUNT origins to visit from the listed origins table.""" + + if fields: + parsed_fields: Optional[List[str]] = fields.split(",") + else: + parsed_fields = None + + scheduler = ctx.obj["scheduler"] + + origins = scheduler.grab_next_visits(count, policy=policy) + for line in format_origins(origins, fields=parsed_fields, with_header=with_header): + click.echo(line) diff --git a/swh/scheduler/tests/test_cli_origin.py b/swh/scheduler/tests/test_cli_origin.py new file mode 100644 index 0000000..87812ad --- /dev/null +++ b/swh/scheduler/tests/test_cli_origin.py @@ -0,0 +1,76 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from typing import Tuple + +import pytest + +from swh.scheduler.cli.origin import format_origins +from swh.scheduler.tests.test_cli import invoke as basic_invoke + + +def invoke(scheduler, args: Tuple[str, ...] = (), catch_exceptions: bool = False): + return basic_invoke( + scheduler, args=["origin", *args], catch_exceptions=catch_exceptions + ) + + +def test_cli_origin(swh_scheduler): + """Check that swh scheduler origin returns its help text""" + + result = invoke(swh_scheduler) + + assert "Commands:" in result.stdout + + +def test_format_origins_basic(listed_origins): + listed_origins = listed_origins[:100] + + basic_output = list(format_origins(listed_origins)) + # 1 header line + all origins + assert len(basic_output) == len(listed_origins) + 1 + + no_header_output = list(format_origins(listed_origins, with_header=False)) + assert basic_output[1:] == no_header_output + + +def test_format_origins_fields_unknown(listed_origins): + listed_origins = listed_origins[:10] + + it = format_origins(listed_origins, fields=["unknown_field"]) + + with pytest.raises(ValueError, match="unknown_field"): + next(it) + + +def test_format_origins_fields(listed_origins): + listed_origins = listed_origins[:10] + fields = ["lister_id", "url", "visit_type"] + + output = list(format_origins(listed_origins, fields=fields)) + assert output[0] == ",".join(fields) + for i, origin in enumerate(listed_origins): + assert output[i + 1] == f"{origin.lister_id},{origin.url},{origin.visit_type}" + + +def test_grab_next(swh_scheduler, listed_origins): + num_origins = 10 + assert len(listed_origins) >= num_origins + + swh_scheduler.record_listed_origins(listed_origins) + + result = invoke(swh_scheduler, args=("grab-next", str(num_origins))) + assert result.exit_code == 0 + + out_lines = result.stdout.splitlines() + assert len(out_lines) == num_origins + 1 + + fields = out_lines[0].split(",") + returned_origins = [dict(zip(fields, line.split(","))) for line in out_lines[1:]] + + # Check that we've received origins we had listed in the first place + assert set(origin["url"] for origin in returned_origins) <= set( + origin.url for origin in listed_origins + )