diff --git a/swh/scheduler/cli/origin.py b/swh/scheduler/cli/origin.py index 2771aca..64df67a 100644 --- a/swh/scheduler/cli/origin.py +++ b/swh/scheduler/cli/origin.py @@ -1,102 +1,136 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from __future__ import annotations from typing import TYPE_CHECKING, Iterable, List, Optional import click from . import cli if TYPE_CHECKING: from ..model import ListedOrigin @cli.group("origin") @click.pass_context def origin(ctx): """Manipulate listed origins.""" if not ctx.obj["scheduler"]: raise ValueError("Scheduler class (local/remote) must be instantiated") def format_origins( origins: List[ListedOrigin], fields: Optional[List[str]] = None, with_header: bool = True, ) -> Iterable[str]: """Format a list of origins as CSV. Arguments: origins: list of origins to output fields: optional list of fields to output (defaults to all fields) with_header: if True, output a CSV header. """ import csv from io import StringIO import attr from ..model import ListedOrigin expected_fields = [field.name for field in attr.fields(ListedOrigin)] if not fields: fields = expected_fields unknown_fields = set(fields) - set(expected_fields) if unknown_fields: raise ValueError( "Unknown ListedOrigin field(s): %s" % ", ".join(unknown_fields) ) output = StringIO() writer = csv.writer(output) def csv_row(data): """Return a single CSV-formatted row. We clear the output buffer after we're done to keep it reasonably sized.""" writer.writerow(data) output.seek(0) ret = output.read().rstrip() output.seek(0) output.truncate() return ret if with_header: yield csv_row(fields) for origin in origins: yield csv_row(str(getattr(origin, field)) for field in fields) @origin.command("grab-next") @click.option( "--policy", "-p", default="oldest_scheduled_first", help="Scheduling policy" ) @click.option( "--fields", "-f", default=None, help="Listed origin fields to print on output" ) @click.option( "--with-header/--without-header", is_flag=True, default=True, help="Print the CSV header?", ) @click.argument("count", type=int) @click.pass_context def grab_next(ctx, policy: str, fields: Optional[str], with_header: bool, count: int): """Grab the next COUNT origins to visit from the listed origins table.""" if fields: parsed_fields: Optional[List[str]] = fields.split(",") else: parsed_fields = None scheduler = ctx.obj["scheduler"] origins = scheduler.grab_next_visits(count, policy=policy) for line in format_origins(origins, fields=parsed_fields, with_header=with_header): click.echo(line) + + +@origin.command("schedule-next") +@click.option( + "--policy", "-p", default="oldest_scheduled_first", help="Scheduling policy" +) +@click.argument("count", type=int) +@click.pass_context +def schedule_next(ctx, policy: str, count: int): + """Send the next COUNT origin visits to the scheduler as one-shot tasks.""" + from ..utils import utcnow + from .task import pretty_print_task + + scheduler = ctx.obj["scheduler"] + + origins = scheduler.grab_next_visits(count, policy=policy) + + created = scheduler.create_tasks( + [ + { + **origin.as_task_dict(), + "policy": "oneshot", + "next_run": utcnow(), + "retries_left": 1, + } + for origin in origins + ] + ) + + output = ["Created %d tasks\n" % len(created)] + for task in created: + output.append(pretty_print_task(task)) + + click.echo_via_pager("\n".join(output)) diff --git a/swh/scheduler/model.py b/swh/scheduler/model.py index f404ff9..d425b9e 100644 --- a/swh/scheduler/model.py +++ b/swh/scheduler/model.py @@ -1,235 +1,244 @@ # Copyright (C) 2020-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime from typing import Any, Dict, List, Optional, Tuple, Union from uuid import UUID import attr import attr.converters from attrs_strict import type_validator def check_timestamptz(value) -> None: """Checks the date has a timezone.""" if value is not None and value.tzinfo is None: raise ValueError("date must be a timezone-aware datetime.") @attr.s class BaseSchedulerModel: """Base class for database-backed objects. These database-backed objects are defined through attrs-based attributes that match the columns of the database 1:1. This is a (very) lightweight ORM. These attrs-based attributes have metadata specific to the functionality expected from these fields in the database: - `primary_key`: the column is a primary key; it should be filtered out when doing an `update` of the object - `auto_primary_key`: the column is a primary key, which is automatically handled by the database. It will not be inserted to. This must be matched with a database-side default value. - `auto_now_add`: the column is a timestamp that is set to the current time when the object is inserted, and never updated afterwards. This must be matched with a database-side default value. - `auto_now`: the column is a timestamp that is set to the current time when the object is inserted or updated. """ _pk_cols: Optional[Tuple[str, ...]] = None _select_cols: Optional[Tuple[str, ...]] = None _insert_cols_and_metavars: Optional[Tuple[Tuple[str, ...], Tuple[str, ...]]] = None @classmethod def primary_key_columns(cls) -> Tuple[str, ...]: """Get the primary key columns for this object type""" if cls._pk_cols is None: columns: List[str] = [] for field in attr.fields(cls): if any( field.metadata.get(flag) for flag in ("auto_primary_key", "primary_key") ): columns.append(field.name) cls._pk_cols = tuple(sorted(columns)) return cls._pk_cols @classmethod def select_columns(cls) -> Tuple[str, ...]: """Get all the database columns needed for a `select` on this object type""" if cls._select_cols is None: columns: List[str] = [] for field in attr.fields(cls): columns.append(field.name) cls._select_cols = tuple(sorted(columns)) return cls._select_cols @classmethod def insert_columns_and_metavars(cls) -> Tuple[Tuple[str, ...], Tuple[str, ...]]: """Get the database columns and metavars needed for an `insert` or `update` on this object type. This implements support for the `auto_*` field metadata attributes. """ if cls._insert_cols_and_metavars is None: zipped_cols_and_metavars: List[Tuple[str, str]] = [] for field in attr.fields(cls): if any( field.metadata.get(flag) for flag in ("auto_now_add", "auto_primary_key") ): continue elif field.metadata.get("auto_now"): zipped_cols_and_metavars.append((field.name, "now()")) else: zipped_cols_and_metavars.append((field.name, f"%({field.name})s")) zipped_cols_and_metavars.sort() cols, metavars = zip(*zipped_cols_and_metavars) cls._insert_cols_and_metavars = cols, metavars return cls._insert_cols_and_metavars @attr.s class Lister(BaseSchedulerModel): name = attr.ib(type=str, validator=[type_validator()]) instance_name = attr.ib(type=str, validator=[type_validator()]) # Populated by database id = attr.ib( type=Optional[UUID], validator=type_validator(), default=None, metadata={"auto_primary_key": True}, ) current_state = attr.ib( type=Dict[str, Any], validator=[type_validator()], factory=dict ) created = attr.ib( type=Optional[datetime.datetime], validator=[type_validator()], default=None, metadata={"auto_now_add": True}, ) updated = attr.ib( type=Optional[datetime.datetime], validator=[type_validator()], default=None, metadata={"auto_now": True}, ) @attr.s class ListedOrigin(BaseSchedulerModel): """Basic information about a listed origin, output by a lister""" lister_id = attr.ib( type=UUID, validator=[type_validator()], metadata={"primary_key": True} ) url = attr.ib( type=str, validator=[type_validator()], metadata={"primary_key": True} ) visit_type = attr.ib( type=str, validator=[type_validator()], metadata={"primary_key": True} ) extra_loader_arguments = attr.ib( type=Dict[str, str], validator=[type_validator()], factory=dict ) last_update = attr.ib( type=Optional[datetime.datetime], validator=[type_validator()], default=None, ) last_scheduled = attr.ib( type=Optional[datetime.datetime], validator=[type_validator()], default=None, ) enabled = attr.ib(type=bool, validator=[type_validator()], default=True) first_seen = attr.ib( type=Optional[datetime.datetime], validator=[type_validator()], default=None, metadata={"auto_now_add": True}, ) last_seen = attr.ib( type=Optional[datetime.datetime], validator=[type_validator()], default=None, metadata={"auto_now": True}, ) + def as_task_dict(self): + return { + "type": f"load-{self.visit_type}", + "arguments": { + "args": [], + "kwargs": {"url": self.url, **self.extra_loader_arguments}, + }, + } + ListedOriginPageToken = Tuple[UUID, str] def convert_listed_origin_page_token( input: Union[None, ListedOriginPageToken, List[Union[UUID, str]]] ) -> Optional[ListedOriginPageToken]: if input is None: return None if isinstance(input, tuple): return input x, y = input assert isinstance(x, UUID) assert isinstance(y, str) return (x, y) @attr.s class PaginatedListedOriginList(BaseSchedulerModel): """A list of listed origins, with a continuation token""" origins = attr.ib(type=List[ListedOrigin], validator=[type_validator()]) next_page_token = attr.ib( type=Optional[ListedOriginPageToken], validator=[type_validator()], converter=convert_listed_origin_page_token, default=None, ) @attr.s(frozen=True, slots=True) class OriginVisitStats(BaseSchedulerModel): """Represents an aggregated origin visits view. """ url = attr.ib( type=str, validator=[type_validator()], metadata={"primary_key": True} ) visit_type = attr.ib( type=str, validator=[type_validator()], metadata={"primary_key": True} ) last_eventful = attr.ib( type=Optional[datetime.datetime], validator=type_validator() ) last_uneventful = attr.ib( type=Optional[datetime.datetime], validator=type_validator() ) last_failed = attr.ib(type=Optional[datetime.datetime], validator=type_validator()) @last_eventful.validator def check_last_eventful(self, attribute, value): check_timestamptz(value) @last_uneventful.validator def check_last_uneventful(self, attribute, value): check_timestamptz(value) @last_failed.validator def check_last_failed(self, attribute, value): check_timestamptz(value) diff --git a/swh/scheduler/tests/test_cli_origin.py b/swh/scheduler/tests/test_cli_origin.py index 87812ad..b570484 100644 --- a/swh/scheduler/tests/test_cli_origin.py +++ b/swh/scheduler/tests/test_cli_origin.py @@ -1,76 +1,103 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from typing import Tuple import pytest from swh.scheduler.cli.origin import format_origins +from swh.scheduler.tests.common import TASK_TYPES from swh.scheduler.tests.test_cli import invoke as basic_invoke def invoke(scheduler, args: Tuple[str, ...] = (), catch_exceptions: bool = False): return basic_invoke( scheduler, args=["origin", *args], catch_exceptions=catch_exceptions ) def test_cli_origin(swh_scheduler): """Check that swh scheduler origin returns its help text""" result = invoke(swh_scheduler) assert "Commands:" in result.stdout def test_format_origins_basic(listed_origins): listed_origins = listed_origins[:100] basic_output = list(format_origins(listed_origins)) # 1 header line + all origins assert len(basic_output) == len(listed_origins) + 1 no_header_output = list(format_origins(listed_origins, with_header=False)) assert basic_output[1:] == no_header_output def test_format_origins_fields_unknown(listed_origins): listed_origins = listed_origins[:10] it = format_origins(listed_origins, fields=["unknown_field"]) with pytest.raises(ValueError, match="unknown_field"): next(it) def test_format_origins_fields(listed_origins): listed_origins = listed_origins[:10] fields = ["lister_id", "url", "visit_type"] output = list(format_origins(listed_origins, fields=fields)) assert output[0] == ",".join(fields) for i, origin in enumerate(listed_origins): assert output[i + 1] == f"{origin.lister_id},{origin.url},{origin.visit_type}" def test_grab_next(swh_scheduler, listed_origins): num_origins = 10 assert len(listed_origins) >= num_origins swh_scheduler.record_listed_origins(listed_origins) result = invoke(swh_scheduler, args=("grab-next", str(num_origins))) assert result.exit_code == 0 out_lines = result.stdout.splitlines() assert len(out_lines) == num_origins + 1 fields = out_lines[0].split(",") returned_origins = [dict(zip(fields, line.split(","))) for line in out_lines[1:]] # Check that we've received origins we had listed in the first place assert set(origin["url"] for origin in returned_origins) <= set( origin.url for origin in listed_origins ) + + +def test_schedule_next(swh_scheduler, listed_origins): + for task_type in TASK_TYPES.values(): + swh_scheduler.create_task_type(task_type) + + num_origins = 10 + assert len(listed_origins) >= num_origins + + swh_scheduler.record_listed_origins(listed_origins) + + result = invoke(swh_scheduler, args=("schedule-next", str(num_origins))) + assert result.exit_code == 0 + + # pull all tasks out of the scheduler + tasks = swh_scheduler.search_tasks() + assert len(tasks) == num_origins + + scheduled_tasks = { + (task["type"], task["arguments"]["kwargs"]["url"]) for task in tasks + } + all_possible_tasks = { + (f"load-{origin.visit_type}", origin.url) for origin in listed_origins + } + + assert scheduled_tasks <= all_possible_tasks diff --git a/swh/scheduler/tests/test_model.py b/swh/scheduler/tests/test_model.py index 47bb618..1293e64 100644 --- a/swh/scheduler/tests/test_model.py +++ b/swh/scheduler/tests/test_model.py @@ -1,94 +1,123 @@ -# Copyright (C) 2020 The Software Heritage developers +# Copyright (C) 2020-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime +import uuid import attr from swh.scheduler import model def test_select_columns(): @attr.s class TestModel(model.BaseSchedulerModel): id = attr.ib(type=str) test1 = attr.ib(type=str) a_first_attr = attr.ib(type=str) @property def test2(self): """This property should not show up in the extracted columns""" return self.test1 assert TestModel.select_columns() == ("a_first_attr", "id", "test1") def test_insert_columns(): @attr.s class TestModel(model.BaseSchedulerModel): id = attr.ib(type=str) test1 = attr.ib(type=str) @property def test2(self): """This property should not show up in the extracted columns""" return self.test1 assert TestModel.insert_columns_and_metavars() == ( ("id", "test1"), ("%(id)s", "%(test1)s"), ) def test_insert_columns_auto_now_add(): @attr.s class TestModel(model.BaseSchedulerModel): id = attr.ib(type=str) test1 = attr.ib(type=str) added = attr.ib(type=datetime.datetime, metadata={"auto_now_add": True}) assert TestModel.insert_columns_and_metavars() == ( ("id", "test1"), ("%(id)s", "%(test1)s"), ) def test_insert_columns_auto_now(): @attr.s class TestModel(model.BaseSchedulerModel): id = attr.ib(type=str) test1 = attr.ib(type=str) updated = attr.ib(type=datetime.datetime, metadata={"auto_now": True}) assert TestModel.insert_columns_and_metavars() == ( ("id", "test1", "updated"), ("%(id)s", "%(test1)s", "now()"), ) def test_insert_columns_primary_key(): @attr.s class TestModel(model.BaseSchedulerModel): id = attr.ib(type=str, metadata={"auto_primary_key": True}) test1 = attr.ib(type=str) assert TestModel.insert_columns_and_metavars() == (("test1",), ("%(test1)s",)) def test_insert_primary_key(): @attr.s class TestModel(model.BaseSchedulerModel): id = attr.ib(type=str, metadata={"auto_primary_key": True}) test1 = attr.ib(type=str) assert TestModel.primary_key_columns() == ("id",) @attr.s class TestModel2(model.BaseSchedulerModel): col1 = attr.ib(type=str, metadata={"primary_key": True}) col2 = attr.ib(type=str, metadata={"primary_key": True}) test1 = attr.ib(type=str) assert TestModel2.primary_key_columns() == ("col1", "col2") + + +def test_listed_origin_as_task_dict(): + origin = model.ListedOrigin( + lister_id=uuid.uuid4(), url="http://example.com/", visit_type="git", + ) + + task = origin.as_task_dict() + assert task == { + "type": "load-git", + "arguments": {"args": [], "kwargs": {"url": "http://example.com/"}}, + } + + origin_w_args = model.ListedOrigin( + lister_id=uuid.uuid4(), + url="http://example.com/svn/", + visit_type="svn", + extra_loader_arguments={"foo": "bar"}, + ) + + task_w_args = origin_w_args.as_task_dict() + assert task_w_args == { + "type": "load-svn", + "arguments": { + "args": [], + "kwargs": {"url": "http://example.com/svn/", "foo": "bar"}, + }, + }