diff --git a/swh/scheduler/cli/__init__.py b/swh/scheduler/cli/__init__.py --- a/swh/scheduler/cli/__init__.py +++ b/swh/scheduler/cli/__init__.py @@ -1,4 +1,4 @@ -# Copyright (C) 2016-2020 The Software Heritage developers +# Copyright (C) 2016-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -12,6 +12,9 @@ from swh.core.cli import CONTEXT_SETTINGS, AliasedGroup from swh.core.cli import swh as swh_cli_group +# If you're looking for subcommand imports, they are further down this file to +# avoid a circular import! + @swh_cli_group.group( name="scheduler", context_settings=CONTEXT_SETTINGS, cls=AliasedGroup @@ -79,7 +82,7 @@ ctx.obj["config"] = conf -from . import admin, celery_monitor, task, task_type # noqa +from . import admin, celery_monitor, origin, task, task_type # noqa def main(): diff --git a/swh/scheduler/cli/origin.py b/swh/scheduler/cli/origin.py new file mode 100644 --- /dev/null +++ b/swh/scheduler/cli/origin.py @@ -0,0 +1,102 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable, List, Optional + +import click + +from . import cli + +if TYPE_CHECKING: + from ..model import ListedOrigin + + +@cli.group("origin") +@click.pass_context +def origin(ctx): + """Manipulate listed origins.""" + if not ctx.obj["scheduler"]: + raise ValueError("Scheduler class (local/remote) must be instantiated") + + +def format_origins( + origins: List[ListedOrigin], + fields: Optional[List[str]] = None, + with_header: bool = True, +) -> Iterable[str]: + """Format a list of origins as CSV. + + Arguments: + origins: list of origins to output + fields: optional list of fields to output (defaults to all fields) + with_header: if True, output a CSV header. + """ + import csv + from io import StringIO + + import attr + + from ..model import ListedOrigin + + expected_fields = [field.name for field in attr.fields(ListedOrigin)] + if not fields: + fields = expected_fields + + unknown_fields = set(fields) - set(expected_fields) + if unknown_fields: + raise ValueError( + "Unknown ListedOrigin field(s): %s" % ", ".join(unknown_fields) + ) + + output = StringIO() + writer = csv.writer(output) + + def csv_row(data): + """Return a single CSV-formatted row. We clear the output buffer after we're + done to keep it reasonably sized.""" + writer.writerow(data) + output.seek(0) + ret = output.read().rstrip() + output.seek(0) + output.truncate() + return ret + + if with_header: + yield csv_row(fields) + + for origin in origins: + yield csv_row(str(getattr(origin, field)) for field in fields) + + +@origin.command("grab-next") +@click.option( + "--policy", "-p", default="oldest_scheduled_first", help="Scheduling policy" +) +@click.option( + "--fields", "-f", default=None, help="Listed origin fields to print on output" +) +@click.option( + "--with-header/--without-header", + is_flag=True, + default=True, + help="Print the CSV header?", +) +@click.argument("count", type=int) +@click.pass_context +def grab_next(ctx, policy: str, fields: Optional[str], with_header: bool, count: int): + """Grab the next COUNT origins to visit from the listed origins table.""" + + if fields: + parsed_fields: Optional[List[str]] = fields.split(",") + else: + parsed_fields = None + + scheduler = ctx.obj["scheduler"] + + origins = scheduler.grab_next_visits(count, policy=policy) + for line in format_origins(origins, fields=parsed_fields, with_header=with_header): + click.echo(line) diff --git a/swh/scheduler/tests/test_cli_origin.py b/swh/scheduler/tests/test_cli_origin.py new file mode 100644 --- /dev/null +++ b/swh/scheduler/tests/test_cli_origin.py @@ -0,0 +1,76 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from typing import Tuple + +import pytest + +from swh.scheduler.cli.origin import format_origins +from swh.scheduler.tests.test_cli import invoke as basic_invoke + + +def invoke(scheduler, args: Tuple[str, ...] = (), catch_exceptions: bool = False): + return basic_invoke( + scheduler, args=["origin", *args], catch_exceptions=catch_exceptions + ) + + +def test_cli_origin(swh_scheduler): + """Check that swh scheduler origin returns its help text""" + + result = invoke(swh_scheduler) + + assert "Commands:" in result.stdout + + +def test_format_origins_basic(listed_origins): + listed_origins = listed_origins[:100] + + basic_output = list(format_origins(listed_origins)) + # 1 header line + all origins + assert len(basic_output) == len(listed_origins) + 1 + + no_header_output = list(format_origins(listed_origins, with_header=False)) + assert basic_output[1:] == no_header_output + + +def test_format_origins_fields_unknown(listed_origins): + listed_origins = listed_origins[:10] + + it = format_origins(listed_origins, fields=["unknown_field"]) + + with pytest.raises(ValueError, match="unknown_field"): + next(it) + + +def test_format_origins_fields(listed_origins): + listed_origins = listed_origins[:10] + fields = ["lister_id", "url", "visit_type"] + + output = list(format_origins(listed_origins, fields=fields)) + assert output[0] == ",".join(fields) + for i, origin in enumerate(listed_origins): + assert output[i + 1] == f"{origin.lister_id},{origin.url},{origin.visit_type}" + + +def test_grab_next(swh_scheduler, listed_origins): + num_origins = 10 + assert len(listed_origins) >= num_origins + + swh_scheduler.record_listed_origins(listed_origins) + + result = invoke(swh_scheduler, args=("grab-next", str(num_origins))) + assert result.exit_code == 0 + + out_lines = result.stdout.splitlines() + assert len(out_lines) == num_origins + 1 + + fields = out_lines[0].split(",") + returned_origins = [dict(zip(fields, line.split(","))) for line in out_lines[1:]] + + # Check that we've received origins we had listed in the first place + assert set(origin["url"] for origin in returned_origins) <= set( + origin.url for origin in listed_origins + )