diff --git a/swh/scheduler/__init__.py b/swh/scheduler/__init__.py
index fbb52cd..b1f98ba 100644
--- a/swh/scheduler/__init__.py
+++ b/swh/scheduler/__init__.py
@@ -1,73 +1,76 @@
# Copyright (C) 2018-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from __future__ import annotations
from importlib import import_module
from typing import TYPE_CHECKING, Any, Dict
import warnings
DEFAULT_CONFIG = {
"scheduler": (
"dict",
- {"cls": "local", "db": "dbname=softwareheritage-scheduler-dev",},
+ {
+ "cls": "local",
+ "db": "dbname=softwareheritage-scheduler-dev",
+ },
)
}
# current configuration. To be set by the config loading mechanism
CONFIG = {} # type: Dict[str, Any]
if TYPE_CHECKING:
from swh.scheduler.interface import SchedulerInterface
BACKEND_TYPES: Dict[str, str] = {
"postgresql": ".backend.SchedulerBackend",
"remote": ".api.client.RemoteScheduler",
# deprecated
"local": ".backend.SchedulerBackend",
}
def get_scheduler(cls: str, **kwargs) -> SchedulerInterface:
"""
Get a scheduler object of class `cls` with arguments `**kwargs`.
Args:
cls: scheduler's class, either 'local' or 'remote'
kwargs: arguments to pass to the class' constructor
Returns:
an instance of swh.scheduler, either local or remote:
local: swh.scheduler.backend.SchedulerBackend
remote: swh.scheduler.api.client.RemoteScheduler
Raises:
ValueError if passed an unknown storage class.
"""
if "args" in kwargs:
warnings.warn(
'Explicit "args" key is deprecated, use keys directly instead.',
DeprecationWarning,
)
kwargs = kwargs["args"]
class_path = BACKEND_TYPES.get(cls)
if class_path is None:
raise ValueError(
f"Unknown Scheduler class `{cls}`. "
f"Supported: {', '.join(BACKEND_TYPES)}"
)
(module_path, class_name) = class_path.rsplit(".", 1)
module = import_module(module_path, package=__package__)
BackendClass = getattr(module, class_name)
return BackendClass(**kwargs)
get_datastore = get_scheduler
diff --git a/swh/scheduler/api/client.py b/swh/scheduler/api/client.py
index 8a36d05..6760750 100644
--- a/swh/scheduler/api/client.py
+++ b/swh/scheduler/api/client.py
@@ -1,24 +1,22 @@
# Copyright (C) 2018-2019 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from swh.core.api import RPCClient
from .. import exc
from ..interface import SchedulerInterface
from .serializers import DECODERS, ENCODERS
class RemoteScheduler(RPCClient):
- """Proxy to a remote scheduler API
-
- """
+ """Proxy to a remote scheduler API"""
backend_class = SchedulerInterface
reraise_exceptions = [getattr(exc, a) for a in exc.__all__]
extra_type_decoders = DECODERS
extra_type_encoders = ENCODERS
diff --git a/swh/scheduler/api/server.py b/swh/scheduler/api/server.py
index f1654a6..5854095 100644
--- a/swh/scheduler/api/server.py
+++ b/swh/scheduler/api/server.py
@@ -1,150 +1,150 @@
# Copyright (C) 2018-2019 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import logging
import os
from swh.core import config
from swh.core.api import JSONFormatter, MsgpackFormatter, RPCServerApp
from swh.core.api import encode_data_server as encode_data
from swh.core.api import error_handler, negotiate
from swh.scheduler import get_scheduler
from swh.scheduler.exc import SchedulerException
from swh.scheduler.interface import SchedulerInterface
from .serializers import DECODERS, ENCODERS
scheduler = None
def get_global_scheduler():
global scheduler
if not scheduler:
scheduler = get_scheduler(**app.config["scheduler"])
return scheduler
class SchedulerServerApp(RPCServerApp):
extra_type_decoders = DECODERS
extra_type_encoders = ENCODERS
app = SchedulerServerApp(
__name__, backend_class=SchedulerInterface, backend_factory=get_global_scheduler
)
@app.errorhandler(SchedulerException)
def argument_error_handler(exception):
return error_handler(exception, encode_data, status_code=400)
@app.errorhandler(Exception)
def my_error_handler(exception):
return error_handler(exception, encode_data)
def has_no_empty_params(rule):
return len(rule.defaults or ()) >= len(rule.arguments or ())
@app.route("/")
def index():
return """
Software Heritage scheduler RPC server
You have reached the
Software Heritage
scheduler RPC server.
See its
documentation
and API for more information
"""
@app.route("/site-map")
@negotiate(MsgpackFormatter)
@negotiate(JSONFormatter)
def site_map():
links = []
for rule in app.url_map.iter_rules():
if has_no_empty_params(rule) and hasattr(SchedulerInterface, rule.endpoint):
links.append(
dict(
rule=rule.rule,
description=getattr(SchedulerInterface, rule.endpoint).__doc__,
)
)
# links is now a list of url, endpoint tuples
return links
def load_and_check_config(config_path, type="local"):
"""Check the minimal configuration is set to run the api or raise an
error explanation.
Args:
config_path (str): Path to the configuration file to load
type (str): configuration type. For 'local' type, more
checks are done.
Raises:
Error if the setup is not as expected
Returns:
configuration as a dict
"""
if not config_path:
raise EnvironmentError("Configuration file must be defined")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Configuration file {config_path} does not exist")
cfg = config.read(config_path)
vcfg = cfg.get("scheduler")
if not vcfg:
raise KeyError("Missing '%scheduler' configuration")
if type == "local":
cls = vcfg.get("cls")
if cls != "local":
raise ValueError(
"The scheduler backend can only be started with a 'local' "
"configuration"
)
db = vcfg.get("db")
if not db:
raise KeyError("Invalid configuration; missing 'db' config entry")
return cfg
api_cfg = None
def make_app_from_configfile():
"""Run the WSGI app from the webserver, loading the configuration from
- a configuration file.
+ a configuration file.
- SWH_CONFIG_FILENAME environment variable defines the
- configuration path to load.
+ SWH_CONFIG_FILENAME environment variable defines the
+ configuration path to load.
"""
global api_cfg
if not api_cfg:
config_path = os.environ.get("SWH_CONFIG_FILENAME")
api_cfg = load_and_check_config(config_path)
app.config.update(api_cfg)
handler = logging.StreamHandler()
app.logger.addHandler(handler)
return app
if __name__ == "__main__":
print('Please use the "swh-scheduler api-server" command')
diff --git a/swh/scheduler/backend.py b/swh/scheduler/backend.py
index 0b3c3f8..91a6572 100644
--- a/swh/scheduler/backend.py
+++ b/swh/scheduler/backend.py
@@ -1,1119 +1,1125 @@
# Copyright (C) 2015-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import datetime
import json
import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from uuid import UUID
import attr
from psycopg2.errors import CardinalityViolation
from psycopg2.extensions import AsIs
import psycopg2.extras
import psycopg2.pool
from swh.core.db import BaseDb
from swh.core.db.common import db_transaction
from swh.scheduler.utils import utcnow
from .exc import SchedulerException, StaleData, UnknownPolicy
from .interface import ListedOriginPageToken, PaginatedListedOriginList
from .model import (
LastVisitStatus,
ListedOrigin,
Lister,
OriginVisitStats,
SchedulerMetrics,
)
logger = logging.getLogger(__name__)
def adapt_LastVisitStatus(v: LastVisitStatus):
return AsIs(f"'{v.value}'::last_visit_status")
psycopg2.extensions.register_adapter(dict, psycopg2.extras.Json)
psycopg2.extensions.register_adapter(LastVisitStatus, adapt_LastVisitStatus)
psycopg2.extras.register_uuid()
def format_query(query, keys):
"""Format a query with the given keys"""
query_keys = ", ".join(keys)
placeholders = ", ".join(["%s"] * len(keys))
return query.format(keys=query_keys, placeholders=placeholders)
class SchedulerBackend:
- """Backend for the Software Heritage scheduling database.
-
- """
+ """Backend for the Software Heritage scheduling database."""
current_version = 33
def __init__(self, db, min_pool_conns=1, max_pool_conns=10):
"""
Args:
db_conn: either a libpq connection string, or a psycopg2 connection
"""
if isinstance(db, psycopg2.extensions.connection):
self._pool = None
self._db = BaseDb(db)
else:
self._pool = psycopg2.pool.ThreadedConnectionPool(
min_pool_conns,
max_pool_conns,
db,
cursor_factory=psycopg2.extras.RealDictCursor,
)
self._db = None
def get_current_version(self):
return self.current_version
def get_db(self):
if self._db:
return self._db
return BaseDb.from_pool(self._pool)
def put_db(self, db):
if db is not self._db:
db.put_conn()
task_type_keys = [
"type",
"description",
"backend_name",
"default_interval",
"min_interval",
"max_interval",
"backoff_factor",
"max_queue_length",
"num_retries",
"retry_delay",
]
@db_transaction()
def create_task_type(self, task_type, db=None, cur=None):
"""Create a new task type ready for scheduling.
Args:
task_type (dict): a dictionary with the following keys:
- type (str): an identifier for the task type
- description (str): a human-readable description of what the
task does
- backend_name (str): the name of the task in the
job-scheduling backend
- default_interval (datetime.timedelta): the default interval
between two task runs
- min_interval (datetime.timedelta): the minimum interval
between two task runs
- max_interval (datetime.timedelta): the maximum interval
between two task runs
- backoff_factor (float): the factor by which the interval
changes at each run
- max_queue_length (int): the maximum length of the task queue
for this task type
"""
keys = [key for key in self.task_type_keys if key in task_type]
query = format_query(
"""insert into task_type ({keys}) values ({placeholders})
on conflict do nothing""",
keys,
)
cur.execute(query, [task_type[key] for key in keys])
@db_transaction()
def get_task_type(self, task_type_name, db=None, cur=None):
"""Retrieve the task type with id task_type_name"""
query = format_query(
- "select {keys} from task_type where type=%s", self.task_type_keys,
+ "select {keys} from task_type where type=%s",
+ self.task_type_keys,
)
cur.execute(query, (task_type_name,))
return cur.fetchone()
@db_transaction()
def get_task_types(self, db=None, cur=None):
"""Retrieve all registered task types"""
- query = format_query("select {keys} from task_type", self.task_type_keys,)
+ query = format_query(
+ "select {keys} from task_type",
+ self.task_type_keys,
+ )
cur.execute(query)
return cur.fetchall()
@db_transaction()
def get_listers(self, db=None, cur=None) -> List[Lister]:
- """Retrieve information about all listers from the database.
- """
+ """Retrieve information about all listers from the database."""
select_cols = ", ".join(Lister.select_columns())
query = f"""
select {select_cols} from listers
"""
cur.execute(query)
return [Lister(**ret) for ret in cur.fetchall()]
@db_transaction()
def get_lister(
self, name: str, instance_name: Optional[str] = None, db=None, cur=None
) -> Optional[Lister]:
"""Retrieve information about the given instance of the lister from the
database.
"""
if instance_name is None:
instance_name = ""
select_cols = ", ".join(Lister.select_columns())
query = f"""
select {select_cols} from listers
where (name, instance_name) = (%s, %s)
"""
cur.execute(query, (name, instance_name))
ret = cur.fetchone()
if not ret:
return None
return Lister(**ret)
@db_transaction()
def get_or_create_lister(
self, name: str, instance_name: Optional[str] = None, db=None, cur=None
) -> Lister:
"""Retrieve information about the given instance of the lister from the
database, or create the entry if it did not exist.
"""
if instance_name is None:
instance_name = ""
select_cols = ", ".join(Lister.select_columns())
insert_cols, insert_meta = (
", ".join(tup) for tup in Lister.insert_columns_and_metavars()
)
query = f"""
with added as (
insert into listers ({insert_cols}) values ({insert_meta})
on conflict do nothing
returning {select_cols}
)
select {select_cols} from added
union all
select {select_cols} from listers
where (name, instance_name) = (%(name)s, %(instance_name)s);
"""
cur.execute(query, attr.asdict(Lister(name=name, instance_name=instance_name)))
return Lister(**cur.fetchone())
@db_transaction()
def update_lister(self, lister: Lister, db=None, cur=None) -> Lister:
"""Update the state for the given lister instance in the database.
Returns:
a new Lister object, with all fields updated from the database
Raises:
StaleData if the `updated` timestamp for the lister instance in
database doesn't match the one passed by the user.
"""
select_cols = ", ".join(Lister.select_columns())
set_vars = ", ".join(
f"{col} = {meta}"
for col, meta in zip(*Lister.insert_columns_and_metavars())
)
query = f"""update listers
set {set_vars}
where id=%(id)s and updated=%(updated)s
returning {select_cols}"""
cur.execute(query, attr.asdict(lister))
updated = cur.fetchone()
if not updated:
raise StaleData("Stale data; Lister state not updated")
return Lister(**updated)
@db_transaction()
def record_listed_origins(
self, listed_origins: Iterable[ListedOrigin], db=None, cur=None
) -> List[ListedOrigin]:
"""Record a set of origins that a lister has listed.
This performs an "upsert": origins with the same (lister_id, url,
visit_type) values are updated with new values for
extra_loader_arguments, last_update and last_seen.
"""
pk_cols = ListedOrigin.primary_key_columns()
select_cols = ListedOrigin.select_columns()
insert_cols, insert_meta = ListedOrigin.insert_columns_and_metavars()
deduplicated_origins = {
tuple(getattr(origin, k) for k in pk_cols): origin
for origin in listed_origins
}
upsert_cols = [col for col in insert_cols if col not in pk_cols]
upsert_set = ", ".join(f"{col} = EXCLUDED.{col}" for col in upsert_cols)
query = f"""INSERT into listed_origins ({", ".join(insert_cols)})
VALUES %s
ON CONFLICT ({", ".join(pk_cols)}) DO UPDATE
SET {upsert_set}
RETURNING {", ".join(select_cols)}
"""
ret = psycopg2.extras.execute_values(
cur=cur,
sql=query,
argslist=(attr.asdict(origin) for origin in deduplicated_origins.values()),
template=f"({', '.join(insert_meta)})",
page_size=1000,
fetch=True,
)
return [ListedOrigin(**d) for d in ret]
@db_transaction()
def get_listed_origins(
self,
lister_id: Optional[UUID] = None,
url: Optional[str] = None,
limit: int = 1000,
page_token: Optional[ListedOriginPageToken] = None,
db=None,
cur=None,
) -> PaginatedListedOriginList:
"""Get information on the listed origins matching either the `url` or
`lister_id`, or both arguments.
"""
query_filters: List[str] = []
query_params: List[Union[int, str, UUID, Tuple[UUID, str]]] = []
if lister_id:
query_filters.append("lister_id = %s")
query_params.append(lister_id)
if url is not None:
query_filters.append("url = %s")
query_params.append(url)
if page_token is not None:
query_filters.append("(lister_id, url) > %s")
# the typeshed annotation for tuple() is too strict.
query_params.append(tuple(page_token)) # type: ignore
query_params.append(limit)
select_cols = ", ".join(ListedOrigin.select_columns())
if query_filters:
where_clause = "where %s" % (" and ".join(query_filters))
else:
where_clause = ""
query = f"""SELECT {select_cols}
from listed_origins
{where_clause}
ORDER BY lister_id, url
LIMIT %s"""
cur.execute(query, tuple(query_params))
origins = [ListedOrigin(**d) for d in cur]
if len(origins) == limit:
page_token = (str(origins[-1].lister_id), origins[-1].url)
else:
page_token = None
return PaginatedListedOriginList(origins, page_token)
@db_transaction()
def grab_next_visits(
self,
visit_type: str,
count: int,
policy: str,
enabled: bool = True,
lister_uuid: Optional[str] = None,
timestamp: Optional[datetime.datetime] = None,
scheduled_cooldown: Optional[datetime.timedelta] = datetime.timedelta(days=7),
failed_cooldown: Optional[datetime.timedelta] = datetime.timedelta(days=14),
not_found_cooldown: Optional[datetime.timedelta] = datetime.timedelta(days=31),
tablesample: Optional[float] = None,
db=None,
cur=None,
) -> List[ListedOrigin]:
if timestamp is None:
timestamp = utcnow()
origin_select_cols = ", ".join(ListedOrigin.select_columns())
query_args: List[Any] = []
where_clauses = []
# list of (name, query) handled as CTEs before the main query
common_table_expressions: List[Tuple[str, str]] = []
# "NOT enabled" = the lister said the origin no longer exists
where_clauses.append("enabled" if enabled else "not enabled")
# Only schedule visits of the given type
where_clauses.append("visit_type = %s")
query_args.append(visit_type)
if scheduled_cooldown:
# Don't re-schedule visits if they're already scheduled but we haven't
# recorded a result yet, unless they've been scheduled more than a week
# ago (it probably means we've lost them in flight somewhere).
where_clauses.append(
"""origin_visit_stats.last_scheduled IS NULL
OR origin_visit_stats.last_scheduled < GREATEST(
%s,
origin_visit_stats.last_visit
)
"""
)
query_args.append(timestamp - scheduled_cooldown)
if failed_cooldown:
# Don't retry failed origins too often
where_clauses.append(
"origin_visit_stats.last_visit_status is distinct from 'failed' "
"or origin_visit_stats.last_visit < %s"
)
query_args.append(timestamp - failed_cooldown)
if not_found_cooldown:
# Don't retry not found origins too often
where_clauses.append(
"origin_visit_stats.last_visit_status is distinct from 'not_found' "
"or origin_visit_stats.last_visit < %s"
)
query_args.append(timestamp - not_found_cooldown)
if policy == "oldest_scheduled_first":
order_by = "origin_visit_stats.last_scheduled NULLS FIRST"
elif policy == "never_visited_oldest_update_first":
# never visited origins have a NULL last_snapshot
where_clauses.append("origin_visit_stats.last_snapshot IS NULL")
# order by increasing last_update (oldest first)
where_clauses.append("listed_origins.last_update IS NOT NULL")
order_by = "listed_origins.last_update"
elif policy == "already_visited_order_by_lag":
# TODO: store "visit lag" in a materialized view?
# visited origins have a NOT NULL last_snapshot
where_clauses.append("origin_visit_stats.last_snapshot IS NOT NULL")
# ignore origins we have visited after the known last update
where_clauses.append("listed_origins.last_update IS NOT NULL")
where_clauses.append(
"listed_origins.last_update > origin_visit_stats.last_successful"
)
# order by decreasing visit lag
order_by = (
"listed_origins.last_update - origin_visit_stats.last_successful DESC"
)
elif policy == "origins_without_last_update":
where_clauses.append("last_update IS NULL")
order_by = ", ".join(
[
# By default, sort using the queue position. If the queue
# position is null, then the origin has never been visited,
# which we want to handle first
"origin_visit_stats.next_visit_queue_position nulls first",
# Schedule unknown origins in the order we've seen them
"listed_origins.first_seen",
]
)
# fmt: off
# This policy requires updating the global queue position for this
# visit type
common_table_expressions.append(("update_queue_position", """
INSERT INTO
visit_scheduler_queue_position(visit_type, position)
SELECT
visit_type, COALESCE(MAX(next_visit_queue_position), 0)
FROM selected_origins
GROUP BY visit_type
ON CONFLICT(visit_type) DO UPDATE
SET position=GREATEST(
visit_scheduler_queue_position.position, EXCLUDED.position
)
"""))
# fmt: on
else:
raise UnknownPolicy(f"Unknown scheduling policy {policy}")
if tablesample:
table = "listed_origins tablesample SYSTEM (%s)"
query_args.insert(0, tablesample)
else:
table = "listed_origins"
if lister_uuid:
where_clauses.append("lister_id = %s")
query_args.append(lister_uuid)
# fmt: off
common_table_expressions.insert(0, ("selected_origins", f"""
SELECT
{origin_select_cols}, next_visit_queue_position
FROM
{table}
LEFT JOIN
origin_visit_stats USING (url, visit_type)
WHERE
({") AND (".join(where_clauses)})
ORDER BY
{order_by}
LIMIT %s
"""))
# fmt: on
query_args.append(count)
# fmt: off
common_table_expressions.append(("deduplicated_selected_origins", """
SELECT DISTINCT
url, visit_type
FROM
selected_origins
"""))
# fmt: on
# fmt: off
common_table_expressions.append(("update_stats", """
INSERT INTO
origin_visit_stats (url, visit_type, last_scheduled)
SELECT
url, visit_type, %s
FROM
deduplicated_selected_origins
ON CONFLICT (url, visit_type) DO UPDATE
SET last_scheduled = GREATEST(
origin_visit_stats.last_scheduled,
EXCLUDED.last_scheduled
)
"""))
# fmt: on
query_args.append(timestamp)
formatted_ctes = ",\n".join(
f"{name} AS (\n{cte}\n)" for name, cte in common_table_expressions
)
query = f"""
WITH
{formatted_ctes}
SELECT
{origin_select_cols}
FROM
selected_origins
"""
cur.execute(query, tuple(query_args))
return [ListedOrigin(**d) for d in cur]
task_create_keys = [
"type",
"arguments",
"next_run",
"policy",
"status",
"retries_left",
"priority",
]
task_keys = task_create_keys + ["id", "current_interval"]
@db_transaction()
def create_tasks(self, tasks, policy="recurring", db=None, cur=None):
"""Create new tasks.
Args:
tasks (list): each task is a dictionary with the following keys:
- type (str): the task type
- arguments (dict): the arguments for the task runner, keys:
- args (list of str): arguments
- kwargs (dict str -> str): keyword arguments
- next_run (datetime.datetime): the next scheduled run for the
task
Returns:
a list of created tasks.
"""
cur.execute("select swh_scheduler_mktemp_task()")
db.copy_to(
tasks,
"tmp_task",
self.task_create_keys,
default_values={"policy": policy, "status": "next_run_not_scheduled"},
cur=cur,
)
query = format_query(
- "select {keys} from swh_scheduler_create_tasks_from_temp()", self.task_keys,
+ "select {keys} from swh_scheduler_create_tasks_from_temp()",
+ self.task_keys,
)
cur.execute(query)
return cur.fetchall()
@db_transaction()
def set_status_tasks(
self,
task_ids: List[int],
status: str = "disabled",
next_run: Optional[datetime.datetime] = None,
db=None,
cur=None,
):
"""Set the tasks' status whose ids are listed.
If given, also set the next_run date.
"""
if not task_ids:
return
query = ["UPDATE task SET status = %s"]
args: List[Any] = [status]
if next_run:
query.append(", next_run = %s")
args.append(next_run)
query.append(" WHERE id IN %s")
args.append(tuple(task_ids))
cur.execute("".join(query), args)
@db_transaction()
def disable_tasks(self, task_ids, db=None, cur=None):
"""Disable the tasks whose ids are listed."""
return self.set_status_tasks(task_ids, db=db, cur=cur)
@db_transaction()
def search_tasks(
self,
task_id=None,
task_type=None,
status=None,
priority=None,
policy=None,
before=None,
after=None,
limit=None,
db=None,
cur=None,
):
"""Search tasks from selected criterions"""
where = []
args = []
if task_id:
if isinstance(task_id, (str, int)):
where.append("id = %s")
else:
where.append("id in %s")
task_id = tuple(task_id)
args.append(task_id)
if task_type:
if isinstance(task_type, str):
where.append("type = %s")
else:
where.append("type in %s")
task_type = tuple(task_type)
args.append(task_type)
if status:
if isinstance(status, str):
where.append("status = %s")
else:
where.append("status in %s")
status = tuple(status)
args.append(status)
if priority:
if isinstance(priority, str):
where.append("priority = %s")
else:
priority = tuple(priority)
where.append("priority in %s")
args.append(priority)
if policy:
where.append("policy = %s")
args.append(policy)
if before:
where.append("next_run <= %s")
args.append(before)
if after:
where.append("next_run >= %s")
args.append(after)
query = "select * from task"
if where:
query += " where " + " and ".join(where)
if limit:
query += " limit %s :: bigint"
args.append(limit)
cur.execute(query, args)
return cur.fetchall()
@db_transaction()
def get_tasks(self, task_ids, db=None, cur=None):
"""Retrieve the info of tasks whose ids are listed."""
query = format_query("select {keys} from task where id in %s", self.task_keys)
cur.execute(query, (tuple(task_ids),))
return cur.fetchall()
@db_transaction()
def peek_ready_tasks(
self,
task_type: str,
timestamp: Optional[datetime.datetime] = None,
num_tasks: Optional[int] = None,
db=None,
cur=None,
) -> List[Dict]:
if timestamp is None:
timestamp = utcnow()
cur.execute(
"""select * from swh_scheduler_peek_no_priority_tasks(
%s, %s, %s :: bigint)""",
(task_type, timestamp, num_tasks),
)
logger.debug("PEEK %s => %s" % (task_type, cur.rowcount))
return cur.fetchall()
@db_transaction()
def grab_ready_tasks(
self,
task_type: str,
timestamp: Optional[datetime.datetime] = None,
num_tasks: Optional[int] = None,
db=None,
cur=None,
) -> List[Dict]:
if timestamp is None:
timestamp = utcnow()
cur.execute(
"""select * from swh_scheduler_grab_ready_tasks(
%s, %s, %s :: bigint)""",
(task_type, timestamp, num_tasks),
)
logger.debug("GRAB %s => %s" % (task_type, cur.rowcount))
return cur.fetchall()
@db_transaction()
def peek_ready_priority_tasks(
self,
task_type: str,
timestamp: Optional[datetime.datetime] = None,
num_tasks: Optional[int] = None,
db=None,
cur=None,
) -> List[Dict]:
if timestamp is None:
timestamp = utcnow()
cur.execute(
"""select * from swh_scheduler_peek_any_ready_priority_tasks(
%s, %s, %s :: bigint)""",
(task_type, timestamp, num_tasks),
)
logger.debug("PEEK %s => %s", task_type, cur.rowcount)
return cur.fetchall()
@db_transaction()
def grab_ready_priority_tasks(
self,
task_type: str,
timestamp: Optional[datetime.datetime] = None,
num_tasks: Optional[int] = None,
db=None,
cur=None,
) -> List[Dict]:
if timestamp is None:
timestamp = utcnow()
cur.execute(
"""select * from swh_scheduler_grab_any_ready_priority_tasks(
%s, %s, %s :: bigint)""",
(task_type, timestamp, num_tasks),
)
logger.debug("GRAB %s => %s", task_type, cur.rowcount)
return cur.fetchall()
task_run_create_keys = ["task", "backend_id", "scheduled", "metadata"]
@db_transaction()
def schedule_task_run(
self, task_id, backend_id, metadata=None, timestamp=None, db=None, cur=None
):
"""Mark a given task as scheduled, adding a task_run entry in the database.
Args:
task_id (int): the identifier for the task being scheduled
backend_id (str): the identifier of the job in the backend
metadata (dict): metadata to add to the task_run entry
timestamp (datetime.datetime): the instant the event occurred
Returns:
a fresh task_run entry
"""
if metadata is None:
metadata = {}
if timestamp is None:
timestamp = utcnow()
cur.execute(
"select * from swh_scheduler_schedule_task_run(%s, %s, %s, %s)",
(task_id, backend_id, metadata, timestamp),
)
return cur.fetchone()
@db_transaction()
def mass_schedule_task_runs(self, task_runs, db=None, cur=None):
"""Schedule a bunch of task runs.
Args:
task_runs (list): a list of dicts with keys:
- task (int): the identifier for the task being scheduled
- backend_id (str): the identifier of the job in the backend
- metadata (dict): metadata to add to the task_run entry
- scheduled (datetime.datetime): the instant the event occurred
Returns:
None
"""
cur.execute("select swh_scheduler_mktemp_task_run()")
db.copy_to(task_runs, "tmp_task_run", self.task_run_create_keys, cur=cur)
cur.execute("select swh_scheduler_schedule_task_run_from_temp()")
@db_transaction()
def start_task_run(
self, backend_id, metadata=None, timestamp=None, db=None, cur=None
):
"""Mark a given task as started, updating the corresponding task_run
entry in the database.
Args:
backend_id (str): the identifier of the job in the backend
metadata (dict): metadata to add to the task_run entry
timestamp (datetime.datetime): the instant the event occurred
Returns:
the updated task_run entry
"""
if metadata is None:
metadata = {}
if timestamp is None:
timestamp = utcnow()
cur.execute(
"select * from swh_scheduler_start_task_run(%s, %s, %s)",
(backend_id, metadata, timestamp),
)
return cur.fetchone()
@db_transaction()
def end_task_run(
self,
backend_id,
status,
metadata=None,
timestamp=None,
result=None,
db=None,
cur=None,
):
"""Mark a given task as ended, updating the corresponding task_run entry in the
database.
Args:
backend_id (str): the identifier of the job in the backend
status (str): how the task ended; one of: 'eventful', 'uneventful',
'failed'
metadata (dict): metadata to add to the task_run entry
timestamp (datetime.datetime): the instant the event occurred
Returns:
the updated task_run entry
"""
if metadata is None:
metadata = {}
if timestamp is None:
timestamp = utcnow()
cur.execute(
"select * from swh_scheduler_end_task_run(%s, %s, %s, %s)",
(backend_id, status, metadata, timestamp),
)
return cur.fetchone()
@db_transaction()
def filter_task_to_archive(
self,
after_ts: str,
before_ts: str,
limit: int = 10,
page_token: Optional[str] = None,
db=None,
cur=None,
) -> Dict[str, Any]:
"""Compute the tasks to archive within the datetime interval
[after_ts, before_ts[. The method returns a paginated result.
Returns:
dict with the following keys:
- **next_page_token**: opaque token to be used as
`page_token` to retrieve the next page of result. If absent,
there is no more pages to gather.
- **tasks**: list of task dictionaries with the following keys:
**id** (str): origin task id
**started** (Optional[datetime]): started date
**scheduled** (datetime): scheduled date
**arguments** (json dict): task's arguments
...
"""
assert not page_token or isinstance(page_token, str)
last_id = -1 if page_token is None else int(page_token)
tasks = []
cur.execute(
"select * from swh_scheduler_task_to_archive(%s, %s, %s, %s)",
(after_ts, before_ts, last_id, limit + 1),
)
for row in cur:
task = dict(row)
# nested type index does not accept bare values
# transform it as a dict to comply with this
task["arguments"]["args"] = {
i: v for i, v in enumerate(task["arguments"]["args"])
}
kwargs = task["arguments"]["kwargs"]
task["arguments"]["kwargs"] = json.dumps(kwargs)
tasks.append(task)
if len(tasks) >= limit + 1: # remains data, add pagination information
result = {
"tasks": tasks[:limit],
"next_page_token": str(tasks[-1]["task_id"]),
}
else:
result = {"tasks": tasks}
return result
@db_transaction()
def delete_archived_tasks(self, task_ids, db=None, cur=None):
"""Delete archived tasks as much as possible. Only the task_ids whose
- complete associated task_run have been cleaned up will be.
+ complete associated task_run have been cleaned up will be.
"""
_task_ids = _task_run_ids = []
for task_id in task_ids:
_task_ids.append(task_id["task_id"])
_task_run_ids.append(task_id["task_run_id"])
cur.execute(
"select * from swh_scheduler_delete_archived_tasks(%s, %s)",
(_task_ids, _task_run_ids),
)
task_run_keys = [
"id",
"task",
"backend_id",
"scheduled",
"started",
"ended",
"metadata",
"status",
]
@db_transaction()
def get_task_runs(self, task_ids, limit=None, db=None, cur=None):
"""Search task run for a task id"""
where = []
args = []
if task_ids:
if isinstance(task_ids, (str, int)):
where.append("task = %s")
else:
where.append("task in %s")
task_ids = tuple(task_ids)
args.append(task_ids)
else:
return ()
query = "select * from task_run where " + " and ".join(where)
if limit:
query += " limit %s :: bigint"
args.append(limit)
cur.execute(query, args)
return cur.fetchall()
@db_transaction()
def origin_visit_stats_upsert(
self, origin_visit_stats: Iterable[OriginVisitStats], db=None, cur=None
) -> None:
pk_cols = OriginVisitStats.primary_key_columns()
insert_cols, insert_meta = OriginVisitStats.insert_columns_and_metavars()
upsert_cols = [col for col in insert_cols if col not in pk_cols]
upsert_set = ", ".join(
f"{col} = coalesce(EXCLUDED.{col}, ovi.{col})" for col in upsert_cols
)
query = f"""
INSERT into origin_visit_stats AS ovi ({", ".join(insert_cols)})
VALUES %s
ON CONFLICT ({", ".join(pk_cols)}) DO UPDATE
SET {upsert_set}
"""
try:
psycopg2.extras.execute_values(
cur=cur,
sql=query,
argslist=(
attr.asdict(visit_stats) for visit_stats in origin_visit_stats
),
template=f"({', '.join(insert_meta)})",
page_size=1000,
fetch=False,
)
except CardinalityViolation as e:
raise SchedulerException(repr(e))
@db_transaction()
def origin_visit_stats_get(
self, ids: Iterable[Tuple[str, str]], db=None, cur=None
) -> List[OriginVisitStats]:
if not ids:
return []
primary_keys = tuple((origin, visit_type) for (origin, visit_type) in ids)
query = format_query(
"""
SELECT {keys}
FROM (VALUES %s) as stats(url, visit_type)
INNER JOIN origin_visit_stats USING (url, visit_type)
""",
OriginVisitStats.select_columns(),
)
rows = psycopg2.extras.execute_values(
cur=cur, sql=query, argslist=primary_keys, fetch=True
)
return [OriginVisitStats(**row) for row in rows]
@db_transaction()
def visit_scheduler_queue_position_get(self, db=None, cur=None) -> Dict[str, int]:
cur.execute("SELECT visit_type, position FROM visit_scheduler_queue_position")
return {row["visit_type"]: row["position"] for row in cur}
@db_transaction()
def visit_scheduler_queue_position_set(
- self, visit_type: str, position: int, db=None, cur=None,
+ self,
+ visit_type: str,
+ position: int,
+ db=None,
+ cur=None,
) -> None:
query = """
INSERT INTO visit_scheduler_queue_position(visit_type, position)
VALUES(%s, %s)
ON CONFLICT(visit_type) DO UPDATE SET position=EXCLUDED.position
"""
cur.execute(query, (visit_type, position))
@db_transaction()
def update_metrics(
self,
lister_id: Optional[UUID] = None,
timestamp: Optional[datetime.datetime] = None,
db=None,
cur=None,
) -> List[SchedulerMetrics]:
"""Update the performance metrics of this scheduler instance.
Returns the updated metrics.
Args:
lister_id: if passed, update the metrics only for this lister instance
timestamp: if passed, the date at which we're updating the metrics,
defaults to the database NOW()
"""
query = format_query(
"SELECT {keys} FROM update_metrics(%s, %s)",
SchedulerMetrics.select_columns(),
)
cur.execute(query, (lister_id, timestamp))
return [SchedulerMetrics(**row) for row in cur.fetchall()]
@db_transaction()
def get_metrics(
self,
lister_id: Optional[UUID] = None,
visit_type: Optional[str] = None,
db=None,
cur=None,
) -> List[SchedulerMetrics]:
"""Retrieve the performance metrics of this scheduler instance.
Args:
lister_id: filter the metrics for this lister instance only
visit_type: filter the metrics for this visit type only
"""
where_filters = []
where_args = []
if lister_id:
where_filters.append("lister_id = %s")
where_args.append(str(lister_id))
if visit_type:
where_filters.append("visit_type = %s")
where_args.append(visit_type)
where_clause = ""
if where_filters:
where_clause = f"where {' and '.join(where_filters)}"
query = format_query(
"SELECT {keys} FROM scheduler_metrics %s" % where_clause,
SchedulerMetrics.select_columns(),
)
cur.execute(query, tuple(where_args))
return [SchedulerMetrics(**row) for row in cur.fetchall()]
diff --git a/swh/scheduler/celery_backend/pika_listener.py b/swh/scheduler/celery_backend/pika_listener.py
index d895d5d..a5c48f2 100644
--- a/swh/scheduler/celery_backend/pika_listener.py
+++ b/swh/scheduler/celery_backend/pika_listener.py
@@ -1,107 +1,110 @@
# Copyright (C) 2020-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
"""This is the scheduler listener. It is in charge of listening to rabbitmq events (the
task result) and flushes the "oneshot" tasks' status in the scheduler backend. It's the
final step after a task is done.
The scheduler runner :mod:`swh.scheduler.celery_backend.runner` is the module in charge
of pushing tasks in the queue.
"""
import json
import logging
import sys
import pika
from swh.core.statsd import statsd
from swh.scheduler import get_scheduler
from swh.scheduler.utils import utcnow
logger = logging.getLogger(__name__)
def get_listener(broker_url, queue_name, scheduler_backend):
connection = pika.BlockingConnection(pika.URLParameters(broker_url))
channel = connection.channel()
channel.queue_declare(queue=queue_name, durable=True)
exchange = "celeryev"
routing_key = "#"
channel.queue_bind(queue=queue_name, exchange=exchange, routing_key=routing_key)
channel.basic_qos(prefetch_count=1000)
channel.basic_consume(
- queue=queue_name, on_message_callback=get_on_message(scheduler_backend),
+ queue=queue_name,
+ on_message_callback=get_on_message(scheduler_backend),
)
return channel
def get_on_message(scheduler_backend):
def on_message(channel, method_frame, properties, body):
try:
events = json.loads(body)
except Exception:
logger.warning("Could not parse body %r", body)
events = []
if not isinstance(events, list):
events = [events]
for event in events:
logger.debug("Received event %r", event)
process_event(event, scheduler_backend)
channel.basic_ack(delivery_tag=method_frame.delivery_tag)
return on_message
def process_event(event, scheduler_backend):
uuid = event.get("uuid")
if not uuid:
return
event_type = event["type"]
statsd.increment(
"swh_scheduler_listener_handled_event_total", tags={"event_type": event_type}
)
if event_type == "task-started":
scheduler_backend.start_task_run(
- uuid, timestamp=utcnow(), metadata={"worker": event.get("hostname")},
+ uuid,
+ timestamp=utcnow(),
+ metadata={"worker": event.get("hostname")},
)
elif event_type == "task-result":
result = event["result"]
status = None
if isinstance(result, dict) and "status" in result:
status = result["status"]
if status == "success":
status = "eventful" if result.get("eventful") else "uneventful"
if status is None:
status = "eventful" if result else "uneventful"
scheduler_backend.end_task_run(
uuid, timestamp=utcnow(), status=status, result=result
)
elif event_type == "task-failed":
scheduler_backend.end_task_run(uuid, timestamp=utcnow(), status="failed")
if __name__ == "__main__":
url = sys.argv[1]
logging.basicConfig(level=logging.DEBUG)
scheduler_backend = get_scheduler("local", args={"db": "service=swh-scheduler"})
channel = get_listener(url, "celeryev.test", scheduler_backend)
logger.info("Start consuming")
channel.start_consuming()
diff --git a/swh/scheduler/celery_backend/recurrent_visits.py b/swh/scheduler/celery_backend/recurrent_visits.py
index 2a9ebe6..d58a569 100644
--- a/swh/scheduler/celery_backend/recurrent_visits.py
+++ b/swh/scheduler/celery_backend/recurrent_visits.py
@@ -1,323 +1,329 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
"""This schedules the recurrent visits, for listed origins, in Celery.
For "oneshot" (save code now, lister) tasks, check the
:mod:`swh.scheduler.celery_backend.runner` and
:mod:`swh.scheduler.celery_backend.pika_listener` modules.
"""
from __future__ import annotations
from itertools import chain
import logging
from queue import Empty, Queue
import random
from threading import Thread
import time
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from kombu.utils.uuid import uuid
from swh.scheduler.celery_backend.config import get_available_slots
if TYPE_CHECKING:
from ..interface import SchedulerInterface
from ..model import ListedOrigin
logger = logging.getLogger(__name__)
_VCS_POLICY_WEIGHTS: Dict[str, float] = {
"already_visited_order_by_lag": 49,
"never_visited_oldest_update_first": 49,
"origins_without_last_update": 2,
}
POLICY_WEIGHTS: Dict[str, Dict[str, float]] = {
"default": {
"already_visited_order_by_lag": 50,
"never_visited_oldest_update_first": 50,
},
"git": _VCS_POLICY_WEIGHTS,
"hg": _VCS_POLICY_WEIGHTS,
"svn": _VCS_POLICY_WEIGHTS,
"cvs": _VCS_POLICY_WEIGHTS,
"bzr": _VCS_POLICY_WEIGHTS,
}
POLICY_ADDITIONAL_PARAMETERS: Dict[str, Dict[str, Any]] = {
"git": {
"already_visited_order_by_lag": {"tablesample": 0.1},
"never_visited_oldest_update_first": {"tablesample": 0.1},
"origins_without_last_update": {"tablesample": 0.1},
}
}
"""Scheduling policies to use to retrieve visits for the given visit types, with their
relative weights"""
MIN_SLOTS_RATIO = 0.05
"""Quantity of slots that need to be available (with respect to max_queue_length) for
:py:func:`~swh.scheduler.interface.SchedulerInterface.grab_next_visits` to trigger"""
QUEUE_FULL_BACKOFF = 60
"""Backoff time (in seconds) if there's fewer than :py:data:`MIN_SLOTS_RATIO` slots
available in the queue."""
NO_ORIGINS_SCHEDULED_BACKOFF = 20 * 60
"""Backoff time (in seconds) if no origins have been scheduled in the current
iteration"""
BACKOFF_SPLAY = 5.0
"""Amplitude of the fuzziness between backoffs"""
TERMINATE = object()
"""Termination request received from command queue (singleton used for identity
comparison)"""
def grab_next_visits_policy_weights(
scheduler: SchedulerInterface, visit_type: str, num_visits: int
) -> List[ListedOrigin]:
"""Get the next ``num_visits`` for the given ``visit_type`` using the corresponding
set of scheduling policies.
The :py:data:`POLICY_WEIGHTS` dict sets, for each visit type, the scheduling
policies used to pull the next tasks, and what proportion of the available
num_visits they take.
This function emits a warning if the ratio of retrieved origins is off of
the requested ratio by more than 5%.
Returns:
at most ``num_visits`` :py:class:`~swh.scheduler.model.ListedOrigin` objects
"""
policy_weights = POLICY_WEIGHTS.get(visit_type, POLICY_WEIGHTS["default"])
total_weight = sum(policy_weights.values())
if not total_weight:
raise ValueError(f"No policy weights set for visit type {visit_type}")
policy_ratio = {
policy: weight / total_weight for policy, weight in policy_weights.items()
}
fetched_origins: Dict[str, List[ListedOrigin]] = {}
for policy, ratio in policy_ratio.items():
num_tasks_to_send = int(num_visits * ratio)
fetched_origins[policy] = scheduler.grab_next_visits(
visit_type,
num_tasks_to_send,
policy=policy,
**POLICY_ADDITIONAL_PARAMETERS.get(visit_type, {}).get(policy, {}),
)
all_origins: List[ListedOrigin] = list(
chain.from_iterable(fetched_origins.values())
)
if not all_origins:
return []
# Check whether the ratios of origins fetched are skewed with respect to the
# ones we requested
fetched_origin_ratios = {
policy: len(origins) / len(all_origins)
for policy, origins in fetched_origins.items()
}
for policy, expected_ratio in policy_ratio.items():
# 5% of skew with respect to request
if abs(fetched_origin_ratios[policy] - expected_ratio) / expected_ratio > 0.05:
logger.info(
"Skewed fetch for visit type %s with policy %s: fetched %s, "
"requested %s",
visit_type,
policy,
fetched_origin_ratios[policy],
expected_ratio,
)
return all_origins
def splay():
"""Return a random short interval by which to vary the backoffs for the visit
scheduling threads"""
return random.uniform(0, BACKOFF_SPLAY)
def send_visits_for_visit_type(
- scheduler: SchedulerInterface, app, visit_type: str, task_type: Dict,
+ scheduler: SchedulerInterface,
+ app,
+ visit_type: str,
+ task_type: Dict,
) -> float:
"""Schedule the next batch of visits for the given ``visit_type``.
First, we determine the number of available slots by introspecting the RabbitMQ
queue.
If there's fewer than :py:data:`MIN_SLOTS_RATIO` slots available in the queue, we
wait for :py:data:`QUEUE_FULL_BACKOFF` seconds. This avoids running the expensive
:py:func:`~swh.scheduler.interface.SchedulerInterface.grab_next_visits` queries when
there's not many jobs to queue.
Once there's more than :py:data:`MIN_SLOTS_RATIO` slots available, we run
:py:func:`grab_next_visits_policy_weights` to retrieve the next set of origin visits
to schedule, and we send them to celery.
If the last scheduling attempt didn't return any origins, we sleep for
:py:data:`NO_ORIGINS_SCHEDULED_BACKOFF` seconds. This avoids running the expensive
:py:func:`~swh.scheduler.interface.SchedulerInterface.grab_next_visits` queries too
often if there's nothing left to schedule.
Returns:
the earliest :py:func:`time.monotonic` value at which to run the next iteration
of the loop.
"""
queue_name = task_type["backend_name"]
max_queue_length = task_type.get("max_queue_length") or 0
min_available_slots = max_queue_length * MIN_SLOTS_RATIO
current_iteration_start = time.monotonic()
# Check queue level
available_slots = get_available_slots(app, queue_name, max_queue_length)
logger.debug(
"%s available slots for visit type %s in queue %s",
available_slots,
visit_type,
queue_name,
)
if available_slots < min_available_slots:
return current_iteration_start + QUEUE_FULL_BACKOFF
origins = grab_next_visits_policy_weights(scheduler, visit_type, available_slots)
if not origins:
logger.debug("No origins to visit for type %s", visit_type)
return current_iteration_start + NO_ORIGINS_SCHEDULED_BACKOFF
# Try to smooth the ingestion load, origins pulled by different
# scheduling policies have different resource usage patterns
random.shuffle(origins)
for origin in origins:
task_dict = origin.as_task_dict()
app.send_task(
queue_name,
task_id=uuid(),
args=task_dict["arguments"]["args"],
kwargs=task_dict["arguments"]["kwargs"],
queue=queue_name,
)
logger.info(
- "%s: %s visits scheduled in queue %s", visit_type, len(origins), queue_name,
+ "%s: %s visits scheduled in queue %s",
+ visit_type,
+ len(origins),
+ queue_name,
)
# When everything worked, we can try to schedule origins again ASAP.
return time.monotonic()
def visit_scheduler_thread(
config: Dict,
visit_type: str,
command_queue: Queue[object],
exc_queue: Queue[Tuple[str, BaseException]],
):
"""Target function for the visit sending thread, which initializes local connections
and handles exceptions by sending them back to the main thread."""
from swh.scheduler import get_scheduler
from swh.scheduler.celery_backend.config import build_app
try:
# We need to reinitialize these connections because they're not generally
# thread-safe
app = build_app(config.get("celery"))
scheduler = get_scheduler(**config["scheduler"])
task_type = scheduler.get_task_type(f"load-{visit_type}")
if task_type is None:
raise ValueError(f"Unknown task type: load-{visit_type}")
next_iteration = time.monotonic()
while True:
# vary the next iteration time a little bit
next_iteration = next_iteration + splay()
while time.monotonic() < next_iteration:
# Wait for next iteration to start. Listen for termination message.
try:
msg = command_queue.get(block=True, timeout=1)
except Empty:
continue
if msg is TERMINATE:
return
else:
logger.warn("Received unexpected message %s in command queue", msg)
next_iteration = send_visits_for_visit_type(
scheduler, app, visit_type, task_type
)
except BaseException as e:
exc_queue.put((visit_type, e))
VisitSchedulerThreads = Dict[str, Tuple[Thread, Queue]]
"""Dict storing the visit scheduler threads and their command queues"""
def spawn_visit_scheduler_thread(
threads: VisitSchedulerThreads,
exc_queue: Queue[Tuple[str, BaseException]],
config: Dict[str, Any],
visit_type: str,
):
"""Spawn a new thread to schedule the visits of type ``visit_type``."""
command_queue: Queue[object] = Queue()
thread = Thread(
target=visit_scheduler_thread,
kwargs={
"config": config,
"visit_type": visit_type,
"command_queue": command_queue,
"exc_queue": exc_queue,
},
)
threads[visit_type] = (thread, command_queue)
thread.start()
def terminate_visit_scheduler_threads(threads: VisitSchedulerThreads) -> List[str]:
"""Terminate all visit scheduler threads"""
logger.info("Termination requested...")
for _, command_queue in threads.values():
command_queue.put(TERMINATE)
loops = 0
while threads and loops < 10:
logger.info(
"Terminating visit scheduling threads: %s", ", ".join(sorted(threads))
)
loops += 1
for visit_type, (thread, _) in list(threads.items()):
thread.join(timeout=1)
if not thread.is_alive():
logger.debug("Thread %s terminated", visit_type)
del threads[visit_type]
if threads:
logger.warn(
"Could not reap the following threads after 10 attempts: %s",
", ".join(sorted(threads)),
)
return list(sorted(threads))
diff --git a/swh/scheduler/celery_backend/runner.py b/swh/scheduler/celery_backend/runner.py
index 53b2e25..be59acf 100644
--- a/swh/scheduler/celery_backend/runner.py
+++ b/swh/scheduler/celery_backend/runner.py
@@ -1,180 +1,184 @@
# Copyright (C) 2015-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
"""This is the first scheduler runner. It is in charge of scheduling "oneshot" tasks
(e.g save code now, indexer, vault, deposit, ...). To do this, it reads tasks ouf of the
scheduler backend and pushes those to their associated rabbitmq queues.
The scheduler listener :mod:`swh.scheduler.celery_backend.pika_listener` is the module
in charge of finalizing the task results.
"""
import logging
from typing import Dict, List, Tuple
from deprecated import deprecated
from kombu.utils.uuid import uuid
from swh.core.statsd import statsd
from swh.scheduler import get_scheduler
from swh.scheduler.celery_backend.config import get_available_slots
from swh.scheduler.interface import SchedulerInterface
from swh.scheduler.utils import utcnow
logger = logging.getLogger(__name__)
# Max batch size for tasks
MAX_NUM_TASKS = 10000
def run_ready_tasks(
backend: SchedulerInterface,
app,
task_types: List[Dict] = [],
with_priority: bool = False,
) -> List[Dict]:
"""Schedule tasks ready to be scheduled.
This lookups any tasks per task type and mass schedules those accordingly (send
messages to rabbitmq and mark as scheduled equivalent tasks in the scheduler
backend).
If tasks (per task type) with priority exist, they will get redirected to dedicated
high priority queue (standard queue name prefixed with `save_code_now:`).
Args:
backend: scheduler backend to interact with (read/update tasks)
app (App): Celery application to send tasks to
task_types: The list of task types dict to iterate over. By default, empty.
When empty, the full list of task types referenced in the scheduler will be
used.
with_priority: If True, only tasks with priority set will be fetched and
scheduled. By default, False.
Returns:
A list of dictionaries::
{
'task': the scheduler's task id,
'backend_id': Celery's task id,
'scheduler': utcnow()
}
The result can be used to block-wait for the tasks' results::
backend_tasks = run_ready_tasks(self.scheduler, app)
for task in backend_tasks:
AsyncResult(id=task['backend_id']).get()
"""
all_backend_tasks: List[Dict] = []
while True:
if not task_types:
task_types = backend.get_task_types()
task_types_d = {}
pending_tasks = []
for task_type in task_types:
task_type_name = task_type["type"]
task_types_d[task_type_name] = task_type
max_queue_length = task_type["max_queue_length"]
if max_queue_length is None:
max_queue_length = 0
backend_name = task_type["backend_name"]
if with_priority:
# grab max_queue_length (or 10) potential tasks with any priority for
# the same type (limit the result to avoid too long running queries)
grabbed_priority_tasks = backend.grab_ready_priority_tasks(
task_type_name, num_tasks=max_queue_length or 10
)
if grabbed_priority_tasks:
pending_tasks.extend(grabbed_priority_tasks)
logger.info(
"Grabbed %s tasks %s (priority)",
len(grabbed_priority_tasks),
task_type_name,
)
statsd.increment(
"swh_scheduler_runner_scheduled_task_total",
len(grabbed_priority_tasks),
tags={"task_type": task_type_name},
)
else:
num_tasks = get_available_slots(app, backend_name, max_queue_length)
# only pull tasks if the buffer is at least 1/5th empty (= 80%
# full), to help postgresql use properly indexed queries.
if num_tasks > min(MAX_NUM_TASKS, max_queue_length) // 5:
# Only grab num_tasks tasks with no priority
grabbed_tasks = backend.grab_ready_tasks(
task_type_name, num_tasks=num_tasks
)
if grabbed_tasks:
pending_tasks.extend(grabbed_tasks)
logger.info(
"Grabbed %s tasks %s", len(grabbed_tasks), task_type_name
)
statsd.increment(
"swh_scheduler_runner_scheduled_task_total",
len(grabbed_tasks),
tags={"task_type": task_type_name},
)
if not pending_tasks:
return all_backend_tasks
backend_tasks = []
celery_tasks: List[Tuple[bool, str, str, List, Dict]] = []
for task in pending_tasks:
args = task["arguments"]["args"]
kwargs = task["arguments"]["kwargs"]
backend_name = task_types_d[task["type"]]["backend_name"]
backend_id = uuid()
celery_tasks.append(
(
task.get("priority") is not None,
backend_name,
backend_id,
args,
kwargs,
)
)
data = {
"task": task["id"],
"backend_id": backend_id,
"scheduled": utcnow(),
}
backend_tasks.append(data)
logger.debug("Sent %s celery tasks", len(backend_tasks))
backend.mass_schedule_task_runs(backend_tasks)
for with_priority, backend_name, backend_id, args, kwargs in celery_tasks:
- kw = dict(task_id=backend_id, args=args, kwargs=kwargs,)
+ kw = dict(
+ task_id=backend_id,
+ args=args,
+ kwargs=kwargs,
+ )
if with_priority:
kw["queue"] = f"save_code_now:{backend_name}"
app.send_task(backend_name, **kw)
all_backend_tasks.extend(backend_tasks)
@deprecated(version="0.18", reason="Use `swh scheduler start-runner` instead")
def main():
from .config import app as main_app
for module in main_app.conf.CELERY_IMPORTS:
__import__(module)
main_backend = get_scheduler("local")
try:
run_ready_tasks(main_backend, main_app)
except Exception:
main_backend.rollback()
raise
if __name__ == "__main__":
main()
diff --git a/swh/scheduler/cli/__init__.py b/swh/scheduler/cli/__init__.py
index 9a2ab19..cda7259 100644
--- a/swh/scheduler/cli/__init__.py
+++ b/swh/scheduler/cli/__init__.py
@@ -1,99 +1,102 @@
# Copyright (C) 2016-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
# WARNING: do not import unnecessary things here to keep cli startup time under
# control
import logging
import click
from swh.core.cli import CONTEXT_SETTINGS, AliasedGroup
from swh.core.cli import swh as swh_cli_group
# If you're looking for subcommand imports, they are further down this file to
# avoid a circular import!
@swh_cli_group.group(
name="scheduler", context_settings=CONTEXT_SETTINGS, cls=AliasedGroup
)
@click.option(
"--config-file",
"-C",
default=None,
- type=click.Path(exists=True, dir_okay=False,),
+ type=click.Path(
+ exists=True,
+ dir_okay=False,
+ ),
help="Configuration file.",
)
@click.option(
"--database",
"-d",
default=None,
help="Scheduling database DSN (imply cls is 'local')",
)
@click.option(
"--url", "-u", default=None, help="Scheduler's url access (imply cls is 'remote')"
)
@click.option(
"--no-stdout", is_flag=True, default=False, help="Do NOT output logs on the console"
)
@click.pass_context
def cli(ctx, config_file, database, url, no_stdout):
"""Software Heritage Scheduler tools.
Use a local scheduler instance by default (plugged to the
main scheduler db).
"""
try:
from psycopg2 import OperationalError
except ImportError:
class OperationalError(Exception):
pass
from swh.core import config
from swh.scheduler import DEFAULT_CONFIG, get_scheduler
ctx.ensure_object(dict)
logger = logging.getLogger(__name__)
scheduler = None
conf = config.read(config_file, DEFAULT_CONFIG)
if "scheduler" not in conf:
raise ValueError("missing 'scheduler' configuration")
if database:
conf["scheduler"]["cls"] = "local"
conf["scheduler"]["db"] = database
elif url:
conf["scheduler"]["cls"] = "remote"
conf["scheduler"]["url"] = url
sched_conf = conf["scheduler"]
try:
logger.debug("Instantiating scheduler with %s", sched_conf)
scheduler = get_scheduler(**sched_conf)
except (ValueError, OperationalError):
# it's the subcommand to decide whether not having a proper
# scheduler instance is a problem.
pass
ctx.obj["scheduler"] = scheduler
ctx.obj["config"] = conf
from . import admin, celery_monitor, journal, origin, simulator, task, task_type # noqa
def main():
import click.core
click.core.DEPRECATED_HELP_NOTICE = """
DEPRECATED! Please use the command 'swh scheduler'."""
cli.deprecated = True
return cli(auto_envvar_prefix="SWH_SCHEDULER")
if __name__ == "__main__":
main()
diff --git a/swh/scheduler/cli/admin.py b/swh/scheduler/cli/admin.py
index 861651e..73daab3 100644
--- a/swh/scheduler/cli/admin.py
+++ b/swh/scheduler/cli/admin.py
@@ -1,226 +1,225 @@
# Copyright (C) 2016-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from __future__ import annotations
# WARNING: do not import unnecessary things here to keep cli startup time under
# control
import logging
import time
from typing import List, Tuple
import click
from . import cli
@cli.command("start-runner")
@click.option(
"--period",
"-p",
default=0,
help=(
"Period (in s) at witch pending tasks are checked and "
"executed. Set to 0 (default) for a one shot."
),
)
@click.option(
"--task-type",
"task_type_names",
multiple=True,
default=[],
help=(
"Task types to schedule. If not provided, this iterates over every "
"task types referenced in the scheduler backend."
),
)
@click.option(
"--with-priority/--without-priority",
is_flag=True,
default=False,
help=(
"Determine if those tasks should be the ones with priority or not."
"By default, this deals with tasks without any priority."
),
)
@click.pass_context
def runner(ctx, period, task_type_names, with_priority):
"""Starts a swh-scheduler runner service.
This process is responsible for checking for ready-to-run tasks and
schedule them."""
from swh.scheduler.celery_backend.config import build_app
from swh.scheduler.celery_backend.runner import run_ready_tasks
config = ctx.obj["config"]
app = build_app(config.get("celery"))
app.set_current()
logger = logging.getLogger(__name__ + ".runner")
scheduler = ctx.obj["scheduler"]
logger.debug("Scheduler %s", scheduler)
task_types = []
for task_type_name in task_type_names:
task_type = scheduler.get_task_type(task_type_name)
if not task_type:
raise ValueError(f"Unknown {task_type_name}")
task_types.append(task_type)
try:
while True:
logger.debug("Run ready tasks")
try:
ntasks = len(run_ready_tasks(scheduler, app, task_types, with_priority))
if ntasks:
logger.info("Scheduled %s tasks", ntasks)
except Exception:
logger.exception("Unexpected error in run_ready_tasks()")
if not period:
break
time.sleep(period)
except KeyboardInterrupt:
ctx.exit(0)
@cli.command("start-listener")
@click.pass_context
def listener(ctx):
"""Starts a swh-scheduler listener service.
This service is responsible for listening at task lifecycle events and
handle their workflow status in the database."""
scheduler_backend = ctx.obj["scheduler"]
if not scheduler_backend:
raise ValueError("Scheduler class (local/remote) must be instantiated")
broker = (
ctx.obj["config"]
.get("celery", {})
.get("task_broker", "amqp://guest@localhost/%2f")
)
from swh.scheduler.celery_backend.pika_listener import get_listener
listener = get_listener(broker, "celeryev.listener", scheduler_backend)
try:
listener.start_consuming()
finally:
listener.stop_consuming()
@cli.command("schedule-recurrent")
@click.option(
"--visit-type",
"visit_types",
multiple=True,
default=[],
help=(
"Visit types to schedule. If not provided, this iterates over every "
"corresponding load task types referenced in the scheduler backend."
),
)
@click.pass_context
def schedule_recurrent(ctx, visit_types: List[str]):
"""Starts the scheduler for recurrent visits.
This runs one thread for each visit type, which regularly sends new visits
to celery.
"""
from queue import Queue
from swh.scheduler.celery_backend.recurrent_visits import (
VisitSchedulerThreads,
logger,
spawn_visit_scheduler_thread,
terminate_visit_scheduler_threads,
)
config = ctx.obj["config"]
scheduler = ctx.obj["scheduler"]
if not visit_types:
visit_types = []
# Figure out which visit types exist in the scheduler
all_task_types = scheduler.get_task_types()
for task_type in all_task_types:
if not task_type["type"].startswith("load-"):
# only consider loading tasks as recurring ones, the rest is dismissed
continue
# get visit type name from task type
visit_types.append(task_type["type"][5:])
else:
# Check that the passed visit types exist in the scheduler
for visit_type in visit_types:
task_type_name = f"load-{visit_type}"
task_type = scheduler.get_task_type(task_type_name)
if not task_type:
raise ValueError(f"Unknown task type: {task_type_name}")
exc_queue: Queue[Tuple[str, BaseException]] = Queue()
threads: VisitSchedulerThreads = {}
try:
# Spawn initial threads
for visit_type in visit_types:
spawn_visit_scheduler_thread(threads, exc_queue, config, visit_type)
# Handle exceptions from child threads
while True:
visit_type, exc_info = exc_queue.get(block=True)
logger.exception(
"Thread %s died with exception; respawning",
visit_type,
exc_info=exc_info,
)
dead_thread = threads[visit_type][0]
dead_thread.join(timeout=1)
if dead_thread.is_alive():
logger.warn(
"The thread for %s is still alive after sending an exception?! "
"Respawning anyway.",
visit_type,
)
spawn_visit_scheduler_thread(threads, exc_queue, config, visit_type)
except SystemExit:
remaining_threads = terminate_visit_scheduler_threads(threads)
if remaining_threads:
ctx.exit(1)
ctx.exit(0)
@cli.command("rpc-serve")
@click.option("--host", default="0.0.0.0", help="Host to run the scheduler server api")
@click.option("--port", default=5008, type=click.INT, help="Binding port of the server")
@click.option(
"--debug/--nodebug",
default=None,
help=(
"Indicates if the server should run in debug mode. "
"Defaults to True if log-level is DEBUG, False otherwise."
),
)
@click.pass_context
def rpc_server(ctx, host, port, debug):
- """Starts a swh-scheduler API HTTP server.
- """
+ """Starts a swh-scheduler API HTTP server."""
if ctx.obj["config"]["scheduler"]["cls"] == "remote":
click.echo(
"The API server can only be started with a 'local' " "configuration",
err=True,
)
ctx.exit(1)
from swh.scheduler.api import server
server.app.config.update(ctx.obj["config"])
if debug is None:
debug = ctx.obj["log_level"] <= logging.DEBUG
server.app.run(host, port=port, debug=bool(debug))
diff --git a/swh/scheduler/cli/journal.py b/swh/scheduler/cli/journal.py
index 9551164..58579ae 100644
--- a/swh/scheduler/cli/journal.py
+++ b/swh/scheduler/cli/journal.py
@@ -1,59 +1,61 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import click
from . import cli as cli_scheduler_group
@cli_scheduler_group.command("journal-client")
@click.pass_context
@click.option(
"--stop-after-objects",
"-m",
default=None,
type=int,
help="Maximum number of objects to replay. Default is to run forever.",
)
def visit_stats_journal_client(ctx, stop_after_objects):
- """Keep the the origin visits stats table up to date from a swh kafka journal
- """
+ """Keep the the origin visits stats table up to date from a swh kafka journal"""
from functools import partial
from swh.journal.client import get_journal_client
from swh.scheduler.journal_client import process_journal_objects
if not ctx.obj["scheduler"]:
raise ValueError("Scheduler class (local/remote) must be instantiated")
scheduler = ctx.obj["scheduler"]
config = ctx.obj["config"]
if "journal" not in config:
raise ValueError("Missing 'journal' configuration key")
journal_cfg = config["journal"]
journal_cfg["stop_after_objects"] = stop_after_objects or journal_cfg.get(
"stop_after_objects"
)
client = get_journal_client(
cls="kafka",
object_types=["origin_visit_status"],
prefix="swh.journal.objects",
**journal_cfg,
)
- worker_fn = partial(process_journal_objects, scheduler=scheduler,)
+ worker_fn = partial(
+ process_journal_objects,
+ scheduler=scheduler,
+ )
nb_messages = 0
try:
nb_messages = client.process(worker_fn)
print(f"Processed {nb_messages} message(s).")
except KeyboardInterrupt:
ctx.exit(0)
else:
print("Done.")
finally:
client.close()
diff --git a/swh/scheduler/cli/origin.py b/swh/scheduler/cli/origin.py
index 5f5d5d6..987ecc3 100644
--- a/swh/scheduler/cli/origin.py
+++ b/swh/scheduler/cli/origin.py
@@ -1,253 +1,258 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from __future__ import annotations
from typing import TYPE_CHECKING, Iterable, List, Optional
import click
from . import cli
if TYPE_CHECKING:
from uuid import UUID
from ..interface import SchedulerInterface
from ..model import ListedOrigin
@cli.group("origin")
@click.pass_context
def origin(ctx):
"""Manipulate listed origins."""
if not ctx.obj["scheduler"]:
raise ValueError("Scheduler class (local/remote) must be instantiated")
def format_origins(
origins: List[ListedOrigin],
fields: Optional[List[str]] = None,
with_header: bool = True,
) -> Iterable[str]:
"""Format a list of origins as CSV.
Arguments:
origins: list of origins to output
fields: optional list of fields to output (defaults to all fields)
with_header: if True, output a CSV header.
"""
import csv
from io import StringIO
import attr
from ..model import ListedOrigin
expected_fields = [field.name for field in attr.fields(ListedOrigin)]
if not fields:
fields = expected_fields
unknown_fields = set(fields) - set(expected_fields)
if unknown_fields:
raise ValueError(
"Unknown ListedOrigin field(s): %s" % ", ".join(unknown_fields)
)
output = StringIO()
writer = csv.writer(output)
def csv_row(data):
"""Return a single CSV-formatted row. We clear the output buffer after we're
done to keep it reasonably sized."""
writer.writerow(data)
output.seek(0)
ret = output.read().rstrip()
output.seek(0)
output.truncate()
return ret
if with_header:
yield csv_row(fields)
for origin in origins:
yield csv_row(str(getattr(origin, field)) for field in fields)
@origin.command("grab-next")
@click.option(
"--policy", "-p", default="oldest_scheduled_first", help="Scheduling policy"
)
@click.option(
"--fields", "-f", default=None, help="Listed origin fields to print on output"
)
@click.option(
"--with-header/--without-header",
is_flag=True,
default=True,
help="Print the CSV header?",
)
@click.argument("type", type=str)
@click.argument("count", type=int)
@click.pass_context
def grab_next(
ctx, policy: str, fields: Optional[str], with_header: bool, type: str, count: int
):
"""Grab the next COUNT origins to visit using the TYPE loader from the
listed origins table."""
if fields:
parsed_fields: Optional[List[str]] = fields.split(",")
else:
parsed_fields = None
scheduler = ctx.obj["scheduler"]
origins = scheduler.grab_next_visits(type, count, policy=policy)
for line in format_origins(origins, fields=parsed_fields, with_header=with_header):
click.echo(line)
@origin.command("schedule-next")
@click.option(
"--policy", "-p", default="oldest_scheduled_first", help="Scheduling policy"
)
@click.argument("type", type=str)
@click.argument("count", type=int)
@click.pass_context
def schedule_next(ctx, policy: str, type: str, count: int):
"""Send the next COUNT origin visits of the TYPE loader to the scheduler as
one-shot tasks."""
from ..utils import utcnow
from .task import pretty_print_task
scheduler = ctx.obj["scheduler"]
origins = scheduler.grab_next_visits(type, count, policy=policy)
created = scheduler.create_tasks(
[
{
**origin.as_task_dict(),
"policy": "oneshot",
"next_run": utcnow(),
"retries_left": 1,
}
for origin in origins
]
)
output = ["Created %d tasks\n" % len(created)]
for task in created:
output.append(pretty_print_task(task))
click.echo_via_pager("\n".join(output))
@origin.command("send-to-celery")
@click.option(
"--policy", "-p", default="oldest_scheduled_first", help="Scheduling policy"
)
@click.option(
- "--queue", "-q", help="Target celery queue", type=str,
+ "--queue",
+ "-q",
+ help="Target celery queue",
+ type=str,
)
@click.option(
- "--tablesample", help="Table sampling percentage", type=float,
+ "--tablesample",
+ help="Table sampling percentage",
+ type=float,
)
@click.option(
"--only-enabled/--only-disabled",
"enabled",
is_flag=True,
default=True,
help="""Determine whether we want to scheduled enabled or disabled origins. As default, we
want to reasonably deal with enabled origins. For some edge case though, we
might want the disabled ones.""",
)
@click.option(
"--lister-uuid",
default=None,
help="Limit origins to those listed from such lister",
)
@click.argument("type", type=str)
@click.pass_context
def send_to_celery(
ctx,
policy: str,
queue: Optional[str],
tablesample: Optional[float],
type: str,
enabled: bool,
lister_uuid: Optional[str] = None,
):
"""Send the next origin visits of the TYPE loader to celery, filling the queue."""
from kombu.utils.uuid import uuid
from swh.scheduler.celery_backend.config import app, get_available_slots
scheduler = ctx.obj["scheduler"]
task_type = scheduler.get_task_type(f"load-{type}")
task_name = task_type["backend_name"]
queue_name = queue or task_name
num_tasks = get_available_slots(app, queue_name, task_type["max_queue_length"])
click.echo(f"{num_tasks} slots available in celery queue")
origins = scheduler.grab_next_visits(
type,
num_tasks,
policy=policy,
tablesample=tablesample,
enabled=enabled,
lister_uuid=lister_uuid,
)
click.echo(f"{len(origins)} visits to send to celery")
for origin in origins:
task_dict = origin.as_task_dict()
app.send_task(
task_name,
task_id=uuid(),
args=task_dict["arguments"]["args"],
kwargs=task_dict["arguments"]["kwargs"],
queue=queue_name,
)
@origin.command("update-metrics")
@click.option("--lister", default=None, help="Only update metrics for this lister")
@click.option(
"--instance", default=None, help="Only update metrics for this lister instance"
)
@click.pass_context
def update_metrics(ctx, lister: Optional[str], instance: Optional[str]):
"""Update the scheduler metrics on listed origins.
Examples:
swh scheduler origin update-metrics
swh scheduler origin update-metrics --lister github
swh scheduler origin update-metrics --lister phabricator --instance llvm
"""
import json
import attr
scheduler: SchedulerInterface = ctx.obj["scheduler"]
lister_id: Optional[UUID] = None
if lister is not None:
lister_instance = scheduler.get_lister(name=lister, instance_name=instance)
if not lister_instance:
click.echo(f"Lister not found: {lister} instance={instance}")
ctx.exit(2)
assert False # for mypy
lister_id = lister_instance.id
def dictify_metrics(d):
return {k: str(v) for (k, v) in attr.asdict(d).items()}
ret = scheduler.update_metrics(lister_id=lister_id)
click.echo(json.dumps(list(map(dictify_metrics, ret)), indent=4, sort_keys=True))
diff --git a/swh/scheduler/cli/task.py b/swh/scheduler/cli/task.py
index e003985..65ed2a7 100644
--- a/swh/scheduler/cli/task.py
+++ b/swh/scheduler/cli/task.py
@@ -1,590 +1,595 @@
# Copyright (C) 2016-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from __future__ import annotations
# WARNING: do not import unnecessary things here to keep cli startup time under
# control
import locale
from typing import TYPE_CHECKING, Iterator, List, Optional
import click
from . import cli
if TYPE_CHECKING:
import datetime
# importing swh.storage.interface triggers the load of 300+ modules, so...
from swh.model.model import Origin
from swh.storage.interface import StorageInterface
locale.setlocale(locale.LC_ALL, "")
CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"])
DATETIME = click.DateTime()
def format_dict(d):
"""Recursively format date objects in the dict passed as argument"""
import datetime
ret = {}
for k, v in d.items():
if isinstance(v, (datetime.date, datetime.datetime)):
v = v.isoformat()
elif isinstance(v, dict):
v = format_dict(v)
ret[k] = v
return ret
def pretty_print_list(list, indent=0):
"""Pretty-print a list"""
return "".join("%s%r\n" % (" " * indent, item) for item in list)
def pretty_print_dict(dict, indent=0):
"""Pretty-print a list"""
return "".join(
"%s%s: %r\n" % (" " * indent, click.style(key, bold=True), value)
for key, value in sorted(dict.items())
)
def pretty_print_run(run, indent=4):
fmt = (
"{indent}{backend_id} [{status}]\n"
"{indent} scheduled: {scheduled} [{started}:{ended}]"
)
return fmt.format(indent=" " * indent, **format_dict(run))
def pretty_print_task(task, full=False):
"""Pretty-print a task
If 'full' is True, also print the status and priority fields.
>>> import datetime
>>> task = {
... 'id': 1234,
... 'arguments': {
... 'args': ['foo', 'bar', True],
... 'kwargs': {'key': 'value', 'key2': 42},
... },
... 'current_interval': datetime.timedelta(hours=1),
... 'next_run': datetime.datetime(2019, 2, 21, 13, 52, 35, 407818),
... 'policy': 'oneshot',
... 'priority': None,
... 'status': 'next_run_not_scheduled',
... 'type': 'test_task',
... }
>>> print(click.unstyle(pretty_print_task(task)))
Task 1234
Next run: ... (2019-02-21T13:52:35.407818)
Interval: 1:00:00
Type: test_task
Policy: oneshot
Args:
'foo'
'bar'
True
Keyword args:
key: 'value'
key2: 42
>>> print(click.unstyle(pretty_print_task(task, full=True)))
Task 1234
Next run: ... (2019-02-21T13:52:35.407818)
Interval: 1:00:00
Type: test_task
Policy: oneshot
Status: next_run_not_scheduled
Priority:\x20
Args:
'foo'
'bar'
True
Keyword args:
key: 'value'
key2: 42
"""
import humanize
next_run = task["next_run"]
lines = [
"%s %s\n" % (click.style("Task", bold=True), task["id"]),
click.style(" Next run: ", bold=True),
"%s (%s)" % (humanize.naturaldate(next_run), next_run.isoformat()),
"\n",
click.style(" Interval: ", bold=True),
str(task["current_interval"]),
"\n",
click.style(" Type: ", bold=True),
task["type"] or "",
"\n",
click.style(" Policy: ", bold=True),
task["policy"] or "",
"\n",
]
if full:
lines += [
click.style(" Status: ", bold=True),
task["status"] or "",
"\n",
click.style(" Priority: ", bold=True),
task["priority"] or "",
"\n",
]
lines += [
click.style(" Args:\n", bold=True),
pretty_print_list(task["arguments"]["args"], indent=4),
click.style(" Keyword args:\n", bold=True),
pretty_print_dict(task["arguments"]["kwargs"], indent=4),
]
return "".join(lines)
@cli.group("task")
@click.pass_context
def task(ctx):
"""Manipulate tasks."""
pass
@task.command("schedule")
@click.option(
"--columns",
"-c",
multiple=True,
default=["type", "args", "kwargs", "next_run"],
type=click.Choice(["type", "args", "kwargs", "policy", "next_run"]),
help="columns present in the CSV file",
)
@click.option("--delimiter", "-d", default=",")
@click.argument("file", type=click.File(encoding="utf-8"))
@click.pass_context
def schedule_tasks(ctx, columns, delimiter, file):
"""Schedule tasks from a CSV input file.
The following columns are expected, and can be set through the -c option:
- type: the type of the task to be scheduled (mandatory)
- args: the arguments passed to the task (JSON list, defaults to an empty
list)
- kwargs: the keyword arguments passed to the task (JSON object, defaults
to an empty dict)
- next_run: the date at which the task should run (datetime, defaults to
now)
The CSV can be read either from a named file, or from stdin (use - as
filename).
Use sample:
cat scheduling-task.txt | \
python3 -m swh.scheduler.cli \
--database 'service=swh-scheduler-dev' \
task schedule \
--columns type --columns kwargs --columns policy \
--delimiter ';' -
"""
import csv
import json
from swh.scheduler.utils import utcnow
tasks = []
now = utcnow()
scheduler = ctx.obj["scheduler"]
if not scheduler:
raise ValueError("Scheduler class (local/remote) must be instantiated")
reader = csv.reader(file, delimiter=delimiter)
for line in reader:
task = dict(zip(columns, line))
args = json.loads(task.pop("args", "[]"))
kwargs = json.loads(task.pop("kwargs", "{}"))
task["arguments"] = {
"args": args,
"kwargs": kwargs,
}
task["next_run"] = task.get("next_run", now)
tasks.append(task)
created = scheduler.create_tasks(tasks)
output = [
"Created %d tasks\n" % len(created),
]
for task in created:
output.append(pretty_print_task(task))
click.echo_via_pager("\n".join(output))
@task.command("add")
@click.argument("type", nargs=1, required=True)
@click.argument("options", nargs=-1)
@click.option(
"--policy", "-p", default="recurring", type=click.Choice(["recurring", "oneshot"])
)
@click.option(
"--priority", "-P", default=None, type=click.Choice(["low", "normal", "high"])
)
@click.option("--next-run", "-n", default=None)
@click.pass_context
def schedule_task(ctx, type, options, policy, priority, next_run):
"""Schedule one task from arguments.
The first argument is the name of the task type, further ones are
positional and keyword argument(s) of the task, in YAML format.
Keyword args are of the form key=value.
Usage sample:
swh-scheduler --database 'service=swh-scheduler' \
task add list-pypi
swh-scheduler --database 'service=swh-scheduler' \
task add list-debian-distribution --policy=oneshot distribution=stretch
Note: if the priority is not given, the task won't have the priority set,
which is considered as the lowest priority level.
"""
from swh.scheduler.utils import utcnow
from .utils import parse_options
scheduler = ctx.obj["scheduler"]
if not scheduler:
raise ValueError("Scheduler class (local/remote) must be instantiated")
now = utcnow()
(args, kw) = parse_options(options)
task = {
"type": type,
"policy": policy,
"priority": priority,
- "arguments": {"args": args, "kwargs": kw,},
+ "arguments": {
+ "args": args,
+ "kwargs": kw,
+ },
"next_run": next_run or now,
}
created = scheduler.create_tasks([task])
output = [
"Created %d tasks\n" % len(created),
]
for task in created:
output.append(pretty_print_task(task))
click.echo("\n".join(output))
def iter_origins( # use string annotations to prevent some pkg loading
- storage: "StorageInterface", page_token: "Optional[str]" = None,
+ storage: "StorageInterface",
+ page_token: "Optional[str]" = None,
) -> "Iterator[Origin]":
"""Iterate over origins in the storage. Optionally starting from page_token.
This logs regularly an info message during pagination with the page_token. This, in
order to feed it back to the cli if the process interrupted.
Yields
origin model objects from the storage
"""
while True:
page_result = storage.origin_list(page_token=page_token)
page_token = page_result.next_page_token
yield from page_result.results
if not page_token:
break
click.echo(f"page_token: {page_token}\n")
@task.command("schedule_origins")
@click.argument("type", nargs=1, required=True)
@click.argument("options", nargs=-1)
@click.option(
"--batch-size",
"-b",
"origin_batch_size",
default=10,
show_default=True,
type=int,
help="Number of origins per task",
)
@click.option(
"--page-token",
default=0,
show_default=True,
type=str,
help="Only schedule tasks for origins whose ID is greater",
)
@click.option(
"--limit",
default=None,
type=int,
help="Limit the tasks scheduling up to this number of tasks",
)
@click.option("--storage-url", "-g", help="URL of the (graph) storage API")
@click.option(
"--dry-run/--no-dry-run",
is_flag=True,
default=False,
help="List only what would be scheduled.",
)
@click.pass_context
def schedule_origin_metadata_index(
ctx, type, options, storage_url, origin_batch_size, page_token, limit, dry_run
):
"""Schedules tasks for origins that are already known.
The first argument is the name of the task type, further ones are
keyword argument(s) of the task in the form key=value, where value is
in YAML format.
Usage sample:
swh-scheduler --database 'service=swh-scheduler' \
task schedule_origins index-origin-metadata
"""
from itertools import islice
from swh.storage import get_storage
from .utils import parse_options, schedule_origin_batches
scheduler = ctx.obj["scheduler"]
storage = get_storage("remote", url=storage_url)
if dry_run:
scheduler = None
(args, kw) = parse_options(options)
if args:
raise click.ClickException("Only keywords arguments are allowed.")
origins = iter_origins(storage, page_token=page_token)
if limit:
origins = islice(origins, limit)
origin_urls = (origin.url for origin in origins)
schedule_origin_batches(scheduler, type, origin_urls, origin_batch_size, kw)
@task.command("list-pending")
@click.argument("task-types", required=True, nargs=-1)
@click.option(
"--limit",
"-l",
"num_tasks",
required=False,
type=click.INT,
help="The maximum number of tasks to fetch",
)
@click.option(
"--before",
"-b",
required=False,
type=DATETIME,
help="List all jobs supposed to run before the given date",
)
@click.pass_context
def list_pending_tasks(ctx, task_types, num_tasks, before):
"""List tasks with no priority that are going to be run.
You can override the number of tasks to fetch with the --limit flag.
"""
scheduler = ctx.obj["scheduler"]
if not scheduler:
raise ValueError("Scheduler class (local/remote) must be instantiated")
output = []
for task_type in task_types:
pending = scheduler.peek_ready_tasks(
- task_type, timestamp=before, num_tasks=num_tasks,
+ task_type,
+ timestamp=before,
+ num_tasks=num_tasks,
)
output.append("Found %d %s tasks\n" % (len(pending), task_type))
for task in pending:
output.append(pretty_print_task(task))
click.echo("\n".join(output))
@task.command("list")
@click.option(
"--task-id",
"-i",
default=None,
multiple=True,
metavar="ID",
help="List only tasks whose id is ID.",
)
@click.option(
"--task-type",
"-t",
default=None,
multiple=True,
metavar="TYPE",
help="List only tasks of type TYPE",
)
@click.option(
"--limit",
"-l",
required=False,
type=click.INT,
help="The maximum number of tasks to fetch.",
)
@click.option(
"--status",
"-s",
multiple=True,
metavar="STATUS",
type=click.Choice(
("next_run_not_scheduled", "next_run_scheduled", "completed", "disabled")
),
default=None,
help="List tasks whose status is STATUS.",
)
@click.option(
"--policy",
"-p",
default=None,
type=click.Choice(["recurring", "oneshot"]),
help="List tasks whose policy is POLICY.",
)
@click.option(
"--priority",
"-P",
default=None,
multiple=True,
type=click.Choice(["all", "low", "normal", "high"]),
help="List tasks whose priority is PRIORITY.",
)
@click.option(
"--before",
"-b",
required=False,
type=DATETIME,
metavar="DATETIME",
help="Limit to tasks supposed to run before the given date.",
)
@click.option(
"--after",
"-a",
required=False,
type=DATETIME,
metavar="DATETIME",
help="Limit to tasks supposed to run after the given date.",
)
@click.option(
"--list-runs",
"-r",
is_flag=True,
default=False,
help="Also list past executions of each task.",
)
@click.pass_context
def list_tasks(
ctx, task_id, task_type, limit, status, policy, priority, before, after, list_runs
):
- """List tasks.
- """
+ """List tasks."""
from operator import itemgetter
scheduler = ctx.obj["scheduler"]
if not scheduler:
raise ValueError("Scheduler class (local/remote) must be instantiated")
if not task_type:
task_type = [x["type"] for x in scheduler.get_task_types()]
# if task_id is not given, default value for status is
# 'next_run_not_scheduled'
# if task_id is given, default status is 'all'
if task_id is None and status is None:
status = ["next_run_not_scheduled"]
if status and "all" in status:
status = None
if priority and "all" in priority:
priority = None
output = []
tasks = scheduler.search_tasks(
task_id=task_id,
task_type=task_type,
status=status,
priority=priority,
policy=policy,
before=before,
after=after,
limit=limit,
)
if list_runs:
runs = {t["id"]: [] for t in tasks}
for r in scheduler.get_task_runs([task["id"] for task in tasks]):
runs[r["task"]].append(r)
else:
runs = {}
output.append("Found %d tasks\n" % (len(tasks)))
for task in sorted(tasks, key=itemgetter("id")):
output.append(pretty_print_task(task, full=True))
if runs.get(task["id"]):
output.append(click.style(" Executions:", bold=True))
for run in sorted(runs[task["id"]], key=itemgetter("id")):
output.append(pretty_print_run(run, indent=4))
click.echo("\n".join(output))
@task.command("respawn")
@click.argument("task-ids", required=True, nargs=-1)
@click.option(
"--next-run",
"-n",
required=False,
type=DATETIME,
metavar="DATETIME",
default=None,
help="Re spawn the selected tasks at this date",
)
@click.pass_context
def respawn_tasks(ctx, task_ids: List[str], next_run: datetime.datetime):
"""Respawn tasks.
Respawn tasks given by their ids (see the 'task list' command to
find task ids) at the given date (immediately by default).
Eg.
swh-scheduler task respawn 1 3 12
"""
from swh.scheduler.utils import utcnow
scheduler = ctx.obj["scheduler"]
if not scheduler:
raise ValueError("Scheduler class (local/remote) must be instantiated")
if next_run is None:
next_run = utcnow()
output = []
task_ids_int = [int(id_) for id_ in task_ids]
scheduler.set_status_tasks(
task_ids_int, status="next_run_not_scheduled", next_run=next_run
)
output.append("Respawn tasks %s\n" % (task_ids_int,))
click.echo("\n".join(output))
diff --git a/swh/scheduler/cli/task_type.py b/swh/scheduler/cli/task_type.py
index 1557007..ba56e32 100644
--- a/swh/scheduler/cli/task_type.py
+++ b/swh/scheduler/cli/task_type.py
@@ -1,231 +1,230 @@
# Copyright (C) 2016-2019 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from importlib import import_module
import logging
from typing import Mapping
# WARNING: do not import unnecessary things here to keep cli startup time under
# control
import click
from pkg_resources import iter_entry_points
from . import cli
logger = logging.getLogger(__name__)
DEFAULT_TASK_TYPE = {
"full": { # for tasks like 'list_xxx_full()'
"default_interval": "90 days",
"min_interval": "90 days",
"max_interval": "90 days",
"backoff_factor": 1,
},
"*": { # value if not suffix matches
"default_interval": "1 day",
"min_interval": "1 day",
"max_interval": "1 day",
"backoff_factor": 1,
},
}
PLUGIN_WORKER_DESCRIPTIONS = {
entry_point.name: entry_point for entry_point in iter_entry_points("swh.workers")
}
@cli.group("task-type")
@click.pass_context
def task_type(ctx):
"""Manipulate task types."""
scheduler = ctx.obj["scheduler"]
if not scheduler:
raise ValueError("Scheduler class (local/remote) must be instantiated")
@task_type.command("list")
@click.option("--verbose", "-v", is_flag=True, default=False, help="Verbose mode")
@click.option(
"--task_type",
"-t",
multiple=True,
default=None,
help="List task types of given type",
)
@click.option(
"--task_name",
"-n",
multiple=True,
default=None,
help="List task types of given backend task name",
)
@click.pass_context
def list_task_types(ctx, verbose, task_type, task_name):
click.echo("Known task types:")
if verbose:
tmpl = (
click.style("{type}: ", bold=True)
+ """{backend_name}
{description}
interval: {default_interval} [{min_interval}, {max_interval}]
backoff_factor: {backoff_factor}
max_queue_length: {max_queue_length}
num_retries: {num_retries}
retry_delay: {retry_delay}
"""
)
else:
tmpl = "{type}:\n {description}"
for tasktype in sorted(
ctx.obj["scheduler"].get_task_types(), key=lambda x: x["type"]
):
if task_type and tasktype["type"] not in task_type:
continue
if task_name and tasktype["backend_name"] not in task_name:
continue
click.echo(tmpl.format(**tasktype))
@task_type.command("register")
@click.option(
"--plugins",
"-p",
"plugins",
multiple=True,
default=("all",),
type=click.Choice(["all"] + list(PLUGIN_WORKER_DESCRIPTIONS)),
help="Registers task-types for provided plugins. " "Defaults to all",
)
@click.pass_context
def register_task_types(ctx, plugins):
"""Register missing task-type entries in the scheduler.
According to declared tasks in each loaded worker (e.g. lister, loader,
...) plugins.
"""
import celery.app.task
scheduler = ctx.obj["scheduler"]
if plugins == ("all",):
plugins = list(PLUGIN_WORKER_DESCRIPTIONS)
for plugin in plugins:
entrypoint = PLUGIN_WORKER_DESCRIPTIONS[plugin]
logger.info("Loading entrypoint for plugin %s", plugin)
registry_entry = entrypoint.load()()
for task_module in registry_entry["task_modules"]:
mod = import_module(task_module)
for task_name in (x for x in dir(mod) if not x.startswith("_")):
logger.debug("Loading task name %s", task_name)
taskobj = getattr(mod, task_name)
if isinstance(taskobj, celery.app.task.Task):
tt_name = task_name.replace("_", "-")
task_cfg = registry_entry.get("task_types", {}).get(tt_name, {})
ensure_task_type(task_module, tt_name, taskobj, task_cfg, scheduler)
def ensure_task_type(
task_module: str, task_type: str, swhtask, task_config: Mapping, scheduler
):
"""Ensure a given task-type (for the task_module) exists in the scheduler.
Args:
task_module: task module we are currently checking for task type
consistency
task_type: the type of the task to check/insert (correspond to
the 'type' field in the db)
swhtask (SWHTask): the SWHTask instance the task-type correspond to
task_config: a dict with specific/overloaded values for the
task-type to be created
scheduler: the scheduler object used to access the scheduler db
"""
for suffix, defaults in DEFAULT_TASK_TYPE.items():
if task_type.endswith("-" + suffix):
task_type_dict = defaults.copy()
break
else:
task_type_dict = DEFAULT_TASK_TYPE["*"].copy()
task_type_dict["type"] = task_type
task_type_dict["backend_name"] = swhtask.name
if swhtask.__doc__:
task_type_dict["description"] = swhtask.__doc__.splitlines()[0]
task_type_dict.update(task_config)
current_task_type = scheduler.get_task_type(task_type)
if current_task_type:
# Ensure the existing task_type is consistent in the scheduler
if current_task_type["backend_name"] != task_type_dict["backend_name"]:
logger.warning(
"Existing task type %s for module %s has a "
"different backend name than current "
"code version provides (%s vs. %s)",
task_type,
task_module,
current_task_type["backend_name"],
task_type_dict["backend_name"],
)
else:
logger.info("Create task type %s in scheduler", task_type)
logger.debug(" %s", task_type_dict)
scheduler.create_task_type(task_type_dict)
@task_type.command("add")
@click.argument("type", required=True)
@click.argument("task-name", required=True)
@click.argument("description", required=True)
@click.option(
"--default-interval",
"-i",
default="90 days",
help='Default interval ("90 days" by default)',
)
@click.option(
"--min-interval",
default=None,
help="Minimum interval (default interval if not set)",
)
@click.option(
"--max-interval",
"-i",
default=None,
help="Maximal interval (default interval if not set)",
)
@click.option("--backoff-factor", "-f", type=float, default=1, help="Backoff factor")
@click.pass_context
def add_task_type(
ctx,
type,
task_name,
description,
default_interval,
min_interval,
max_interval,
backoff_factor,
):
- """Create a new task type
- """
+ """Create a new task type"""
task_type = dict(
type=type,
backend_name=task_name,
description=description,
default_interval=default_interval,
min_interval=min_interval,
max_interval=max_interval,
backoff_factor=backoff_factor,
max_queue_length=None,
num_retries=None,
retry_delay=None,
)
ctx.obj["scheduler"].create_task_type(task_type)
click.echo("OK")
diff --git a/swh/scheduler/interface.py b/swh/scheduler/interface.py
index 033235f..93e12cf 100644
--- a/swh/scheduler/interface.py
+++ b/swh/scheduler/interface.py
@@ -1,499 +1,502 @@
# Copyright (C) 2015-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import datetime
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from uuid import UUID
from typing_extensions import Protocol, runtime_checkable
from swh.core.api import remote_api_endpoint
from swh.core.api.classes import PagedResult
from swh.scheduler.model import ListedOrigin, Lister, OriginVisitStats, SchedulerMetrics
ListedOriginPageToken = Tuple[str, str]
class PaginatedListedOriginList(PagedResult[ListedOrigin, ListedOriginPageToken]):
"""A list of listed origins, with a continuation token"""
def __init__(
self,
results: List[ListedOrigin],
next_page_token: Union[None, ListedOriginPageToken, List[str]],
):
parsed_next_page_token: Optional[Tuple[str, str]] = None
if next_page_token is not None:
if len(next_page_token) != 2:
raise TypeError("Expected Tuple[str, str] or list of size 2.")
parsed_next_page_token = tuple(next_page_token) # type: ignore
super().__init__(results, parsed_next_page_token)
@runtime_checkable
class SchedulerInterface(Protocol):
@remote_api_endpoint("task_type/create")
def create_task_type(self, task_type):
"""Create a new task type ready for scheduling.
Args:
task_type (dict): a dictionary with the following keys:
- type (str): an identifier for the task type
- description (str): a human-readable description of what the
task does
- backend_name (str): the name of the task in the
job-scheduling backend
- default_interval (datetime.timedelta): the default interval
between two task runs
- min_interval (datetime.timedelta): the minimum interval
between two task runs
- max_interval (datetime.timedelta): the maximum interval
between two task runs
- backoff_factor (float): the factor by which the interval
changes at each run
- max_queue_length (int): the maximum length of the task queue
for this task type
"""
...
@remote_api_endpoint("task_type/get")
def get_task_type(self, task_type_name):
"""Retrieve the task type with id task_type_name"""
...
@remote_api_endpoint("task_type/get_all")
def get_task_types(self):
"""Retrieve all registered task types"""
...
@remote_api_endpoint("task/create")
def create_tasks(self, tasks, policy="recurring"):
"""Create new tasks.
Args:
tasks (list): each task is a dictionary with the following keys:
- type (str): the task type
- arguments (dict): the arguments for the task runner, keys:
- args (list of str): arguments
- kwargs (dict str -> str): keyword arguments
- next_run (datetime.datetime): the next scheduled run for the
task
Returns:
a list of created tasks.
"""
...
@remote_api_endpoint("task/set_status")
def set_status_tasks(
self,
task_ids: List[int],
status: str = "disabled",
next_run: Optional[datetime.datetime] = None,
):
"""Set the tasks' status whose ids are listed.
If given, also set the next_run date.
"""
...
@remote_api_endpoint("task/disable")
def disable_tasks(self, task_ids):
"""Disable the tasks whose ids are listed."""
...
@remote_api_endpoint("task/search")
def search_tasks(
self,
task_id=None,
task_type=None,
status=None,
priority=None,
policy=None,
before=None,
after=None,
limit=None,
):
"""Search tasks from selected criterions"""
...
@remote_api_endpoint("task/get")
def get_tasks(self, task_ids):
"""Retrieve the info of tasks whose ids are listed."""
...
@remote_api_endpoint("task/peek_ready")
def peek_ready_tasks(
self,
task_type: str,
timestamp: Optional[datetime.datetime] = None,
num_tasks: Optional[int] = None,
) -> List[Dict]:
"""Fetch the list of tasks (with no priority) to be scheduled.
Args:
task_type: filtering task per their type
timestamp: peek tasks that need to be executed
before that timestamp
num_tasks: only peek at num_tasks tasks (with no priority)
Returns:
the list of tasks which would be scheduled
"""
...
@remote_api_endpoint("task/grab_ready")
def grab_ready_tasks(
self,
task_type: str,
timestamp: Optional[datetime.datetime] = None,
num_tasks: Optional[int] = None,
) -> List[Dict]:
"""Fetch and schedule the list of tasks (with no priority) ready to be scheduled.
Args:
task_type: filtering task per their type
timestamp: grab tasks that need to be executed
before that timestamp
num_tasks: only grab num_tasks tasks (with no priority)
Returns:
the list of scheduled tasks
"""
...
@remote_api_endpoint("task/peek_ready_with_priority")
def peek_ready_priority_tasks(
self,
task_type: str,
timestamp: Optional[datetime.datetime] = None,
num_tasks: Optional[int] = None,
) -> List[Dict]:
"""Fetch list of tasks (with any priority) ready to be scheduled.
Args:
task_type: filtering task per their type
timestamp: peek tasks that need to be executed before that timestamp
num_tasks: only peek at num_tasks tasks (with no priority)
Returns:
a list of tasks
"""
...
@remote_api_endpoint("task/grab_ready_with_priority")
def grab_ready_priority_tasks(
self,
task_type: str,
timestamp: Optional[datetime.datetime] = None,
num_tasks: Optional[int] = None,
) -> List[Dict]:
"""Fetch and schedule the list of tasks (with any priority) ready to be scheduled.
Args:
task_type: filtering task per their type
timestamp: grab tasks that need to be executed
before that timestamp
num_tasks: only grab num_tasks tasks (with no priority)
Returns:
a list of tasks
"""
...
@remote_api_endpoint("task_run/schedule_one")
def schedule_task_run(self, task_id, backend_id, metadata=None, timestamp=None):
"""Mark a given task as scheduled, adding a task_run entry in the database.
Args:
task_id (int): the identifier for the task being scheduled
backend_id (str): the identifier of the job in the backend
metadata (dict): metadata to add to the task_run entry
timestamp (datetime.datetime): the instant the event occurred
Returns:
a fresh task_run entry
"""
...
@remote_api_endpoint("task_run/schedule")
def mass_schedule_task_runs(self, task_runs):
"""Schedule a bunch of task runs.
Args:
task_runs (list): a list of dicts with keys:
- task (int): the identifier for the task being scheduled
- backend_id (str): the identifier of the job in the backend
- metadata (dict): metadata to add to the task_run entry
- scheduled (datetime.datetime): the instant the event occurred
Returns:
None
"""
...
@remote_api_endpoint("task_run/start")
def start_task_run(self, backend_id, metadata=None, timestamp=None):
"""Mark a given task as started, updating the corresponding task_run
entry in the database.
Args:
backend_id (str): the identifier of the job in the backend
metadata (dict): metadata to add to the task_run entry
timestamp (datetime.datetime): the instant the event occurred
Returns:
the updated task_run entry
"""
...
@remote_api_endpoint("task_run/end")
def end_task_run(
- self, backend_id, status, metadata=None, timestamp=None, result=None,
+ self,
+ backend_id,
+ status,
+ metadata=None,
+ timestamp=None,
+ result=None,
):
"""Mark a given task as ended, updating the corresponding task_run entry in the
database.
Args:
backend_id (str): the identifier of the job in the backend
status (str): how the task ended; one of: 'eventful', 'uneventful',
'failed'
metadata (dict): metadata to add to the task_run entry
timestamp (datetime.datetime): the instant the event occurred
Returns:
the updated task_run entry
"""
...
@remote_api_endpoint("task/filter_for_archive")
def filter_task_to_archive(
self,
after_ts: str,
before_ts: str,
limit: int = 10,
page_token: Optional[str] = None,
) -> Dict[str, Any]:
"""Compute the tasks to archive within the datetime interval
[after_ts, before_ts[. The method returns a paginated result.
Returns:
dict with the following keys:
- **next_page_token**: opaque token to be used as
`page_token` to retrieve the next page of result. If absent,
there is no more pages to gather.
- **tasks**: list of task dictionaries with the following keys:
**id** (str): origin task id
**started** (Optional[datetime]): started date
**scheduled** (datetime): scheduled date
**arguments** (json dict): task's arguments
...
"""
...
@remote_api_endpoint("task/delete_archived")
def delete_archived_tasks(self, task_ids):
"""Delete archived tasks as much as possible. Only the task_ids whose
- complete associated task_run have been cleaned up will be.
+ complete associated task_run have been cleaned up will be.
"""
...
@remote_api_endpoint("task_run/get")
def get_task_runs(self, task_ids, limit=None):
"""Search task run for a task id"""
...
@remote_api_endpoint("listers/get")
def get_listers(self) -> List[Lister]:
- """Retrieve information about all listers from the database.
- """
+ """Retrieve information about all listers from the database."""
...
@remote_api_endpoint("lister/get")
def get_lister(
self, name: str, instance_name: Optional[str] = None
) -> Optional[Lister]:
"""Retrieve information about the given instance of the lister from the
database.
"""
...
@remote_api_endpoint("lister/get_or_create")
def get_or_create_lister(
self, name: str, instance_name: Optional[str] = None
) -> Lister:
"""Retrieve information about the given instance of the lister from the
database, or create the entry if it did not exist.
"""
...
@remote_api_endpoint("lister/update")
def update_lister(self, lister: Lister) -> Lister:
"""Update the state for the given lister instance in the database.
Returns:
a new Lister object, with all fields updated from the database
Raises:
StaleData if the `updated` timestamp for the lister instance in
database doesn't match the one passed by the user.
"""
...
@remote_api_endpoint("origins/record")
def record_listed_origins(
self, listed_origins: Iterable[ListedOrigin]
) -> List[ListedOrigin]:
"""Record a set of origins that a lister has listed.
This performs an "upsert": origins with the same (lister_id, url,
visit_type) values are updated with new values for
extra_loader_arguments, last_update and last_seen.
"""
...
@remote_api_endpoint("origins/get")
def get_listed_origins(
self,
lister_id: Optional[UUID] = None,
url: Optional[str] = None,
limit: int = 1000,
page_token: Optional[ListedOriginPageToken] = None,
) -> PaginatedListedOriginList:
"""Get information on the listed origins matching either the `url` or
`lister_id`, or both arguments.
Use the `limit` and `page_token` arguments for continuation. The next
page token, if any, is returned in the PaginatedListedOriginList object.
"""
...
@remote_api_endpoint("origins/grab_next")
def grab_next_visits(
self,
visit_type: str,
count: int,
policy: str,
enabled: bool = True,
lister_uuid: Optional[str] = None,
timestamp: Optional[datetime.datetime] = None,
scheduled_cooldown: Optional[datetime.timedelta] = datetime.timedelta(days=7),
failed_cooldown: Optional[datetime.timedelta] = datetime.timedelta(days=14),
not_found_cooldown: Optional[datetime.timedelta] = datetime.timedelta(days=31),
tablesample: Optional[float] = None,
) -> List[ListedOrigin]:
"""Get at most the `count` next origins that need to be visited with
the `visit_type` loader according to the given scheduling `policy`.
This will mark the origins as scheduled in the origin_visit_stats
table, to avoid scheduling multiple visits to the same origin.
Arguments:
visit_type: type of visits to schedule
count: number of visits to schedule
policy: the scheduling policy used to select which visits to schedule
enabled: Determine whether we want to list enabled or disabled origins. As
default, we want reasonably enabled origins. For some edge case, we might
want the others.
lister_uuid: Determine the list of origins listed from the lister with uuid
timestamp: the mocked timestamp at which we're recording that the visits are
being scheduled (defaults to the current time)
scheduled_cooldown: the minimal interval before which we can schedule
the same origin again
failed_cooldown: the minimal interval before which we can reschedule a
failed origin
not_found_cooldown: the minimal interval before which we can reschedule a
not_found origin
tablesample: the percentage of the table on which we run the query
(None: no sampling)
"""
...
@remote_api_endpoint("visit_stats/upsert")
def origin_visit_stats_upsert(
self, origin_visit_stats: Iterable[OriginVisitStats]
) -> None:
- """Create a new origin visit stats
- """
+ """Create a new origin visit stats"""
...
@remote_api_endpoint("visit_stats/get")
def origin_visit_stats_get(
self, ids: Iterable[Tuple[str, str]]
) -> List[OriginVisitStats]:
"""Retrieve the stats for an origin with a given visit type
If some visit_stats are not found, they are filtered out of the result. So the
output list may be of length inferior to the length of the input list.
"""
...
@remote_api_endpoint("visit_scheduler/get")
- def visit_scheduler_queue_position_get(self,) -> Dict[str, int]:
+ def visit_scheduler_queue_position_get(
+ self,
+ ) -> Dict[str, int]:
"""Retrieve all current queue positions for the recurrent visit scheduler.
Returns
Mapping of visit type to their current queue position
"""
...
@remote_api_endpoint("visit_scheduler/set")
def visit_scheduler_queue_position_set(
self, visit_type: str, position: int
) -> None:
- """Set the current queue position of the recurrent visit scheduler for `visit_type`.
-
- """
+ """Set the current queue position of the recurrent visit scheduler for `visit_type`."""
...
@remote_api_endpoint("scheduler_metrics/update")
def update_metrics(
self,
lister_id: Optional[UUID] = None,
timestamp: Optional[datetime.datetime] = None,
) -> List[SchedulerMetrics]:
"""Update the performance metrics of this scheduler instance.
Returns the updated metrics.
Args:
lister_id: if passed, update the metrics only for this lister instance
timestamp: if passed, the date at which we're updating the metrics,
defaults to the database NOW()
"""
...
@remote_api_endpoint("scheduler_metrics/get")
def get_metrics(
self, lister_id: Optional[UUID] = None, visit_type: Optional[str] = None
) -> List[SchedulerMetrics]:
"""Retrieve the performance metrics of this scheduler instance.
Args:
lister_id: filter the metrics for this lister instance only
visit_type: filter the metrics for this visit type only
"""
...
diff --git a/swh/scheduler/model.py b/swh/scheduler/model.py
index 61a52c1..4a64ffe 100644
--- a/swh/scheduler/model.py
+++ b/swh/scheduler/model.py
@@ -1,269 +1,274 @@
# Copyright (C) 2020-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
from uuid import UUID
import attr
import attr.converters
from attrs_strict import type_validator
def check_timestamptz(value) -> None:
"""Checks the date has a timezone."""
if value is not None and value.tzinfo is None:
raise ValueError("date must be a timezone-aware datetime.")
@attr.s
class BaseSchedulerModel:
"""Base class for database-backed objects.
These database-backed objects are defined through attrs-based attributes
that match the columns of the database 1:1. This is a (very) lightweight
ORM.
These attrs-based attributes have metadata specific to the functionality
expected from these fields in the database:
- `primary_key`: the column is a primary key; it should be filtered out
when doing an `update` of the object
- `auto_primary_key`: the column is a primary key, which is automatically handled
by the database. It will not be inserted to. This must be matched with a
database-side default value.
- `auto_now_add`: the column is a timestamp that is set to the current time when
the object is inserted, and never updated afterwards. This must be matched with
a database-side default value.
- `auto_now`: the column is a timestamp that is set to the current time when
the object is inserted or updated.
"""
_pk_cols: Optional[Tuple[str, ...]] = None
_select_cols: Optional[Tuple[str, ...]] = None
_insert_cols_and_metavars: Optional[Tuple[Tuple[str, ...], Tuple[str, ...]]] = None
@classmethod
def primary_key_columns(cls) -> Tuple[str, ...]:
"""Get the primary key columns for this object type"""
if cls._pk_cols is None:
columns: List[str] = []
for field in attr.fields(cls):
if any(
field.metadata.get(flag)
for flag in ("auto_primary_key", "primary_key")
):
columns.append(field.name)
cls._pk_cols = tuple(sorted(columns))
return cls._pk_cols
@classmethod
def select_columns(cls) -> Tuple[str, ...]:
"""Get all the database columns needed for a `select` on this object type"""
if cls._select_cols is None:
columns: List[str] = []
for field in attr.fields(cls):
columns.append(field.name)
cls._select_cols = tuple(sorted(columns))
return cls._select_cols
@classmethod
def insert_columns_and_metavars(cls) -> Tuple[Tuple[str, ...], Tuple[str, ...]]:
"""Get the database columns and metavars needed for an `insert` or `update` on
this object type.
This implements support for the `auto_*` field metadata attributes.
"""
if cls._insert_cols_and_metavars is None:
zipped_cols_and_metavars: List[Tuple[str, str]] = []
for field in attr.fields(cls):
if any(
field.metadata.get(flag)
for flag in ("auto_now_add", "auto_primary_key")
):
continue
elif field.metadata.get("auto_now"):
zipped_cols_and_metavars.append((field.name, "now()"))
else:
zipped_cols_and_metavars.append((field.name, f"%({field.name})s"))
zipped_cols_and_metavars.sort()
cols, metavars = zip(*zipped_cols_and_metavars)
cls._insert_cols_and_metavars = cols, metavars
return cls._insert_cols_and_metavars
@attr.s
class Lister(BaseSchedulerModel):
name = attr.ib(type=str, validator=[type_validator()])
instance_name = attr.ib(type=str, validator=[type_validator()])
# Populated by database
id = attr.ib(
type=Optional[UUID],
validator=type_validator(),
default=None,
metadata={"auto_primary_key": True},
)
current_state = attr.ib(
type=Dict[str, Any], validator=[type_validator()], factory=dict
)
created = attr.ib(
type=Optional[datetime.datetime],
validator=[type_validator()],
default=None,
metadata={"auto_now_add": True},
)
updated = attr.ib(
type=Optional[datetime.datetime],
validator=[type_validator()],
default=None,
metadata={"auto_now": True},
)
@attr.s
class ListedOrigin(BaseSchedulerModel):
"""Basic information about a listed origin, output by a lister"""
lister_id = attr.ib(
type=UUID, validator=[type_validator()], metadata={"primary_key": True}
)
url = attr.ib(
type=str, validator=[type_validator()], metadata={"primary_key": True}
)
visit_type = attr.ib(
type=str, validator=[type_validator()], metadata={"primary_key": True}
)
extra_loader_arguments = attr.ib(
type=Dict[str, Any], validator=[type_validator()], factory=dict
)
last_update = attr.ib(
- type=Optional[datetime.datetime], validator=[type_validator()], default=None,
+ type=Optional[datetime.datetime],
+ validator=[type_validator()],
+ default=None,
)
enabled = attr.ib(type=bool, validator=[type_validator()], default=True)
first_seen = attr.ib(
type=Optional[datetime.datetime],
validator=[type_validator()],
default=None,
metadata={"auto_now_add": True},
)
last_seen = attr.ib(
type=Optional[datetime.datetime],
validator=[type_validator()],
default=None,
metadata={"auto_now": True},
)
def as_task_dict(self):
return {
"type": f"load-{self.visit_type}",
"arguments": {
"args": [],
"kwargs": {"url": self.url, **self.extra_loader_arguments},
},
}
class LastVisitStatus(Enum):
successful = "successful"
failed = "failed"
not_found = "not_found"
def convert_last_visit_status(
s: Union[None, str, LastVisitStatus]
) -> Optional[LastVisitStatus]:
if not isinstance(s, str):
return s
return LastVisitStatus(s)
@attr.s(frozen=True, slots=True)
class OriginVisitStats(BaseSchedulerModel):
- """Represents an aggregated origin visits view.
- """
+ """Represents an aggregated origin visits view."""
url = attr.ib(
type=str, validator=[type_validator()], metadata={"primary_key": True}
)
visit_type = attr.ib(
type=str, validator=[type_validator()], metadata={"primary_key": True}
)
last_successful = attr.ib(
type=Optional[datetime.datetime], validator=type_validator(), default=None
)
last_visit = attr.ib(
type=Optional[datetime.datetime], validator=type_validator(), default=None
)
last_visit_status = attr.ib(
type=Optional[LastVisitStatus],
validator=type_validator(),
default=None,
converter=convert_last_visit_status,
)
last_scheduled = attr.ib(
- type=Optional[datetime.datetime], validator=[type_validator()], default=None,
+ type=Optional[datetime.datetime],
+ validator=[type_validator()],
+ default=None,
)
last_snapshot = attr.ib(
type=Optional[bytes], validator=type_validator(), default=None
)
next_visit_queue_position = attr.ib(
type=Optional[int], validator=type_validator(), default=None
)
next_position_offset = attr.ib(type=int, validator=type_validator(), default=4)
successive_visits = attr.ib(type=int, validator=type_validator(), default=1)
@last_successful.validator
def check_last_successful(self, attribute, value):
check_timestamptz(value)
@last_visit.validator
def check_last_visit(self, attribute, value):
check_timestamptz(value)
@attr.s(frozen=True, slots=True)
class SchedulerMetrics(BaseSchedulerModel):
"""Metrics for the scheduler, aggregated by (lister_id, visit_type)"""
lister_id = attr.ib(
type=UUID, validator=[type_validator()], metadata={"primary_key": True}
)
visit_type = attr.ib(
type=str, validator=[type_validator()], metadata={"primary_key": True}
)
last_update = attr.ib(
- type=Optional[datetime.datetime], validator=[type_validator()], default=None,
+ type=Optional[datetime.datetime],
+ validator=[type_validator()],
+ default=None,
)
origins_known = attr.ib(type=int, validator=[type_validator()], default=0)
"""Number of known (enabled or disabled) origins"""
origins_enabled = attr.ib(type=int, validator=[type_validator()], default=0)
"""Number of origins that were present in the latest listings"""
origins_never_visited = attr.ib(type=int, validator=[type_validator()], default=0)
"""Number of enabled origins that have never been visited
(according to the visit cache)"""
origins_with_pending_changes = attr.ib(
type=int, validator=[type_validator()], default=0
)
"""Number of enabled origins with known activity (recorded by a lister)
since our last visit"""
diff --git a/swh/scheduler/pytest_plugin.py b/swh/scheduler/pytest_plugin.py
index 870d9e2..0b5b1f4 100644
--- a/swh/scheduler/pytest_plugin.py
+++ b/swh/scheduler/pytest_plugin.py
@@ -1,111 +1,112 @@
# Copyright (C) 2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from datetime import timedelta
from functools import partial
from celery.contrib.testing import worker
from celery.contrib.testing.app import TestApp, setup_default_app
import pkg_resources
import pytest
from pytest_postgresql import factories
from swh.core.db.pytest_plugin import initialize_database_for_module, postgresql_fact
from swh.scheduler import get_scheduler
from swh.scheduler.backend import SchedulerBackend
# celery tasks for testing purpose; tasks themselves should be
# in swh/scheduler/tests/tasks.py
TASK_NAMES = ["ping", "multiping", "add", "error", "echo"]
scheduler_postgresql_proc = factories.postgresql_proc(
dbname="scheduler",
load=[
partial(
initialize_database_for_module,
modname="scheduler",
version=SchedulerBackend.current_version,
)
],
)
postgresql_scheduler = postgresql_fact("scheduler_postgresql_proc")
@pytest.fixture
def swh_scheduler_config(request, postgresql_scheduler):
return {
"db": postgresql_scheduler.dsn,
}
@pytest.fixture
def swh_scheduler(swh_scheduler_config):
scheduler = get_scheduler("local", **swh_scheduler_config)
for taskname in TASK_NAMES:
scheduler.create_task_type(
{
"type": "swh-test-{}".format(taskname),
"description": "The {} testing task".format(taskname),
"backend_name": "swh.scheduler.tests.tasks.{}".format(taskname),
"default_interval": timedelta(days=1),
"min_interval": timedelta(hours=6),
"max_interval": timedelta(days=12),
}
)
return scheduler
# this alias is used to be able to easily instantiate a db-backed Scheduler
# eg. for the RPC client/server test suite.
swh_db_scheduler = swh_scheduler
@pytest.fixture(scope="session")
def swh_scheduler_celery_app():
"""Set up a Celery app as swh.scheduler and swh worker tests would expect it"""
test_app = TestApp(
set_as_current=True,
enable_logging=True,
task_cls="swh.scheduler.task:SWHTask",
config={
"accept_content": ["application/x-msgpack", "application/json"],
"broker_url": "memory://guest@localhost//",
"task_serializer": "msgpack",
"result_serializer": "json",
},
)
with setup_default_app(test_app, use_trap=False):
from swh.scheduler.celery_backend import config
config.app = test_app
test_app.set_default()
test_app.set_current()
yield test_app
@pytest.fixture(scope="session")
def swh_scheduler_celery_includes():
"""List of task modules that should be loaded by the swh_scheduler_celery_worker on
-startup."""
+ startup."""
task_modules = ["swh.scheduler.tests.tasks"]
for entrypoint in pkg_resources.iter_entry_points("swh.workers"):
task_modules.extend(entrypoint.load()().get("task_modules", []))
return task_modules
@pytest.fixture(scope="session")
def swh_scheduler_celery_worker(
- swh_scheduler_celery_app, swh_scheduler_celery_includes,
+ swh_scheduler_celery_app,
+ swh_scheduler_celery_includes,
):
"""Spawn a worker"""
for module in swh_scheduler_celery_includes:
swh_scheduler_celery_app.loader.import_task_module(module)
with worker.start_worker(swh_scheduler_celery_app, pool="solo") as w:
yield w
diff --git a/swh/scheduler/simulator/origins.py b/swh/scheduler/simulator/origins.py
index 546bfc3..e888989 100644
--- a/swh/scheduler/simulator/origins.py
+++ b/swh/scheduler/simulator/origins.py
@@ -1,227 +1,227 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
"""This module implements a model of the frequency of updates of an origin
and how long it takes to load it.
For each origin, a commit frequency is chosen deterministically based on the
hash of its URL and assume all origins were created on an arbitrary epoch.
From this we compute a number of commits, that is the product of these two.
And the run time of a load task is approximated as proportional to the number
of commits since the previous visit of the origin (possibly 0)."""
from datetime import datetime, timedelta, timezone
import hashlib
import logging
from typing import Dict, Generator, Iterator, List, Optional, Tuple
import uuid
import attr
from simpy import Event
from swh.model.model import OriginVisitStatus
from swh.scheduler.model import ListedOrigin
from .common import Environment, Queue, Task, TaskEvent
logger = logging.getLogger(__name__)
_nb_generated_origins = 0
_visit_times: Dict[Tuple[str, str], datetime] = {}
"""Cache of the time of the last visit of (visit_type, origin_url),
to spare an SQL query (high latency)."""
def generate_listed_origin(
lister_id: uuid.UUID, now: Optional[datetime] = None
) -> ListedOrigin:
"""Returns a globally unique new origin. Seed the `last_update` value
according to the OriginModel and the passed timestamp.
Arguments:
lister: instance of the lister that generated this origin
now: time of listing, to emulate last_update (defaults to :func:`datetime.now`)
"""
global _nb_generated_origins
_nb_generated_origins += 1
- assert _nb_generated_origins < 10 ** 6, "Too many origins!"
+ assert _nb_generated_origins < 10**6, "Too many origins!"
if now is None:
now = datetime.now(tz=timezone.utc)
url = f"https://example.com/{_nb_generated_origins:06d}.git"
visit_type = "test-git"
origin = OriginModel(visit_type, url)
return ListedOrigin(
lister_id=lister_id,
url=url,
visit_type=visit_type,
last_update=origin.get_last_update(now),
)
class OriginModel:
MIN_RUN_TIME = 0.5
"""Minimal run time for a visit (retrieved from production data)"""
MAX_RUN_TIME = 7200
"""Max run time for a visit"""
PER_COMMIT_RUN_TIME = 0.1
"""Run time per commit"""
EPOCH = datetime(2015, 9, 1, 0, 0, 0, tzinfo=timezone.utc)
"""The origin of all origins (at least according to Software Heritage)"""
def __init__(self, type: str, origin: str):
self.type = type
self.origin = origin
def seconds_between_commits(self):
"""Returns a random 'average time between two commits' of this origin,
used to estimate the run time of a load task, and how much the loading
architecture is lagging behind origin updates."""
n_bytes = 2
num_buckets = 2 ** (8 * n_bytes)
# Deterministic seed to generate "random" characteristics of this origin
bucket = int.from_bytes(
hashlib.md5(self.origin.encode()).digest()[0:n_bytes], "little"
)
# minimum: 1 second (bucket == 0)
# max: 10 years (bucket == num_buckets - 1)
ten_y = 10 * 365 * 24 * 3600
return ten_y ** (bucket / num_buckets)
# return 1 + (ten_y - 1) * (bucket / (num_buckets - 1))
def get_last_update(self, now: datetime) -> datetime:
"""Get the last_update value for this origin.
We assume that the origin had its first commit at `EPOCH`, and that one
commit happened every `self.seconds_between_commits()`. This returns
the last commit date before or equal to `now`.
"""
_, time_since_last_commit = divmod(
(now - self.EPOCH).total_seconds(), self.seconds_between_commits()
)
return now - timedelta(seconds=time_since_last_commit)
def get_current_snapshot_id(self, now: datetime) -> bytes:
"""Get the current snapshot for this origin.
To generate a snapshot id, we calculate the number of commits since the
EPOCH, and hash it alongside the origin type and url.
"""
commits_since_epoch, _ = divmod(
(now - self.EPOCH).total_seconds(), self.seconds_between_commits()
)
return hashlib.sha1(
f"{self.type} {self.origin} {commits_since_epoch}".encode()
).digest()
def load_task_characteristics(
self, now: datetime
) -> Tuple[float, str, Optional[bytes]]:
"""Returns the (run_time, end_status, snapshot id) of the next
origin visit."""
current_snapshot = self.get_current_snapshot_id(now)
key = (self.type, self.origin)
last_visit = _visit_times.get(key, now - timedelta(days=365))
time_since_last_successful_run = now - last_visit
_visit_times[key] = now
seconds_between_commits = self.seconds_between_commits()
seconds_since_last_successful = time_since_last_successful_run.total_seconds()
n_commits = int(seconds_since_last_successful / seconds_between_commits)
logger.debug(
"%s characteristics %s origin=%s: Interval: %s, n_commits: %s",
now,
self.type,
self.origin,
timedelta(seconds=seconds_between_commits),
n_commits,
)
run_time = self.MIN_RUN_TIME + self.PER_COMMIT_RUN_TIME * n_commits
if run_time > self.MAX_RUN_TIME:
# Long visits usually fail
return (self.MAX_RUN_TIME, "partial", None)
else:
return (run_time, "full", current_snapshot)
def lister_process(
env: Environment, lister_id: uuid.UUID
) -> Generator[Event, Event, None]:
"""Every hour, generate new origins and update the `last_update` field for
the ones this process generated in the past"""
NUM_NEW_ORIGINS = 100
origins: List[ListedOrigin] = []
while True:
updated_origins = []
for origin in origins:
model = OriginModel(origin.visit_type, origin.url)
updated_origins.append(
attr.evolve(origin, last_update=model.get_last_update(env.time))
)
origins = updated_origins
origins.extend(
generate_listed_origin(lister_id, now=env.time)
for _ in range(NUM_NEW_ORIGINS)
)
env.scheduler.record_listed_origins(origins)
yield env.timeout(3600)
def load_task_process(
env: Environment, task: Task, status_queue: Queue
) -> Iterator[Event]:
"""A loading task. This pushes OriginVisitStatus objects to the
status_queue to simulate the visible outcomes of the task.
Uses the `load_task_duration` function to determine its run time.
"""
status = OriginVisitStatus(
origin=task.origin,
visit=42,
type=task.visit_type,
status="created",
date=env.time,
snapshot=None,
)
logger.debug("%s task %s origin=%s: Start", env.time, task.visit_type, task.origin)
yield status_queue.put(TaskEvent(task=task, status=status))
origin_model = OriginModel(task.visit_type, task.origin)
(run_time, end_status, snapshot) = origin_model.load_task_characteristics(env.time)
yield env.timeout(run_time)
logger.debug("%s task %s origin=%s: End", env.time, task.visit_type, task.origin)
yield status_queue.put(
TaskEvent(
task=task,
status=attr.evolve(
status, status=end_status, date=env.time, snapshot=snapshot
),
)
)
env.report.record_visit(
(task.visit_type, task.origin), run_time, end_status, snapshot
)
diff --git a/swh/scheduler/simulator/task_scheduler.py b/swh/scheduler/simulator/task_scheduler.py
index a86a962..2f7d424 100644
--- a/swh/scheduler/simulator/task_scheduler.py
+++ b/swh/scheduler/simulator/task_scheduler.py
@@ -1,76 +1,81 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
"""Agents using the "old" task-based scheduler."""
import logging
from typing import Dict, Generator, Iterator
from simpy import Event
from .common import Environment, Queue, Task, TaskEvent
logger = logging.getLogger(__name__)
def scheduler_runner_process(
- env: Environment, task_queues: Dict[str, Queue], min_batch_size: int,
+ env: Environment,
+ task_queues: Dict[str, Queue],
+ min_batch_size: int,
) -> Iterator[Event]:
"""Scheduler runner. Grabs next visits from the database according to the
scheduling policy, and fills the task_queues accordingly."""
while True:
for visit_type, queue in task_queues.items():
remaining = queue.slots_remaining()
if remaining < min_batch_size:
continue
next_tasks = env.scheduler.grab_ready_tasks(
f"load-{visit_type}", num_tasks=remaining, timestamp=env.time
)
logger.debug(
- "%s runner: running %s %s tasks", env.time, visit_type, len(next_tasks),
+ "%s runner: running %s %s tasks",
+ env.time,
+ visit_type,
+ len(next_tasks),
)
sim_tasks = [
Task(visit_type=visit_type, origin=task["arguments"]["kwargs"]["url"])
for task in next_tasks
]
env.scheduler.mass_schedule_task_runs(
[
{
"task": task["id"],
"scheduled": env.time,
"backend_id": str(sim_task.backend_id),
}
for task, sim_task in zip(next_tasks, sim_tasks)
]
)
for sim_task in sim_tasks:
yield queue.put(sim_task)
yield env.timeout(10.0)
def scheduler_listener_process(
env: Environment, status_queue: Queue
) -> Generator[Event, TaskEvent, None]:
"""Scheduler listener. In the real world this would listen to celery
events, but we listen to the status_queue and simulate celery events from
that."""
while True:
event = yield status_queue.get()
if event.status.status == "ongoing":
env.scheduler.start_task_run(event.task.backend_id, timestamp=env.time)
else:
if event.status.status == "full":
status = "eventful" if event.eventful else "uneventful"
else:
status = "failed"
env.scheduler.end_task_run(
str(event.task.backend_id), status=status, timestamp=env.time
)
diff --git a/swh/scheduler/task.py b/swh/scheduler/task.py
index 0db6204..5a65526 100644
--- a/swh/scheduler/task.py
+++ b/swh/scheduler/task.py
@@ -1,85 +1,91 @@
# Copyright (C) 2015-2019 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from datetime import datetime
from celery import current_app
import celery.app.task
from celery.utils.log import get_task_logger
from swh.core.statsd import Statsd
def ts():
return int(datetime.utcnow().timestamp())
class SWHTask(celery.app.task.Task):
"""a schedulable task (abstract class)
Current implementation is based on Celery. See
http://docs.celeryproject.org/en/latest/reference/celery.app.task.html for
how to use tasks once instantiated
"""
_statsd = None
_log = None
reject_on_worker_lost = None
"""Inherited from :class:`celery.app.task.Task`, but we need to override
its docstring because it uses a custom ReST role"""
@property
def statsd(self):
if self._statsd:
return self._statsd
worker_name = current_app.conf.get("worker_name")
if worker_name:
self._statsd = Statsd(
- constant_tags={"task": self.name, "worker": worker_name,}
+ constant_tags={
+ "task": self.name,
+ "worker": worker_name,
+ }
)
return self._statsd
else:
statsd = Statsd(
- constant_tags={"task": self.name, "worker": "unknown worker",}
+ constant_tags={
+ "task": self.name,
+ "worker": "unknown worker",
+ }
)
return statsd
def __call__(self, *args, **kwargs):
self.statsd.increment("swh_task_called_count")
self.statsd.gauge("swh_task_start_ts", ts())
with self.statsd.timed("swh_task_duration_seconds"):
result = super().__call__(*args, **kwargs)
try:
status = result["status"]
if status == "success":
status = "eventful" if result.get("eventful") else "uneventful"
except Exception:
status = "eventful" if result else "uneventful"
self.statsd.gauge("swh_task_end_ts", ts(), tags={"status": status})
return result
def on_failure(self, exc, task_id, args, kwargs, einfo):
self.statsd.increment("swh_task_failure_count")
def on_success(self, retval, task_id, args, kwargs):
self.statsd.increment("swh_task_success_count")
# this is a swh specific event. Used to attach the retval to the
# task_run
self.send_event("task-result", result=retval)
@property
def log(self):
if self._log is None:
self._log = get_task_logger(self.name)
return self._log
def run(self, *args, **kwargs):
self.log.debug("%s: args=%s, kwargs=%s", self.name, args, kwargs)
ret = super().run(*args, **kwargs)
self.log.debug("%s: OK => %s", self.name, ret)
return ret
diff --git a/swh/scheduler/tests/common.py b/swh/scheduler/tests/common.py
index 1dc8324..d25c4f6 100644
--- a/swh/scheduler/tests/common.py
+++ b/swh/scheduler/tests/common.py
@@ -1,125 +1,127 @@
# Copyright (C) 2017-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import copy
import datetime
from typing import Dict, List, Optional
TEMPLATES = {
"test-git": {
"type": "load-test-git",
- "arguments": {"args": [], "kwargs": {},},
+ "arguments": {
+ "args": [],
+ "kwargs": {},
+ },
"next_run": None,
},
"test-hg": {
"type": "load-test-hg",
- "arguments": {"args": [], "kwargs": {},},
+ "arguments": {
+ "args": [],
+ "kwargs": {},
+ },
"next_run": None,
"policy": "oneshot",
},
}
TASK_TYPES = {
"test-git": {
"type": "load-test-git",
"description": "Update a git repository",
"backend_name": "swh.loader.git.tasks.UpdateGitRepository",
"default_interval": datetime.timedelta(days=64),
"min_interval": datetime.timedelta(hours=12),
"max_interval": datetime.timedelta(days=64),
"backoff_factor": 2,
"max_queue_length": None,
"num_retries": 7,
"retry_delay": datetime.timedelta(hours=2),
},
"test-hg": {
"type": "load-test-hg",
"description": "Update a mercurial repository",
"backend_name": "swh.loader.mercurial.tasks.UpdateHgRepository",
"default_interval": datetime.timedelta(days=64),
"min_interval": datetime.timedelta(hours=12),
"max_interval": datetime.timedelta(days=64),
"backoff_factor": 2,
"max_queue_length": None,
"num_retries": 7,
"retry_delay": datetime.timedelta(hours=2),
},
}
def _task_from_template(
template: Dict,
next_run: datetime.datetime,
priority: Optional[str],
*args,
**kwargs,
) -> Dict:
ret = copy.deepcopy(template)
ret["next_run"] = next_run
if priority:
ret["priority"] = priority
if args:
ret["arguments"]["args"] = list(args)
if kwargs:
ret["arguments"]["kwargs"] = kwargs
return ret
def tasks_from_template(
template: Dict,
max_timestamp: datetime.datetime,
num: Optional[int] = None,
priority: Optional[str] = None,
num_priorities: Dict[Optional[str], int] = {},
) -> List[Dict]:
- """Build ``num`` tasks from template
-
- """
+ """Build ``num`` tasks from template"""
assert bool(num) != bool(num_priorities), "mutually exclusive"
if not num_priorities:
assert num is not None # to please mypy
num_priorities = {None: num}
tasks: List[Dict] = []
for (priority, num) in num_priorities.items():
for _ in range(num):
i = len(tasks)
tasks.append(
_task_from_template(
template,
max_timestamp - datetime.timedelta(microseconds=i),
priority,
"argument-%03d" % i,
**{"kwarg%03d" % i: "bogus-kwarg"},
)
)
return tasks
def tasks_with_priority_from_template(
template: Dict, max_timestamp: datetime.datetime, num: int, priority: str
) -> List[Dict]:
- """Build tasks with priority from template
-
- """
+ """Build tasks with priority from template"""
return [
_task_from_template(
template,
max_timestamp - datetime.timedelta(microseconds=i),
priority,
"argument-%03d" % i,
**{"kwarg%03d" % i: "bogus-kwarg"},
)
for i in range(num)
]
LISTERS = (
{"name": "github"},
{"name": "gitlab", "instance_name": "gitlab"},
{"name": "gitlab", "instance_name": "freedesktop"},
{"name": "npm"},
{"name": "pypi"},
)
diff --git a/swh/scheduler/tests/test_cli.py b/swh/scheduler/tests/test_cli.py
index 62b83fb..5152656 100644
--- a/swh/scheduler/tests/test_cli.py
+++ b/swh/scheduler/tests/test_cli.py
@@ -1,847 +1,946 @@
# Copyright (C) 2019-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import datetime
from itertools import islice
import logging
import random
import re
import tempfile
from unittest.mock import patch
from click.testing import CliRunner
import pytest
from swh.core.api.classes import stream_results
from swh.model.model import Origin
from swh.scheduler.cli import cli
from swh.scheduler.utils import create_task_dict, utcnow
CLI_CONFIG = """
scheduler:
cls: foo
args: {}
"""
def invoke(scheduler, catch_exceptions, args):
runner = CliRunner()
with patch(
"swh.scheduler.get_scheduler"
) as get_scheduler_mock, tempfile.NamedTemporaryFile(
"a", suffix=".yml"
) as config_fd:
config_fd.write(CLI_CONFIG)
config_fd.seek(0)
get_scheduler_mock.return_value = scheduler
- args = ["-C" + config_fd.name,] + args
+ args = [
+ "-C" + config_fd.name,
+ ] + args
result = runner.invoke(cli, args, obj={"log_level": logging.WARNING})
if not catch_exceptions and result.exception:
print(result.output)
raise result.exception
return result
def test_schedule_tasks(swh_scheduler):
csv_data = (
b'swh-test-ping;[["arg1", "arg2"]];{"key": "value"};'
+ utcnow().isoformat().encode()
+ b"\n"
+ b'swh-test-ping;[["arg3", "arg4"]];{"key": "value"};'
+ utcnow().isoformat().encode()
+ b"\n"
)
with tempfile.NamedTemporaryFile(suffix=".csv") as csv_fd:
csv_fd.write(csv_data)
csv_fd.seek(0)
result = invoke(
swh_scheduler, False, ["task", "schedule", "-d", ";", csv_fd.name]
)
expected = r"""
Created 2 tasks
Task 1
Next run: today \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: recurring
Args:
\['arg1', 'arg2'\]
Keyword args:
key: 'value'
Task 2
Next run: today \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: recurring
Args:
\['arg3', 'arg4'\]
Keyword args:
key: 'value'
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), result.output
def test_schedule_tasks_columns(swh_scheduler):
with tempfile.NamedTemporaryFile(suffix=".csv") as csv_fd:
csv_fd.write(b'swh-test-ping;oneshot;["arg1", "arg2"];{"key": "value"}\n')
csv_fd.seek(0)
result = invoke(
swh_scheduler,
False,
[
"task",
"schedule",
"-c",
"type",
"-c",
"policy",
"-c",
"args",
"-c",
"kwargs",
"-d",
";",
csv_fd.name,
],
)
expected = r"""
Created 1 tasks
Task 1
Next run: today \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: oneshot
Args:
'arg1'
'arg2'
Keyword args:
key: 'value'
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), result.output
def test_schedule_task(swh_scheduler):
result = invoke(
swh_scheduler,
False,
- ["task", "add", "swh-test-ping", "arg1", "arg2", "key=value",],
+ [
+ "task",
+ "add",
+ "swh-test-ping",
+ "arg1",
+ "arg2",
+ "key=value",
+ ],
)
expected = r"""
Created 1 tasks
Task 1
Next run: today \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: recurring
Args:
'arg1'
'arg2'
Keyword args:
key: 'value'
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), result.output
def test_list_pending_tasks_none(swh_scheduler):
- result = invoke(swh_scheduler, False, ["task", "list-pending", "swh-test-ping",])
+ result = invoke(
+ swh_scheduler,
+ False,
+ [
+ "task",
+ "list-pending",
+ "swh-test-ping",
+ ],
+ )
expected = r"""
Found 0 swh-test-ping tasks
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), result.output
def test_list_pending_tasks(swh_scheduler):
task1 = create_task_dict("swh-test-ping", "oneshot", key="value1")
task2 = create_task_dict("swh-test-ping", "oneshot", key="value2")
task2["next_run"] += datetime.timedelta(days=1)
swh_scheduler.create_tasks([task1, task2])
- result = invoke(swh_scheduler, False, ["task", "list-pending", "swh-test-ping",])
+ result = invoke(
+ swh_scheduler,
+ False,
+ [
+ "task",
+ "list-pending",
+ "swh-test-ping",
+ ],
+ )
expected = r"""
Found 1 swh-test-ping tasks
Task 1
Next run: today \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: oneshot
Args:
Keyword args:
key: 'value1'
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), result.output
swh_scheduler.grab_ready_tasks("swh-test-ping")
- result = invoke(swh_scheduler, False, ["task", "list-pending", "swh-test-ping",])
+ result = invoke(
+ swh_scheduler,
+ False,
+ [
+ "task",
+ "list-pending",
+ "swh-test-ping",
+ ],
+ )
expected = r"""
Found 0 swh-test-ping tasks
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), result.output
def test_list_pending_tasks_filter(swh_scheduler):
task = create_task_dict("swh-test-multiping", "oneshot", key="value")
swh_scheduler.create_tasks([task])
- result = invoke(swh_scheduler, False, ["task", "list-pending", "swh-test-ping",])
+ result = invoke(
+ swh_scheduler,
+ False,
+ [
+ "task",
+ "list-pending",
+ "swh-test-ping",
+ ],
+ )
expected = r"""
Found 0 swh-test-ping tasks
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), result.output
def test_list_pending_tasks_filter_2(swh_scheduler):
swh_scheduler.create_tasks(
[
create_task_dict("swh-test-multiping", "oneshot", key="value"),
create_task_dict("swh-test-ping", "oneshot", key="value2"),
]
)
- result = invoke(swh_scheduler, False, ["task", "list-pending", "swh-test-ping",])
+ result = invoke(
+ swh_scheduler,
+ False,
+ [
+ "task",
+ "list-pending",
+ "swh-test-ping",
+ ],
+ )
expected = r"""
Found 1 swh-test-ping tasks
Task 2
Next run: today \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: oneshot
Args:
Keyword args:
key: 'value2'
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), result.output
# Fails because "task list-pending --limit 3" only returns 2 tasks, because
# of how compute_nb_tasks_from works.
@pytest.mark.xfail
def test_list_pending_tasks_limit(swh_scheduler):
swh_scheduler.create_tasks(
[
create_task_dict("swh-test-ping", "oneshot", key="value%d" % i)
for i in range(10)
]
)
result = invoke(
- swh_scheduler, False, ["task", "list-pending", "swh-test-ping", "--limit", "3",]
+ swh_scheduler,
+ False,
+ [
+ "task",
+ "list-pending",
+ "swh-test-ping",
+ "--limit",
+ "3",
+ ],
)
expected = r"""
Found 2 swh-test-ping tasks
Task 1
Next run: today \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: oneshot
Args:
Keyword args:
key: 'value0'
Task 2
Next run: today \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: oneshot
Args:
Keyword args:
key: 'value1'
Task 3
Next run: today \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: oneshot
Args:
Keyword args:
key: 'value2'
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), result.output
def test_list_pending_tasks_before(swh_scheduler):
task1 = create_task_dict("swh-test-ping", "oneshot", key="value")
task2 = create_task_dict("swh-test-ping", "oneshot", key="value2")
task1["next_run"] += datetime.timedelta(days=3)
task2["next_run"] += datetime.timedelta(days=1)
swh_scheduler.create_tasks([task1, task2])
result = invoke(
swh_scheduler,
False,
[
"task",
"list-pending",
"swh-test-ping",
"--before",
(datetime.date.today() + datetime.timedelta(days=2)).isoformat(),
],
)
expected = r"""
Found 1 swh-test-ping tasks
Task 2
Next run: tomorrow \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: oneshot
Args:
Keyword args:
key: 'value2'
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), result.output
def test_list_tasks(swh_scheduler):
task1 = create_task_dict("swh-test-ping", "oneshot", key="value1")
task2 = create_task_dict("swh-test-ping", "oneshot", key="value2")
task1["next_run"] += datetime.timedelta(days=3, hours=2)
swh_scheduler.create_tasks([task1, task2])
swh_scheduler.grab_ready_tasks("swh-test-ping")
- result = invoke(swh_scheduler, False, ["task", "list",])
+ result = invoke(
+ swh_scheduler,
+ False,
+ [
+ "task",
+ "list",
+ ],
+ )
expected = r"""
Found 2 tasks
Task 1
Next run: .+ \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: oneshot
Status: next_run_not_scheduled
Priority:\x20
Args:
Keyword args:
key: 'value1'
Task 2
Next run: today \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: oneshot
Status: next_run_scheduled
Priority:\x20
Args:
Keyword args:
key: 'value2'
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), result.output
def test_list_tasks_id(swh_scheduler):
task1 = create_task_dict("swh-test-ping", "oneshot", key="value1")
task2 = create_task_dict("swh-test-ping", "oneshot", key="value2")
task3 = create_task_dict("swh-test-ping", "oneshot", key="value3")
swh_scheduler.create_tasks([task1, task2, task3])
- result = invoke(swh_scheduler, False, ["task", "list", "--task-id", "2",])
+ result = invoke(
+ swh_scheduler,
+ False,
+ [
+ "task",
+ "list",
+ "--task-id",
+ "2",
+ ],
+ )
expected = r"""
Found 1 tasks
Task 2
Next run: today \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: oneshot
Status: next_run_not_scheduled
Priority:\x20
Args:
Keyword args:
key: 'value2'
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), result.output
def test_list_tasks_id_2(swh_scheduler):
task1 = create_task_dict("swh-test-ping", "oneshot", key="value1")
task2 = create_task_dict("swh-test-ping", "oneshot", key="value2")
task3 = create_task_dict("swh-test-ping", "oneshot", key="value3")
swh_scheduler.create_tasks([task1, task2, task3])
result = invoke(
swh_scheduler, False, ["task", "list", "--task-id", "2", "--task-id", "3"]
)
expected = r"""
Found 2 tasks
Task 2
Next run: today \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: oneshot
Status: next_run_not_scheduled
Priority:\x20
Args:
Keyword args:
key: 'value2'
Task 3
Next run: today \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: oneshot
Status: next_run_not_scheduled
Priority:\x20
Args:
Keyword args:
key: 'value3'
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), result.output
def test_list_tasks_type(swh_scheduler):
task1 = create_task_dict("swh-test-ping", "oneshot", key="value1")
task2 = create_task_dict("swh-test-multiping", "oneshot", key="value2")
task3 = create_task_dict("swh-test-ping", "oneshot", key="value3")
swh_scheduler.create_tasks([task1, task2, task3])
result = invoke(
swh_scheduler, False, ["task", "list", "--task-type", "swh-test-ping"]
)
expected = r"""
Found 2 tasks
Task 1
Next run: today \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: oneshot
Status: next_run_not_scheduled
Priority:\x20
Args:
Keyword args:
key: 'value1'
Task 3
Next run: today \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: oneshot
Status: next_run_not_scheduled
Priority:\x20
Args:
Keyword args:
key: 'value3'
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), result.output
def test_list_tasks_limit(swh_scheduler):
task1 = create_task_dict("swh-test-ping", "oneshot", key="value1")
task2 = create_task_dict("swh-test-ping", "oneshot", key="value2")
task3 = create_task_dict("swh-test-ping", "oneshot", key="value3")
swh_scheduler.create_tasks([task1, task2, task3])
- result = invoke(swh_scheduler, False, ["task", "list", "--limit", "2",])
+ result = invoke(
+ swh_scheduler,
+ False,
+ [
+ "task",
+ "list",
+ "--limit",
+ "2",
+ ],
+ )
expected = r"""
Found 2 tasks
Task 1
Next run: today \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: oneshot
Status: next_run_not_scheduled
Priority:\x20
Args:
Keyword args:
key: 'value1'
Task 2
Next run: today \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: oneshot
Status: next_run_not_scheduled
Priority:\x20
Args:
Keyword args:
key: 'value2'
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), result.output
def test_list_tasks_before(swh_scheduler):
task1 = create_task_dict("swh-test-ping", "oneshot", key="value1")
task2 = create_task_dict("swh-test-ping", "oneshot", key="value2")
task1["next_run"] += datetime.timedelta(days=3, hours=2)
swh_scheduler.create_tasks([task1, task2])
swh_scheduler.grab_ready_tasks("swh-test-ping")
result = invoke(
swh_scheduler,
False,
[
"task",
"list",
"--before",
(datetime.date.today() + datetime.timedelta(days=2)).isoformat(),
],
)
expected = r"""
Found 1 tasks
Task 2
Next run: today \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: oneshot
Status: next_run_scheduled
Priority:\x20
Args:
Keyword args:
key: 'value2'
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), result.output
def test_list_tasks_after(swh_scheduler):
task1 = create_task_dict("swh-test-ping", "oneshot", key="value1")
task2 = create_task_dict("swh-test-ping", "oneshot", key="value2")
task1["next_run"] += datetime.timedelta(days=3, hours=2)
swh_scheduler.create_tasks([task1, task2])
swh_scheduler.grab_ready_tasks("swh-test-ping")
result = invoke(
swh_scheduler,
False,
[
"task",
"list",
"--after",
(datetime.date.today() + datetime.timedelta(days=2)).isoformat(),
],
)
expected = r"""
Found 1 tasks
Task 1
Next run: .+ \(.*\)
Interval: 1 day, 0:00:00
Type: swh-test-ping
Policy: oneshot
Status: next_run_not_scheduled
Priority:\x20
Args:
Keyword args:
key: 'value1'
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), result.output
def _fill_storage_with_origins(storage, nb_origins):
origins = [Origin(url=f"http://example.com/{i}") for i in range(nb_origins)]
storage.origin_add(origins)
return origins
@patch("swh.scheduler.cli.utils.TASK_BATCH_SIZE", 3)
def test_task_schedule_origins_dry_run(swh_scheduler, storage):
"""Tests the scheduling when origin_batch_size*task_batch_size is a
divisor of nb_origins."""
_fill_storage_with_origins(storage, 90)
result = invoke(
swh_scheduler,
False,
- ["task", "schedule_origins", "--dry-run", "swh-test-ping",],
+ [
+ "task",
+ "schedule_origins",
+ "--dry-run",
+ "swh-test-ping",
+ ],
)
# Check the output
expected = r"""
Scheduled 3 tasks \(30 origins\).
Scheduled 6 tasks \(60 origins\).
Scheduled 9 tasks \(90 origins\).
Done.
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), repr(result.output)
# Check scheduled tasks
tasks = swh_scheduler.search_tasks()
assert len(tasks) == 0
def _assert_origin_tasks_contraints(tasks, max_tasks, max_task_size, expected_origins):
# check there are not too many tasks
assert len(tasks) <= max_tasks
# check tasks are not too large
assert all(len(task["arguments"]["args"][0]) <= max_task_size for task in tasks)
# check the tasks are exhaustive
assert sum([len(task["arguments"]["args"][0]) for task in tasks]) == len(
expected_origins
)
assert set.union(*(set(task["arguments"]["args"][0]) for task in tasks)) == {
origin.url for origin in expected_origins
}
@patch("swh.scheduler.cli.utils.TASK_BATCH_SIZE", 3)
def test_task_schedule_origins(swh_scheduler, storage):
"""Tests the scheduling when neither origin_batch_size or
task_batch_size is a divisor of nb_origins."""
origins = _fill_storage_with_origins(storage, 70)
result = invoke(
swh_scheduler,
False,
- ["task", "schedule_origins", "swh-test-ping", "--batch-size", "20",],
+ [
+ "task",
+ "schedule_origins",
+ "swh-test-ping",
+ "--batch-size",
+ "20",
+ ],
)
# Check the output
expected = r"""
Scheduled 3 tasks \(60 origins\).
Scheduled 4 tasks \(70 origins\).
Done.
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), repr(result.output)
# Check tasks
tasks = swh_scheduler.search_tasks()
_assert_origin_tasks_contraints(tasks, 4, 20, origins)
assert all(task["arguments"]["kwargs"] == {} for task in tasks)
def test_task_schedule_origins_kwargs(swh_scheduler, storage):
"""Tests support of extra keyword-arguments."""
origins = _fill_storage_with_origins(storage, 30)
result = invoke(
swh_scheduler,
False,
[
"task",
"schedule_origins",
"swh-test-ping",
"--batch-size",
"20",
'key1="value1"',
'key2="value2"',
],
)
# Check the output
expected = r"""
Scheduled 2 tasks \(30 origins\).
Done.
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), repr(result.output)
# Check tasks
tasks = swh_scheduler.search_tasks()
_assert_origin_tasks_contraints(tasks, 2, 20, origins)
assert all(
task["arguments"]["kwargs"] == {"key1": "value1", "key2": "value2"}
for task in tasks
)
def test_task_schedule_origins_with_limit(swh_scheduler, storage):
"""Tests support of extra keyword-arguments."""
_fill_storage_with_origins(storage, 50)
limit = 20
expected_origins = list(islice(stream_results(storage.origin_list), limit))
nb_origins = len(expected_origins)
assert nb_origins == limit
max_task_size = 5
nb_tasks, remainder = divmod(nb_origins, max_task_size)
assert remainder == 0 # made the numbers go round
result = invoke(
swh_scheduler,
False,
[
"task",
"schedule_origins",
"swh-test-ping",
"--batch-size",
max_task_size,
"--limit",
limit,
],
)
# Check the output
expected = rf"""
Scheduled {nb_tasks} tasks \({nb_origins} origins\).
Done.
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), repr(result.output)
tasks = swh_scheduler.search_tasks()
_assert_origin_tasks_contraints(tasks, max_task_size, nb_origins, expected_origins)
def test_task_schedule_origins_with_page_token(swh_scheduler, storage):
"""Tests support of extra keyword-arguments."""
nb_total_origins = 50
origins = _fill_storage_with_origins(storage, nb_total_origins)
# prepare page_token and origins result expectancy
page_result = storage.origin_list(limit=10)
assert len(page_result.results) == 10
page_token = page_result.next_page_token
assert page_token is not None
# remove the first 10 origins listed as we won't see those in tasks
expected_origins = [o for o in origins if o not in page_result.results]
nb_origins = len(expected_origins)
assert nb_origins == nb_total_origins - len(page_result.results)
max_task_size = 10
nb_tasks, remainder = divmod(nb_origins, max_task_size)
assert remainder == 0
result = invoke(
swh_scheduler,
False,
[
"task",
"schedule_origins",
"swh-test-ping",
"--batch-size",
max_task_size,
"--page-token",
page_token,
],
)
# Check the output
expected = rf"""
Scheduled {nb_tasks} tasks \({nb_origins} origins\).
Done.
""".lstrip()
assert result.exit_code == 0, result.output
assert re.fullmatch(expected, result.output, re.MULTILINE), repr(result.output)
# Check tasks
tasks = swh_scheduler.search_tasks()
_assert_origin_tasks_contraints(tasks, max_task_size, nb_origins, expected_origins)
def test_cli_task_runner_unknown_task_types(swh_scheduler, storage):
"""When passing at least one unknown task type, the runner should fail."""
task_types = swh_scheduler.get_task_types()
task_type_names = [t["type"] for t in task_types]
known_task_type = random.choice(task_type_names)
unknown_task_type = "unknown-task-type"
assert unknown_task_type not in task_type_names
with pytest.raises(ValueError, match="Unknown"):
invoke(
swh_scheduler,
False,
[
"start-runner",
"--task-type",
known_task_type,
"--task-type",
unknown_task_type,
],
)
@pytest.mark.parametrize("flag_priority", ["--with-priority", "--without-priority"])
def test_cli_task_runner_with_known_tasks(
swh_scheduler, storage, caplog, flag_priority
):
"""Trigger runner with known tasks runs smoothly."""
task_types = swh_scheduler.get_task_types()
task_type_names = [t["type"] for t in task_types]
task_type_name = random.choice(task_type_names)
task_type_name2 = random.choice(task_type_names)
# The runner will just iterate over the following known tasks and do noop. We are
# just checking the runner does not explode here.
result = invoke(
swh_scheduler,
False,
[
"start-runner",
flag_priority,
"--task-type",
task_type_name,
"--task-type",
task_type_name2,
],
)
assert result.exit_code == 0, result.output
def test_cli_task_runner_no_task(swh_scheduler, storage):
"""Trigger runner with no parameter should run as before."""
# The runner will just iterate over the existing tasks from the scheduler and do
# noop. We are just checking the runner does not explode here.
- result = invoke(swh_scheduler, False, ["start-runner",],)
+ result = invoke(
+ swh_scheduler,
+ False,
+ [
+ "start-runner",
+ ],
+ )
assert result.exit_code == 0, result.output
diff --git a/swh/scheduler/tests/test_cli_celery_monitor.py b/swh/scheduler/tests/test_cli_celery_monitor.py
index 151eff1..d37dcc4 100644
--- a/swh/scheduler/tests/test_cli_celery_monitor.py
+++ b/swh/scheduler/tests/test_cli_celery_monitor.py
@@ -1,147 +1,149 @@
# Copyright (C) 2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import logging
from click.testing import CliRunner
import pytest
from swh.scheduler.cli import cli
def invoke(*args, catch_exceptions=False):
result = CliRunner(mix_stderr=False).invoke(
- cli, ["celery-monitor", *args], catch_exceptions=catch_exceptions,
+ cli,
+ ["celery-monitor", *args],
+ catch_exceptions=catch_exceptions,
)
return result
def test_celery_monitor():
"""Check that celery-monitor returns its help text"""
result = invoke()
assert "Commands:" in result.stdout
assert "Options:" in result.stdout
def test_celery_monitor_ping(
caplog, swh_scheduler_celery_app, swh_scheduler_celery_worker
):
caplog.set_level(logging.INFO, "swh.scheduler.cli.celery_monitor")
result = invoke("--pattern", swh_scheduler_celery_worker.hostname, "ping-workers")
assert result.exit_code == 0
assert len(caplog.records) == 1
(record,) = caplog.records
assert record.levelname == "INFO"
assert f"response from {swh_scheduler_celery_worker.hostname}" in record.message
@pytest.mark.parametrize(
"filter_args,filter_message,exit_code",
[
((), "Matching all workers", 0),
(
("--pattern", "celery@*.test-host"),
"Using glob pattern celery@*.test-host",
1,
),
(
("--pattern", "celery@test-type.test-host"),
"Using destinations celery@test-type.test-host",
1,
),
(
("--pattern", "celery@test-type.test-host,celery@test-type2.test-host"),
(
"Using destinations "
"celery@test-type.test-host, celery@test-type2.test-host"
),
1,
),
],
)
def test_celery_monitor_ping_filter(
caplog,
swh_scheduler_celery_app,
swh_scheduler_celery_worker,
filter_args,
filter_message,
exit_code,
):
caplog.set_level(logging.DEBUG, "swh.scheduler.cli.celery_monitor")
result = invoke("--timeout", "1.5", *filter_args, "ping-workers")
assert result.exit_code == exit_code, result.stdout
got_no_response_message = False
got_filter_message = False
for record in caplog.records:
# Check the proper filter has been generated
if record.levelname == "DEBUG":
if filter_message in record.message:
got_filter_message = True
# Check that no worker responded
if record.levelname == "INFO":
if "No response in" in record.message:
got_no_response_message = True
assert got_filter_message
if filter_args:
assert got_no_response_message
def test_celery_monitor_list_running(
caplog, swh_scheduler_celery_app, swh_scheduler_celery_worker
):
caplog.set_level(logging.DEBUG, "swh.scheduler.cli.celery_monitor")
result = invoke("--pattern", swh_scheduler_celery_worker.hostname, "list-running")
assert result.exit_code == 0, result.stdout
for record in caplog.records:
if record.levelname != "INFO":
continue
assert (
f"{swh_scheduler_celery_worker.hostname}: no active tasks" in record.message
)
@pytest.mark.parametrize("format", ["csv", "pretty"])
def test_celery_monitor_list_running_format(
caplog, swh_scheduler_celery_app, swh_scheduler_celery_worker, format
):
caplog.set_level(logging.DEBUG, "swh.scheduler.cli.celery_monitor")
result = invoke(
"--pattern",
swh_scheduler_celery_worker.hostname,
"list-running",
"--format",
format,
)
assert result.exit_code == 0, result.stdout
for record in caplog.records:
if record.levelname != "INFO":
continue
assert (
f"{swh_scheduler_celery_worker.hostname}: no active tasks" in record.message
)
if format == "csv":
lines = result.stdout.splitlines()
assert lines == ["worker,name,args,kwargs,duration,worker_pid"]
diff --git a/swh/scheduler/tests/test_cli_journal.py b/swh/scheduler/tests/test_cli_journal.py
index 666c5a6..3c7e723 100644
--- a/swh/scheduler/tests/test_cli_journal.py
+++ b/swh/scheduler/tests/test_cli_journal.py
@@ -1,115 +1,132 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import os
from typing import Dict, List
from click.testing import CliRunner, Result
from confluent_kafka import Producer
import pytest
import yaml
from swh.journal.serializers import value_to_kafka
from swh.scheduler import get_scheduler
from swh.scheduler.cli import cli
from swh.scheduler.tests.test_journal_client import VISIT_STATUSES_1
@pytest.fixture
def swh_scheduler_cfg(postgresql_scheduler, kafka_server):
"""Journal client configuration ready"""
return {
- "scheduler": {"cls": "local", "db": postgresql_scheduler.dsn,},
+ "scheduler": {
+ "cls": "local",
+ "db": postgresql_scheduler.dsn,
+ },
"journal": {
"brokers": [kafka_server],
"group_id": "test-consume-visit-status",
},
}
def _write_configuration_path(config: Dict, tmp_path: str) -> str:
config_path = os.path.join(str(tmp_path), "scheduler.yml")
with open(config_path, "w") as f:
f.write(yaml.dump(config))
return config_path
@pytest.fixture
def swh_scheduler_cfg_path(swh_scheduler_cfg, tmp_path):
"""Write scheduler configuration in temporary path and returns such path"""
return _write_configuration_path(swh_scheduler_cfg, tmp_path)
def invoke(args: List[str], config_path: str, catch_exceptions: bool = False) -> Result:
- """Invoke swh scheduler journal subcommands
-
- """
+ """Invoke swh scheduler journal subcommands"""
runner = CliRunner()
result = runner.invoke(cli, ["-C" + config_path] + args)
if not catch_exceptions and result.exception:
print(result.output)
raise result.exception
return result
def test_cli_journal_client_origin_visit_status_misconfiguration_no_scheduler(
swh_scheduler_cfg, tmp_path
):
config = swh_scheduler_cfg.copy()
config["scheduler"] = {"cls": "foo"}
config_path = _write_configuration_path(config, tmp_path)
with pytest.raises(ValueError, match="must be instantiated"):
invoke(
- ["journal-client", "--stop-after-objects", "1",], config_path,
+ [
+ "journal-client",
+ "--stop-after-objects",
+ "1",
+ ],
+ config_path,
)
def test_cli_journal_client_origin_visit_status_misconfiguration_missing_journal_conf(
swh_scheduler_cfg, tmp_path
):
config = swh_scheduler_cfg.copy()
config.pop("journal", None)
config_path = _write_configuration_path(config, tmp_path)
with pytest.raises(ValueError, match="Missing 'journal'"):
invoke(
- ["journal-client", "--stop-after-objects", "1",], config_path,
+ [
+ "journal-client",
+ "--stop-after-objects",
+ "1",
+ ],
+ config_path,
)
def test_cli_journal_client_origin_visit_status(
- swh_scheduler_cfg, swh_scheduler_cfg_path,
+ swh_scheduler_cfg,
+ swh_scheduler_cfg_path,
):
kafka_server = swh_scheduler_cfg["journal"]["brokers"][0]
swh_scheduler = get_scheduler(**swh_scheduler_cfg["scheduler"])
producer = Producer(
{
"bootstrap.servers": kafka_server,
"client.id": "test visit-stats producer",
"acks": "all",
}
)
visit_status = VISIT_STATUSES_1[0]
value = value_to_kafka(visit_status)
topic = "swh.journal.objects.origin_visit_status"
producer.produce(topic=topic, key=b"bogus-origin", value=value)
producer.flush()
result = invoke(
- ["journal-client", "--stop-after-objects", "1",], swh_scheduler_cfg_path,
+ [
+ "journal-client",
+ "--stop-after-objects",
+ "1",
+ ],
+ swh_scheduler_cfg_path,
)
# Check the output
expected_output = "Processed 1 message(s).\nDone.\n"
assert result.exit_code == 0, result.output
assert result.output == expected_output
actual_visit_stats = swh_scheduler.origin_visit_stats_get(
[(visit_status["origin"], visit_status["type"])]
)
assert actual_visit_stats
assert len(actual_visit_stats) == 1
diff --git a/swh/scheduler/tests/test_cli_origin.py b/swh/scheduler/tests/test_cli_origin.py
index edf1197..90644c5 100644
--- a/swh/scheduler/tests/test_cli_origin.py
+++ b/swh/scheduler/tests/test_cli_origin.py
@@ -1,159 +1,162 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from typing import Tuple
import pytest
from swh.scheduler.cli.origin import format_origins
from swh.scheduler.tests.common import TASK_TYPES
from swh.scheduler.tests.test_cli import invoke as basic_invoke
def invoke(scheduler, args: Tuple[str, ...] = (), catch_exceptions: bool = False):
return basic_invoke(
scheduler, args=["origin", *args], catch_exceptions=catch_exceptions
)
def test_cli_origin(swh_scheduler):
"""Check that swh scheduler origin returns its help text"""
result = invoke(swh_scheduler)
assert "Commands:" in result.stdout
def test_format_origins_basic(listed_origins):
listed_origins = listed_origins[:100]
basic_output = list(format_origins(listed_origins))
# 1 header line + all origins
assert len(basic_output) == len(listed_origins) + 1
no_header_output = list(format_origins(listed_origins, with_header=False))
assert basic_output[1:] == no_header_output
def test_format_origins_fields_unknown(listed_origins):
listed_origins = listed_origins[:10]
it = format_origins(listed_origins, fields=["unknown_field"])
with pytest.raises(ValueError, match="unknown_field"):
next(it)
def test_format_origins_fields(listed_origins):
listed_origins = listed_origins[:10]
fields = ["lister_id", "url", "visit_type"]
output = list(format_origins(listed_origins, fields=fields))
assert output[0] == ",".join(fields)
for i, origin in enumerate(listed_origins):
assert output[i + 1] == f"{origin.lister_id},{origin.url},{origin.visit_type}"
def test_grab_next(swh_scheduler, listed_origins_by_type):
NUM_RESULTS = 10
# Strict inequality to check that grab_next_visits doesn't return more
# results than requested
# XXX: should test all of 'listed_origins_by_type' here...
visit_type = next(iter(listed_origins_by_type))
assert len(listed_origins_by_type[visit_type]) > NUM_RESULTS
for origins in listed_origins_by_type.values():
swh_scheduler.record_listed_origins(origins)
result = invoke(swh_scheduler, args=("grab-next", visit_type, str(NUM_RESULTS)))
assert result.exit_code == 0
out_lines = result.stdout.splitlines()
assert len(out_lines) == NUM_RESULTS + 1
fields = out_lines[0].split(",")
returned_origins = [dict(zip(fields, line.split(","))) for line in out_lines[1:]]
# Check that we've received origins we had listed in the first place
assert set(origin["url"] for origin in returned_origins) <= set(
origin.url for origin in listed_origins_by_type[visit_type]
)
def test_schedule_next(swh_scheduler, listed_origins_by_type):
for task_type in TASK_TYPES.values():
swh_scheduler.create_task_type(task_type)
NUM_RESULTS = 10
# Strict inequality to check that grab_next_visits doesn't return more
# results than requested
visit_type = next(iter(listed_origins_by_type))
assert len(listed_origins_by_type[visit_type]) > NUM_RESULTS
for origins in listed_origins_by_type.values():
swh_scheduler.record_listed_origins(origins)
result = invoke(swh_scheduler, args=("schedule-next", visit_type, str(NUM_RESULTS)))
assert result.exit_code == 0
# pull all tasks out of the scheduler
tasks = swh_scheduler.search_tasks()
assert len(tasks) == NUM_RESULTS
scheduled_tasks = {
(task["type"], task["arguments"]["kwargs"]["url"]) for task in tasks
}
all_possible_tasks = {
(f"load-{origin.visit_type}", origin.url)
for origin in listed_origins_by_type[visit_type]
}
assert scheduled_tasks <= all_possible_tasks
def test_send_to_celery(
- mocker, swh_scheduler, swh_scheduler_celery_app, listed_origins_by_type,
+ mocker,
+ swh_scheduler,
+ swh_scheduler_celery_app,
+ listed_origins_by_type,
):
for task_type in TASK_TYPES.values():
swh_scheduler.create_task_type(task_type)
visit_type = next(iter(listed_origins_by_type))
for origins in listed_origins_by_type.values():
swh_scheduler.record_listed_origins(origins)
get_queue_length = mocker.patch(
"swh.scheduler.celery_backend.config.get_queue_length"
)
get_queue_length.return_value = None
send_task = mocker.patch.object(swh_scheduler_celery_app, "send_task")
send_task.return_value = None
result = invoke(swh_scheduler, args=("send-to-celery", visit_type))
assert result.exit_code == 0
scheduled_tasks = {
(call[0][0], call[1]["kwargs"]["url"]) for call in send_task.call_args_list
}
expected_tasks = {
(TASK_TYPES[origin.visit_type]["backend_name"], origin.url)
for origin in listed_origins_by_type[visit_type]
}
assert expected_tasks == scheduled_tasks
def test_update_metrics(swh_scheduler, listed_origins):
swh_scheduler.record_listed_origins(listed_origins)
assert swh_scheduler.get_metrics() == []
result = invoke(swh_scheduler, args=("update-metrics",))
assert result.exit_code == 0
assert swh_scheduler.get_metrics() != []
diff --git a/swh/scheduler/tests/test_cli_task_type.py b/swh/scheduler/tests/test_cli_task_type.py
index 64917d2..1a8a016 100644
--- a/swh/scheduler/tests/test_cli_task_type.py
+++ b/swh/scheduler/tests/test_cli_task_type.py
@@ -1,139 +1,129 @@
# Copyright (C) 2019 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import traceback
from click.testing import CliRunner
import pkg_resources
import pytest
import yaml
from swh.scheduler import get_scheduler
from swh.scheduler.cli import cli
FAKE_MODULE_ENTRY_POINTS = {
"lister.gnu=swh.lister.gnu:register",
"lister.pypi=swh.lister.pypi:register",
}
@pytest.fixture
def cli_runner():
return CliRunner()
@pytest.fixture
def mock_pkg_resources(monkeypatch):
- """Monkey patch swh.scheduler's mock_pkg_resources.iter_entry_point call
-
- """
+ """Monkey patch swh.scheduler's mock_pkg_resources.iter_entry_point call"""
def fake_iter_entry_points(*args, **kwargs):
- """Substitute fake function to return a fixed set of entrypoints
-
- """
+ """Substitute fake function to return a fixed set of entrypoints"""
from pkg_resources import Distribution, EntryPoint
d = Distribution()
return [EntryPoint.parse(entry, dist=d) for entry in FAKE_MODULE_ENTRY_POINTS]
original_method = pkg_resources.iter_entry_points
monkeypatch.setattr(pkg_resources, "iter_entry_points", fake_iter_entry_points)
yield
# reset monkeypatch: is that needed?
monkeypatch.setattr(pkg_resources, "iter_entry_points", original_method)
@pytest.fixture
def local_sched_config(swh_scheduler_config):
- """Expose the local scheduler configuration
-
- """
+ """Expose the local scheduler configuration"""
return {"scheduler": {"cls": "local", **swh_scheduler_config}}
@pytest.fixture
def local_sched_configfile(local_sched_config, tmp_path):
- """Write in temporary location the local scheduler configuration
-
- """
+ """Write in temporary location the local scheduler configuration"""
configfile = tmp_path / "config.yml"
configfile.write_text(yaml.dump(local_sched_config))
return configfile.as_posix()
def test_register_ttypes_all(
cli_runner, mock_pkg_resources, local_sched_config, local_sched_configfile
):
"""Registering all task types"""
for command in [
["--config-file", local_sched_configfile, "task-type", "register"],
["--config-file", local_sched_configfile, "task-type", "register", "-p", "all"],
[
"--config-file",
local_sched_configfile,
"task-type",
"register",
"-p",
"lister.gnu",
"-p",
"lister.pypi",
],
]:
result = cli_runner.invoke(cli, command)
assert result.exit_code == 0, traceback.print_exception(*result.exc_info)
scheduler = get_scheduler(**local_sched_config["scheduler"])
all_tasks = [
"list-gnu-full",
"list-pypi",
]
for task in all_tasks:
task_type_desc = scheduler.get_task_type(task)
assert task_type_desc
assert task_type_desc["type"] == task
assert task_type_desc["backoff_factor"] == 1
def test_register_ttypes_filter(
mock_pkg_resources, cli_runner, local_sched_config, local_sched_configfile
):
- """Filtering on one worker should only register its associated task type
-
- """
+ """Filtering on one worker should only register its associated task type"""
result = cli_runner.invoke(
cli,
[
"--config-file",
local_sched_configfile,
"task-type",
"register",
"--plugins",
"lister.gnu",
],
)
assert result.exit_code == 0, traceback.print_exception(*result.exc_info)
scheduler = get_scheduler(**local_sched_config["scheduler"])
all_tasks = [
"list-gnu-full",
]
for task in all_tasks:
task_type_desc = scheduler.get_task_type(task)
assert task_type_desc
assert task_type_desc["type"] == task
assert task_type_desc["backoff_factor"] == 1
@pytest.mark.parametrize("cli_command", ["list", "register", "add"])
def test_cli_task_type_raise(cli_runner, cli_command):
"""Without a proper configuration, the cli raises"""
with pytest.raises(ValueError, match="Scheduler class"):
cli_runner.invoke(cli, ["task-type", cli_command], catch_exceptions=False)
diff --git a/swh/scheduler/tests/test_common.py b/swh/scheduler/tests/test_common.py
index 2439052..dfe2519 100644
--- a/swh/scheduler/tests/test_common.py
+++ b/swh/scheduler/tests/test_common.py
@@ -1,55 +1,59 @@
# Copyright (C) 2017-2019 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import datetime
from .common import TEMPLATES, tasks_from_template
def test_tasks_from_template_no_priority():
nb_tasks = 3
template = TEMPLATES["test-git"]
next_run = datetime.datetime.utcnow()
tasks = tasks_from_template(template, next_run, nb_tasks)
assert len(tasks) == nb_tasks
for i, t in enumerate(tasks):
assert t["type"] == template["type"]
assert t["arguments"] is not None
assert t.get("policy") is None # not defined in template
assert len(t["arguments"]["args"]) == 1
assert len(t["arguments"]["kwargs"].keys()) == 1
assert t["next_run"] == next_run - datetime.timedelta(microseconds=i)
assert t.get("priority") is None
def test_tasks_from_template_priority():
template = TEMPLATES["test-hg"]
num_priorities = {
None: 3,
"high": 5,
"normal": 3,
"low": 2,
}
next_run = datetime.datetime.utcnow()
- tasks = tasks_from_template(template, next_run, num_priorities=num_priorities,)
+ tasks = tasks_from_template(
+ template,
+ next_run,
+ num_priorities=num_priorities,
+ )
assert len(tasks) == sum(num_priorities.values())
repartition_priority = {k: 0 for k in num_priorities}
for i, t in enumerate(tasks):
assert t["type"] == template["type"]
assert t["arguments"] is not None
assert t["policy"] == template["policy"]
assert len(t["arguments"]["args"]) == 1
assert len(t["arguments"]["kwargs"].keys()) == 1
assert t["next_run"] == next_run - datetime.timedelta(microseconds=i)
priority = t.get("priority")
assert priority in num_priorities
repartition_priority[priority] += 1
assert repartition_priority == num_priorities
diff --git a/swh/scheduler/tests/test_journal_client.py b/swh/scheduler/tests/test_journal_client.py
index 214979c..f8b447c 100644
--- a/swh/scheduler/tests/test_journal_client.py
+++ b/swh/scheduler/tests/test_journal_client.py
@@ -1,995 +1,1015 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import datetime
import functools
from itertools import permutations
from typing import List
from unittest.mock import Mock
import attr
import pytest
from swh.model.hashutil import hash_to_bytes
from swh.scheduler.journal_client import (
from_position_offset_to_days,
max_date,
next_visit_queue_position,
process_journal_objects,
)
from swh.scheduler.model import LastVisitStatus, ListedOrigin, OriginVisitStats
from swh.scheduler.utils import utcnow
def test_journal_client_origin_visit_status_from_journal_fail(swh_scheduler):
- process_fn = functools.partial(process_journal_objects, scheduler=swh_scheduler,)
+ process_fn = functools.partial(
+ process_journal_objects,
+ scheduler=swh_scheduler,
+ )
with pytest.raises(AssertionError, match="Got unexpected origin_visit"):
- process_fn({"origin_visit": [{"url": "http://foobar.baz"},]})
+ process_fn(
+ {
+ "origin_visit": [
+ {"url": "http://foobar.baz"},
+ ]
+ }
+ )
with pytest.raises(AssertionError, match="Expected origin_visit_status"):
process_fn({})
ONE_DAY = datetime.timedelta(days=1)
ONE_YEAR = datetime.timedelta(days=366)
DATE3 = utcnow()
DATE2 = DATE3 - ONE_DAY
DATE1 = DATE2 - ONE_DAY
assert DATE1 < DATE2 < DATE3
@pytest.mark.parametrize(
"dates,expected_max_date",
[
((DATE1,), DATE1),
((None, DATE2), DATE2),
((DATE1, None), DATE1),
((DATE1, DATE2), DATE2),
((DATE2, DATE1), DATE2),
((DATE1, DATE2, DATE3), DATE3),
((None, DATE2, DATE3), DATE3),
((None, None, DATE3), DATE3),
((DATE1, None, DATE3), DATE3),
],
)
def test_max_date(dates, expected_max_date):
assert max_date(*dates) == expected_max_date
def test_max_date_raise():
with pytest.raises(ValueError, match="valid datetime"):
max_date()
with pytest.raises(ValueError, match="valid datetime"):
max_date(None)
with pytest.raises(ValueError, match="valid datetime"):
max_date(None, None)
def test_journal_client_origin_visit_status_from_journal_ignored_status(swh_scheduler):
- """Only final statuses (full, partial) are important, the rest remain ignored.
-
- """
+ """Only final statuses (full, partial) are important, the rest remain ignored."""
# Trace method calls on the swh_scheduler
swh_scheduler = Mock(wraps=swh_scheduler)
visit_statuses = [
{
"origin": "foo",
"visit": 1,
"status": "created",
"date": utcnow(),
"type": "git",
"snapshot": None,
},
{
"origin": "bar",
"visit": 1,
"status": "ongoing",
"date": utcnow(),
"type": "svn",
"snapshot": None,
},
]
process_journal_objects(
{"origin_visit_status": visit_statuses}, scheduler=swh_scheduler
)
# All messages have been ignored: no stats have been upserted
swh_scheduler.origin_visit_stats_upsert.assert_not_called()
def test_journal_client_ignore_missing_type(swh_scheduler):
"""Ignore statuses with missing type key"""
# Trace method calls on the swh_scheduler
swh_scheduler = Mock(wraps=swh_scheduler)
date = utcnow()
snapshot = hash_to_bytes("dddcc0710eb6cf9efd5b920a8453e1e07157bddd")
visit_statuses = [
{
"origin": "foo",
"visit": 1,
"status": "full",
"date": date,
"snapshot": snapshot,
},
]
process_journal_objects(
{"origin_visit_status": visit_statuses}, scheduler=swh_scheduler
)
# The message has been ignored: no stats have been upserted
swh_scheduler.origin_visit_stats_upsert.assert_not_called()
def assert_visit_stats_ok(
actual_visit_stats: OriginVisitStats,
expected_visit_stats: OriginVisitStats,
ignore_fields: List[str] = ["next_visit_queue_position"],
):
"""Utility test function to ensure visits stats read from the backend are in the right
shape. The comparison on the next_visit_queue_position will be dealt with in
dedicated tests so it's not tested in tests that are calling this function.
"""
fields = attr.fields_dict(OriginVisitStats)
defaults = {field: fields[field].default for field in ignore_fields}
actual_visit_stats = attr.evolve(actual_visit_stats, **defaults)
assert actual_visit_stats == expected_visit_stats
def test_journal_client_origin_visit_status_from_journal_last_not_found(swh_scheduler):
visit_status = {
"origin": "foo",
"visit": 1,
"status": "not_found",
"date": DATE1,
"type": "git",
"snapshot": None,
}
process_journal_objects(
{"origin_visit_status": [visit_status]}, scheduler=swh_scheduler
)
actual_origin_visit_stats = swh_scheduler.origin_visit_stats_get([("foo", "git")])
assert_visit_stats_ok(
actual_origin_visit_stats[0],
OriginVisitStats(
url="foo",
visit_type="git",
last_visit=visit_status["date"],
last_visit_status=LastVisitStatus.not_found,
next_position_offset=4,
successive_visits=1,
),
)
visit_statuses = [
{
"origin": "foo",
"visit": 3,
"status": "not_found",
"date": DATE2,
"type": "git",
"snapshot": None,
},
{
"origin": "foo",
"visit": 4,
"status": "not_found",
"date": DATE3,
"type": "git",
"snapshot": None,
},
]
process_journal_objects(
{"origin_visit_status": visit_statuses}, scheduler=swh_scheduler
)
actual_origin_visit_stats = swh_scheduler.origin_visit_stats_get([("foo", "git")])
assert_visit_stats_ok(
actual_origin_visit_stats[0],
OriginVisitStats(
url="foo",
visit_type="git",
last_visit=DATE3,
last_visit_status=LastVisitStatus.not_found,
next_position_offset=6,
successive_visits=3,
),
)
def test_journal_client_origin_visit_status_from_journal_last_failed(swh_scheduler):
visit_statuses = [
{
"origin": "foo",
"visit": 1,
"status": "partial",
"date": utcnow(),
"type": "git",
"snapshot": None,
},
{
"origin": "bar",
"visit": 1,
"status": "full",
"date": DATE1,
"type": "git",
"snapshot": None,
},
{
"origin": "bar",
"visit": 2,
"status": "full",
"date": DATE2,
"type": "git",
"snapshot": None,
},
{
"origin": "bar",
"visit": 3,
"status": "full",
"date": DATE3,
"type": "git",
"snapshot": None,
},
]
process_journal_objects(
{"origin_visit_status": visit_statuses}, scheduler=swh_scheduler
)
actual_origin_visit_stats = swh_scheduler.origin_visit_stats_get([("bar", "git")])
assert_visit_stats_ok(
actual_origin_visit_stats[0],
OriginVisitStats(
url="bar",
visit_type="git",
last_visit=DATE3,
last_visit_status=LastVisitStatus.failed,
next_position_offset=6,
successive_visits=3,
),
)
def test_journal_client_origin_visit_status_from_journal_last_failed2(swh_scheduler):
visit_statuses = [
{
"origin": "bar",
"visit": 2,
"status": "failed",
"date": DATE1,
"type": "git",
"snapshot": hash_to_bytes("d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"),
},
{
"origin": "bar",
"visit": 3,
"status": "failed",
"date": DATE2,
"type": "git",
"snapshot": None,
},
]
process_journal_objects(
{"origin_visit_status": visit_statuses}, scheduler=swh_scheduler
)
actual_origin_visit_stats = swh_scheduler.origin_visit_stats_get([("bar", "git")])
assert_visit_stats_ok(
actual_origin_visit_stats[0],
OriginVisitStats(
url="bar",
visit_type="git",
last_visit=DATE2,
last_visit_status=LastVisitStatus.failed,
next_position_offset=5,
successive_visits=2,
),
)
def test_journal_client_origin_visit_status_from_journal_last_successful(swh_scheduler):
visit_statuses = [
{
"origin": "bar",
"visit": 1,
"status": "partial",
"date": utcnow(),
"type": "git",
"snapshot": hash_to_bytes("d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"),
},
{
"origin": "foo",
"visit": 1,
"status": "full",
"date": DATE1,
"type": "git",
"snapshot": hash_to_bytes("eeecc0710eb6cf9efd5b920a8453e1e07157bfff"),
},
{
"origin": "foo",
"visit": 2,
"status": "partial",
"date": DATE2,
"type": "git",
"snapshot": hash_to_bytes("aaacc0710eb6cf9efd5b920a8453e1e07157baaa"),
},
{
"origin": "foo",
"visit": 3,
"status": "full",
"date": DATE3,
"type": "git",
"snapshot": hash_to_bytes("dddcc0710eb6cf9efd5b920a8453e1e07157bddd"),
},
]
process_journal_objects(
{"origin_visit_status": visit_statuses}, scheduler=swh_scheduler
)
actual_origin_visit_stats = swh_scheduler.origin_visit_stats_get([("foo", "git")])
assert_visit_stats_ok(
actual_origin_visit_stats[0],
OriginVisitStats(
url="foo",
visit_type="git",
last_successful=DATE3,
last_visit=DATE3,
last_visit_status=LastVisitStatus.successful,
last_snapshot=hash_to_bytes("dddcc0710eb6cf9efd5b920a8453e1e07157bddd"),
next_position_offset=0,
successive_visits=3,
),
)
def test_journal_client_origin_visit_status_from_journal_last_uneventful(swh_scheduler):
visit_status = {
"origin": "foo",
"visit": 1,
"status": "full",
"date": DATE3 + ONE_DAY,
"type": "git",
"snapshot": hash_to_bytes("d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"),
}
# Let's insert some visit stats with some previous visit information
swh_scheduler.origin_visit_stats_upsert(
[
OriginVisitStats(
url=visit_status["origin"],
visit_type=visit_status["type"],
last_successful=DATE2,
last_visit=DATE3,
last_visit_status=LastVisitStatus.failed,
last_snapshot=visit_status["snapshot"],
next_visit_queue_position=None,
next_position_offset=4,
successive_visits=1,
)
]
)
process_journal_objects(
{"origin_visit_status": [visit_status]}, scheduler=swh_scheduler
)
actual_origin_visit_stats = swh_scheduler.origin_visit_stats_get(
[(visit_status["origin"], visit_status["type"])]
)
assert_visit_stats_ok(
actual_origin_visit_stats[0],
OriginVisitStats(
url=visit_status["origin"],
visit_type=visit_status["type"],
last_visit=DATE3 + ONE_DAY,
last_successful=DATE3 + ONE_DAY,
last_visit_status=LastVisitStatus.successful,
last_snapshot=visit_status["snapshot"],
next_visit_queue_position=None,
next_position_offset=5,
successive_visits=1,
),
)
VISIT_STATUSES = [
{**ovs, "date": DATE1 + n * ONE_DAY}
for n, ovs in enumerate(
[
{
"origin": "foo",
"type": "git",
"visit": 1,
"status": "created",
"snapshot": None,
},
{
"origin": "foo",
"type": "git",
"visit": 1,
"status": "full",
"snapshot": hash_to_bytes("d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"),
},
{
"origin": "foo",
"type": "git",
"visit": 2,
"status": "created",
"snapshot": None,
},
{
"origin": "foo",
"type": "git",
"visit": 2,
"status": "full",
"snapshot": hash_to_bytes("d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"),
},
]
)
]
@pytest.mark.parametrize(
"visit_statuses", permutations(VISIT_STATUSES, len(VISIT_STATUSES))
)
def test_journal_client_origin_visit_status_permutation0(visit_statuses, swh_scheduler):
- """Ensure out of order topic subscription ends up in the same final state
-
- """
+ """Ensure out of order topic subscription ends up in the same final state"""
process_journal_objects(
{"origin_visit_status": visit_statuses}, scheduler=swh_scheduler
)
actual_origin_visit_stats = swh_scheduler.origin_visit_stats_get([("foo", "git")])
visit_stats = actual_origin_visit_stats[0]
assert_visit_stats_ok(
visit_stats,
OriginVisitStats(
url="foo",
visit_type="git",
last_successful=DATE1 + 3 * ONE_DAY,
last_visit=DATE1 + 3 * ONE_DAY,
last_visit_status=LastVisitStatus.successful,
last_snapshot=hash_to_bytes("d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"),
),
ignore_fields=[
"next_visit_queue_position",
"next_position_offset",
"successive_visits",
],
)
# We ignore out of order messages, so the next_position_offset isn't exact
# depending on the permutation. What matters is consistency of the final
# dates (last_visit and last_successful).
assert 4 <= visit_stats.next_position_offset <= 5
# same goes for successive_visits
assert 1 <= visit_stats.successive_visits <= 2
VISIT_STATUSES_1 = [
{**ovs, "date": DATE1 + n * ONE_DAY}
for n, ovs in enumerate(
[
{
"origin": "cavabarder",
"type": "hg",
"visit": 1,
"status": "partial",
"snapshot": hash_to_bytes("d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"),
},
{
"origin": "cavabarder",
"type": "hg",
"visit": 2,
"status": "full",
"snapshot": hash_to_bytes("d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"),
},
{
"origin": "cavabarder",
"type": "hg",
"visit": 3,
"status": "full",
"snapshot": hash_to_bytes("aaaaaabbbeb6cf9efd5b920a8453e1e07157b6cd"),
},
{
"origin": "cavabarder",
"type": "hg",
"visit": 4,
"status": "full",
"snapshot": hash_to_bytes("aaaaaabbbeb6cf9efd5b920a8453e1e07157b6cd"),
},
]
)
]
@pytest.mark.parametrize(
"visit_statuses", permutations(VISIT_STATUSES_1, len(VISIT_STATUSES_1))
)
def test_journal_client_origin_visit_status_permutation1(visit_statuses, swh_scheduler):
- """Ensure out of order topic subscription ends up in the same final state
-
- """
+ """Ensure out of order topic subscription ends up in the same final state"""
process_journal_objects(
{"origin_visit_status": visit_statuses}, scheduler=swh_scheduler
)
actual_visit_stats = swh_scheduler.origin_visit_stats_get([("cavabarder", "hg")])
visit_stats = actual_visit_stats[0]
assert_visit_stats_ok(
visit_stats,
OriginVisitStats(
url="cavabarder",
visit_type="hg",
last_successful=DATE1 + 3 * ONE_DAY,
last_visit=DATE1 + 3 * ONE_DAY,
last_visit_status=LastVisitStatus.successful,
last_snapshot=hash_to_bytes("aaaaaabbbeb6cf9efd5b920a8453e1e07157b6cd"),
),
ignore_fields=[
"next_visit_queue_position",
"next_position_offset",
"successive_visits",
],
)
# We ignore out of order messages, so the next_position_offset isn't exact
# depending on the permutation. What matters is consistency of the final
# dates (last_visit and last_successful).
assert 2 <= visit_stats.next_position_offset <= 5
# same goes for successive_visits
assert 1 <= visit_stats.successive_visits <= 4
VISIT_STATUSES_2 = [
{**ovs, "date": DATE1 + n * ONE_DAY}
for n, ovs in enumerate(
[
{
"origin": "cavabarder",
"type": "hg",
"visit": 1,
"status": "full",
"snapshot": hash_to_bytes("0000000000000000000000000000000000000000"),
},
{
"origin": "cavabarder",
"type": "hg",
"visit": 2,
"status": "full",
"snapshot": hash_to_bytes("1111111111111111111111111111111111111111"),
},
{
"origin": "iciaussi",
"type": "hg",
"visit": 1,
"status": "full",
"snapshot": hash_to_bytes("2222222222222222222222222222222222222222"),
},
{
"origin": "iciaussi",
"type": "hg",
"visit": 2,
"status": "full",
"snapshot": hash_to_bytes("3333333333333333333333333333333333333333"),
},
{
"origin": "cavabarder",
"type": "git",
"visit": 1,
"status": "full",
"snapshot": hash_to_bytes("4444444444444444444444444444444444444444"),
},
{
"origin": "cavabarder",
"type": "git",
"visit": 2,
"status": "full",
"snapshot": hash_to_bytes("5555555555555555555555555555555555555555"),
},
{
"origin": "iciaussi",
"type": "git",
"visit": 1,
"status": "full",
"snapshot": hash_to_bytes("6666666666666666666666666666666666666666"),
},
{
"origin": "iciaussi",
"type": "git",
"visit": 2,
"status": "full",
"snapshot": hash_to_bytes("7777777777777777777777777777777777777777"),
},
]
)
]
def test_journal_client_origin_visit_status_after_grab_next_visits(
swh_scheduler, stored_lister
):
"""Ensure OriginVisitStat entries created in the db as a result of calling
grab_next_visits() do not mess the OriginVisitStats upsert mechanism.
"""
listed_origins = [
ListedOrigin(lister_id=stored_lister.id, url=url, visit_type=visit_type)
for (url, visit_type) in set((v["origin"], v["type"]) for v in VISIT_STATUSES_2)
]
swh_scheduler.record_listed_origins(listed_origins)
before = utcnow()
swh_scheduler.grab_next_visits(
visit_type="git", count=10, policy="oldest_scheduled_first"
)
after = utcnow()
assert swh_scheduler.origin_visit_stats_get([("cavabarder", "hg")]) == []
assert swh_scheduler.origin_visit_stats_get([("cavabarder", "git")])[0] is not None
process_journal_objects(
{"origin_visit_status": VISIT_STATUSES_2}, scheduler=swh_scheduler
)
for url in ("cavabarder", "iciaussi"):
ovs = swh_scheduler.origin_visit_stats_get([(url, "git")])[0]
assert before <= ovs.last_scheduled <= after
ovs = swh_scheduler.origin_visit_stats_get([(url, "hg")])[0]
assert ovs.last_scheduled is None
ovs = swh_scheduler.origin_visit_stats_get([("cavabarder", "git")])[0]
assert ovs.last_successful == DATE1 + 5 * ONE_DAY
assert ovs.last_visit == DATE1 + 5 * ONE_DAY
assert ovs.last_visit_status == LastVisitStatus.successful
assert ovs.last_snapshot == hash_to_bytes(
"5555555555555555555555555555555555555555"
)
def test_journal_client_origin_visit_status_duplicated_messages(swh_scheduler):
- """A duplicated message must be ignored
-
- """
+ """A duplicated message must be ignored"""
visit_status = {
"origin": "foo",
"visit": 1,
"status": "full",
"date": DATE1,
"type": "git",
"snapshot": hash_to_bytes("aaaaaabbbeb6cf9efd5b920a8453e1e07157b6cd"),
}
process_journal_objects(
{"origin_visit_status": [visit_status]}, scheduler=swh_scheduler
)
process_journal_objects(
{"origin_visit_status": [visit_status]}, scheduler=swh_scheduler
)
actual_origin_visit_stats = swh_scheduler.origin_visit_stats_get([("foo", "git")])
assert_visit_stats_ok(
actual_origin_visit_stats[0],
OriginVisitStats(
url="foo",
visit_type="git",
last_successful=DATE1,
last_visit=DATE1,
last_visit_status=LastVisitStatus.successful,
last_snapshot=hash_to_bytes("aaaaaabbbeb6cf9efd5b920a8453e1e07157b6cd"),
successive_visits=1,
),
)
def test_journal_client_origin_visit_status_several_upsert(swh_scheduler):
- """An old message updates old information
-
- """
+ """An old message updates old information"""
visit_status1 = {
"origin": "foo",
"visit": 1,
"status": "full",
"date": DATE1,
"type": "git",
"snapshot": hash_to_bytes("aaaaaabbbeb6cf9efd5b920a8453e1e07157b6cd"),
}
visit_status2 = {
"origin": "foo",
"visit": 1,
"status": "full",
"date": DATE2,
"type": "git",
"snapshot": hash_to_bytes("aaaaaabbbeb6cf9efd5b920a8453e1e07157b6cd"),
}
process_journal_objects(
{"origin_visit_status": [visit_status2]}, scheduler=swh_scheduler
)
process_journal_objects(
{"origin_visit_status": [visit_status1]}, scheduler=swh_scheduler
)
actual_origin_visit_stats = swh_scheduler.origin_visit_stats_get([("foo", "git")])
assert_visit_stats_ok(
actual_origin_visit_stats[0],
OriginVisitStats(
url="foo",
visit_type="git",
last_successful=DATE2,
last_visit=DATE2,
last_visit_status=LastVisitStatus.successful,
last_snapshot=hash_to_bytes("aaaaaabbbeb6cf9efd5b920a8453e1e07157b6cd"),
next_position_offset=4,
successive_visits=1,
),
)
VISIT_STATUSES_SAME_SNAPSHOT = [
{**ovs, "date": DATE1 + n * ONE_YEAR}
for n, ovs in enumerate(
[
{
"origin": "cavabarder",
"type": "hg",
"visit": 3,
"status": "full",
"snapshot": hash_to_bytes("aaaaaabbbeb6cf9efd5b920a8453e1e07157b6cd"),
},
{
"origin": "cavabarder",
"type": "hg",
"visit": 4,
"status": "full",
"snapshot": hash_to_bytes("aaaaaabbbeb6cf9efd5b920a8453e1e07157b6cd"),
},
{
"origin": "cavabarder",
"type": "hg",
"visit": 4,
"status": "full",
"snapshot": hash_to_bytes("aaaaaabbbeb6cf9efd5b920a8453e1e07157b6cd"),
},
]
)
]
@pytest.mark.parametrize(
"visit_statuses",
permutations(VISIT_STATUSES_SAME_SNAPSHOT, len(VISIT_STATUSES_SAME_SNAPSHOT)),
)
def test_journal_client_origin_visit_statuses_same_snapshot_permutation(
visit_statuses, swh_scheduler
):
- """Ensure out of order topic subscription ends up in the same final state
-
- """
+ """Ensure out of order topic subscription ends up in the same final state"""
process_journal_objects(
{"origin_visit_status": visit_statuses}, scheduler=swh_scheduler
)
actual_origin_visit_stats = swh_scheduler.origin_visit_stats_get(
[("cavabarder", "hg")]
)
visit_stats = actual_origin_visit_stats[0]
assert_visit_stats_ok(
visit_stats,
OriginVisitStats(
url="cavabarder",
visit_type="hg",
last_successful=DATE1 + 2 * ONE_YEAR,
last_visit=DATE1 + 2 * ONE_YEAR,
last_visit_status=LastVisitStatus.successful,
last_snapshot=hash_to_bytes("aaaaaabbbeb6cf9efd5b920a8453e1e07157b6cd"),
),
ignore_fields=[
"next_visit_queue_position",
"next_position_offset",
"successive_visits",
],
)
# We ignore out of order messages, so the next_position_offset isn't exact
# depending on the permutation. What matters is consistency of the final
# dates (last_visit and last_successful).
assert 4 <= visit_stats.next_position_offset <= 6
# same goes for successive_visits
assert 1 <= visit_stats.successive_visits <= 3
@pytest.mark.parametrize(
"position_offset, interval",
[
(0, 1),
(1, 1),
(2, 2),
(3, 2),
(4, 2),
(5, 4),
(6, 16),
(7, 64),
(8, 256),
(9, 1024),
(10, 4096),
(11, 16384),
],
)
def test_journal_client_from_position_offset_to_days(position_offset, interval):
assert from_position_offset_to_days(position_offset) == interval
def test_journal_client_from_position_offset_to_days_only_positive_input():
with pytest.raises(AssertionError):
from_position_offset_to_days(-1)
@pytest.mark.parametrize(
- "fudge_factor,next_position_offset", [(0.01, 1), (-0.01, 5), (0.1, 8), (-0.1, 10),]
+ "fudge_factor,next_position_offset",
+ [
+ (0.01, 1),
+ (-0.01, 5),
+ (0.1, 8),
+ (-0.1, 10),
+ ],
)
def test_next_visit_queue_position(mocker, fudge_factor, next_position_offset):
mock_random = mocker.patch("swh.scheduler.journal_client.random.uniform")
mock_random.return_value = fudge_factor
actual_position = next_visit_queue_position(
- {}, {"next_position_offset": next_position_offset, "visit_type": "svn",}
+ {},
+ {
+ "next_position_offset": next_position_offset,
+ "visit_type": "svn",
+ },
)
assert actual_position == int(
24
* 3600
* from_position_offset_to_days(next_position_offset)
* (1 + fudge_factor)
)
assert mock_random.called
@pytest.mark.parametrize(
- "fudge_factor,next_position_offset", [(0.02, 2), (-0.02, 3), (0, 7), (-0.09, 9),]
+ "fudge_factor,next_position_offset",
+ [
+ (0.02, 2),
+ (-0.02, 3),
+ (0, 7),
+ (-0.09, 9),
+ ],
)
def test_next_visit_queue_position_with_state(
mocker, fudge_factor, next_position_offset
):
mock_random = mocker.patch("swh.scheduler.journal_client.random.uniform")
mock_random.return_value = fudge_factor
actual_position = next_visit_queue_position(
{"git": 0},
- {"next_position_offset": next_position_offset, "visit_type": "git",},
+ {
+ "next_position_offset": next_position_offset,
+ "visit_type": "git",
+ },
)
assert actual_position == int(
24
* 3600
* from_position_offset_to_days(next_position_offset)
* (1 + fudge_factor)
)
assert mock_random.called
@pytest.mark.parametrize(
- "fudge_factor,next_position_offset", [(0.03, 3), (-0.03, 4), (0.08, 7), (-0.08, 9),]
+ "fudge_factor,next_position_offset",
+ [
+ (0.03, 3),
+ (-0.03, 4),
+ (0.08, 7),
+ (-0.08, 9),
+ ],
)
def test_next_visit_queue_position_with_next_visit_queue(
mocker, fudge_factor, next_position_offset
):
mock_random = mocker.patch("swh.scheduler.journal_client.random.uniform")
mock_random.return_value = fudge_factor
actual_position = next_visit_queue_position(
{},
{
"next_position_offset": next_position_offset,
"visit_type": "hg",
"next_visit_queue_position": 0,
},
)
assert actual_position == int(
24
* 3600
* from_position_offset_to_days(next_position_offset)
* (1 + fudge_factor)
)
assert mock_random.called
def test_disable_failing_origins(swh_scheduler):
- """Origin with too many failed attempts ends up being deactivated in the scheduler.
-
- """
+ """Origin with too many failed attempts ends up being deactivated in the scheduler."""
# actually store the origin in the scheduler so we can check it's deactivated in the
# end.
lister = swh_scheduler.get_or_create_lister(
name="something", instance_name="something"
)
origin = ListedOrigin(
url="bar", enabled=True, visit_type="svn", lister_id=lister.id
)
swh_scheduler.record_listed_origins([origin])
visit_statuses = [
{
"origin": "bar",
"visit": 2,
"status": "failed",
"date": DATE1,
"type": "svn",
"snapshot": None,
},
{
"origin": "bar",
"visit": 3,
"status": "failed",
"date": DATE2,
"type": "svn",
"snapshot": None,
},
{
"origin": "bar",
"visit": 3,
"status": "failed",
"date": DATE3,
"type": "svn",
"snapshot": None,
},
]
process_journal_objects(
{"origin_visit_status": visit_statuses}, scheduler=swh_scheduler
)
actual_origin_visit_stats = swh_scheduler.origin_visit_stats_get([("bar", "svn")])
assert_visit_stats_ok(
actual_origin_visit_stats[0],
OriginVisitStats(
url="bar",
visit_type="svn",
last_successful=None,
last_visit=DATE3,
last_visit_status=LastVisitStatus.failed,
next_position_offset=6,
successive_visits=3,
),
)
# Now check that the origin in question is disabled
actual_page = swh_scheduler.get_listed_origins(url="bar")
assert len(actual_page.results) == 1
assert actual_page.next_page_token is None
for origin in actual_page.results:
assert origin.enabled is False
assert origin.lister_id == lister.id
assert origin.url == "bar"
assert origin.visit_type == "svn"
diff --git a/swh/scheduler/tests/test_model.py b/swh/scheduler/tests/test_model.py
index a8b2d76..5780ebd 100644
--- a/swh/scheduler/tests/test_model.py
+++ b/swh/scheduler/tests/test_model.py
@@ -1,125 +1,127 @@
# Copyright (C) 2020-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import datetime
import uuid
import attr
from swh.scheduler import model
def test_select_columns():
@attr.s
class TestModel(model.BaseSchedulerModel):
id = attr.ib(type=str)
test1 = attr.ib(type=str)
a_first_attr = attr.ib(type=str)
@property
def test2(self):
"""This property should not show up in the extracted columns"""
return self.test1
assert TestModel.select_columns() == ("a_first_attr", "id", "test1")
def test_insert_columns():
@attr.s
class TestModel(model.BaseSchedulerModel):
id = attr.ib(type=str)
test1 = attr.ib(type=str)
@property
def test2(self):
"""This property should not show up in the extracted columns"""
return self.test1
assert TestModel.insert_columns_and_metavars() == (
("id", "test1"),
("%(id)s", "%(test1)s"),
)
def test_insert_columns_auto_now_add():
@attr.s
class TestModel(model.BaseSchedulerModel):
id = attr.ib(type=str)
test1 = attr.ib(type=str)
added = attr.ib(type=datetime.datetime, metadata={"auto_now_add": True})
assert TestModel.insert_columns_and_metavars() == (
("id", "test1"),
("%(id)s", "%(test1)s"),
)
def test_insert_columns_auto_now():
@attr.s
class TestModel(model.BaseSchedulerModel):
id = attr.ib(type=str)
test1 = attr.ib(type=str)
updated = attr.ib(type=datetime.datetime, metadata={"auto_now": True})
assert TestModel.insert_columns_and_metavars() == (
("id", "test1", "updated"),
("%(id)s", "%(test1)s", "now()"),
)
def test_insert_columns_primary_key():
@attr.s
class TestModel(model.BaseSchedulerModel):
id = attr.ib(type=str, metadata={"auto_primary_key": True})
test1 = attr.ib(type=str)
assert TestModel.insert_columns_and_metavars() == (("test1",), ("%(test1)s",))
def test_insert_primary_key():
@attr.s
class TestModel(model.BaseSchedulerModel):
id = attr.ib(type=str, metadata={"auto_primary_key": True})
test1 = attr.ib(type=str)
assert TestModel.primary_key_columns() == ("id",)
@attr.s
class TestModel2(model.BaseSchedulerModel):
col1 = attr.ib(type=str, metadata={"primary_key": True})
col2 = attr.ib(type=str, metadata={"primary_key": True})
test1 = attr.ib(type=str)
assert TestModel2.primary_key_columns() == ("col1", "col2")
def test_listed_origin_as_task_dict():
origin = model.ListedOrigin(
- lister_id=uuid.uuid4(), url="http://example.com/", visit_type="git",
+ lister_id=uuid.uuid4(),
+ url="http://example.com/",
+ visit_type="git",
)
task = origin.as_task_dict()
assert task == {
"type": "load-git",
"arguments": {"args": [], "kwargs": {"url": "http://example.com/"}},
}
loader_args = {"foo": "bar", "baz": {"foo": "bar"}}
origin_w_args = model.ListedOrigin(
lister_id=uuid.uuid4(),
url="http://example.com/svn/",
visit_type="svn",
extra_loader_arguments=loader_args,
)
task_w_args = origin_w_args.as_task_dict()
assert task_w_args == {
"type": "load-svn",
"arguments": {
"args": [],
"kwargs": {"url": "http://example.com/svn/", **loader_args},
},
}
diff --git a/swh/scheduler/tests/test_recurrent_visits.py b/swh/scheduler/tests/test_recurrent_visits.py
index 4d83547..c7ff20f 100644
--- a/swh/scheduler/tests/test_recurrent_visits.py
+++ b/swh/scheduler/tests/test_recurrent_visits.py
@@ -1,211 +1,215 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from datetime import timedelta
import logging
from queue import Queue
from unittest.mock import MagicMock, patch
import pytest
from swh.scheduler.celery_backend.recurrent_visits import (
POLICY_ADDITIONAL_PARAMETERS,
VisitSchedulerThreads,
grab_next_visits_policy_weights,
send_visits_for_visit_type,
spawn_visit_scheduler_thread,
terminate_visit_scheduler_threads,
visit_scheduler_thread,
)
from .test_cli import invoke
TEST_MAX_QUEUE = 10000
MODULE_NAME = "swh.scheduler.celery_backend.recurrent_visits"
def _compute_backend_name(visit_type: str) -> str:
"Build a dummy reproducible backend name"
return f"swh.loader.{visit_type}.tasks"
@pytest.fixture
def swh_scheduler(swh_scheduler):
"""Override default fixture of the scheduler to install some more task types."""
for visit_type in ["test-git", "test-hg", "test-svn"]:
task_type = f"load-{visit_type}"
swh_scheduler.create_task_type(
{
"type": task_type,
"max_queue_length": TEST_MAX_QUEUE,
"description": "The {} testing task".format(task_type),
"backend_name": _compute_backend_name(visit_type),
"default_interval": timedelta(days=1),
"min_interval": timedelta(hours=6),
"max_interval": timedelta(days=12),
}
)
return swh_scheduler
def test_cli_schedule_recurrent_unknown_visit_type(swh_scheduler):
"""When passed an unknown visit type, the recurrent visit scheduler should refuse
to start."""
with pytest.raises(ValueError, match="Unknown"):
invoke(
swh_scheduler,
False,
[
"schedule-recurrent",
"--visit-type",
"unknown",
"--visit-type",
"test-git",
],
)
def test_cli_schedule_recurrent_noop(swh_scheduler, mocker):
"""When passing no visit types, the recurrent visit scheduler should start."""
spawn_visit_scheduler_thread = mocker.patch(
f"{MODULE_NAME}.spawn_visit_scheduler_thread"
)
spawn_visit_scheduler_thread.side_effect = SystemExit
# The actual scheduling threads won't spawn, they'll immediately terminate. This
# only exercises the logic to pull task types out of the database
result = invoke(swh_scheduler, False, ["schedule-recurrent"])
assert result.exit_code == 0, result.output
def test_recurrent_visit_scheduling(
- swh_scheduler, caplog, listed_origins_by_type, mocker,
+ swh_scheduler,
+ caplog,
+ listed_origins_by_type,
+ mocker,
):
"""Scheduling known tasks is ok."""
caplog.set_level(logging.DEBUG, MODULE_NAME)
nb_origins = 1000
mock_celery_app = MagicMock()
mock_available_slots = mocker.patch(f"{MODULE_NAME}.get_available_slots")
mock_available_slots.return_value = nb_origins # Slots available in queue
# Make sure the scheduler is properly configured in terms of visit/task types
all_task_types = {
task_type_d["type"]: task_type_d
for task_type_d in swh_scheduler.get_task_types()
}
visit_types = list(listed_origins_by_type.keys())
assert len(visit_types) > 0
task_types = []
origins = []
for visit_type, _origins in listed_origins_by_type.items():
origins.extend(swh_scheduler.record_listed_origins(_origins))
task_type_name = f"load-{visit_type}"
assert task_type_name in all_task_types.keys()
task_type = all_task_types[task_type_name]
task_type["visit_type"] = visit_type
# we'll limit the orchestrator to the origins' type we know
task_types.append(task_type)
for visit_type in ["test-git", "test-svn"]:
task_type = f"load-{visit_type}"
send_visits_for_visit_type(
swh_scheduler, mock_celery_app, visit_type, all_task_types[task_type]
)
assert mock_available_slots.called, "The available slots functions should be called"
records = [record.message for record in caplog.records]
# Mapping over the dict ratio/policies entries can change overall order so let's
# check the set of records
expected_records = set()
for task_type in task_types:
visit_type = task_type["visit_type"]
queue_name = task_type["backend_name"]
msg = (
f"{nb_origins} available slots for visit type {visit_type} "
f"in queue {queue_name}"
)
expected_records.add(msg)
for expected_record in expected_records:
assert expected_record in set(records)
@patch.dict(
POLICY_ADDITIONAL_PARAMETERS, {"test-git": POLICY_ADDITIONAL_PARAMETERS["git"]}
)
@pytest.mark.parametrize(
"visit_type, tablesamples",
[("test-hg", {}), ("test-git", POLICY_ADDITIONAL_PARAMETERS["git"])],
)
def test_recurrent_visit_additional_parameters(
swh_scheduler, mocker, visit_type, tablesamples
):
"""Testing additional policy parameters"""
mock_grab_next_visits = mocker.patch.object(swh_scheduler, "grab_next_visits")
mock_grab_next_visits.return_value = []
grab_next_visits_policy_weights(swh_scheduler, visit_type, 10)
for call in mock_grab_next_visits.call_args_list:
assert call[1].get("tablesample") == tablesamples.get(
call[1]["policy"], {}
).get("tablesample")
@pytest.fixture
def scheduler_config(swh_scheduler_config):
return {"scheduler": {"cls": "local", **swh_scheduler_config}, "celery": {}}
def test_visit_scheduler_thread_unknown_task(
- swh_scheduler, scheduler_config,
+ swh_scheduler,
+ scheduler_config,
):
"""Starting a thread with unknown task type reports the error"""
unknown_visit_type = "unknown"
command_queue = Queue()
exc_queue = Queue()
visit_scheduler_thread(
scheduler_config, unknown_visit_type, command_queue, exc_queue
)
assert command_queue.empty() is True
assert exc_queue.empty() is False
assert len(exc_queue.queue) == 1
result = exc_queue.queue.pop()
assert result[0] == unknown_visit_type
assert isinstance(result[1], ValueError)
def test_spawn_visit_scheduler_thread_noop(scheduler_config, visit_types, mocker):
"""Spawning and terminating threads runs smoothly"""
threads: VisitSchedulerThreads = {}
exc_queue = Queue()
mock_build_app = mocker.patch("swh.scheduler.celery_backend.config.build_app")
mock_build_app.return_value = MagicMock()
assert len(threads) == 0
for visit_type in visit_types:
spawn_visit_scheduler_thread(threads, exc_queue, scheduler_config, visit_type)
# This actually only checks the spawning and terminating logic is sound
assert len(threads) == len(visit_types)
actual_threads = terminate_visit_scheduler_threads(threads)
assert not len(actual_threads)
assert mock_build_app.called
diff --git a/swh/scheduler/tests/test_scheduler.py b/swh/scheduler/tests/test_scheduler.py
index 748ccc6..dd5b475 100644
--- a/swh/scheduler/tests/test_scheduler.py
+++ b/swh/scheduler/tests/test_scheduler.py
@@ -1,1505 +1,1529 @@
# Copyright (C) 2017-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from collections import defaultdict
import copy
import datetime
from datetime import timedelta
import inspect
import random
from typing import Any, Dict, List, Optional, Tuple
import uuid
import attr
from psycopg2.extras import execute_values
import pytest
from swh.model.hashutil import hash_to_bytes
from swh.scheduler.exc import SchedulerException, StaleData, UnknownPolicy
from swh.scheduler.interface import ListedOriginPageToken, SchedulerInterface
from swh.scheduler.model import (
LastVisitStatus,
ListedOrigin,
OriginVisitStats,
SchedulerMetrics,
)
from swh.scheduler.utils import utcnow
from .common import (
LISTERS,
TASK_TYPES,
TEMPLATES,
tasks_from_template,
tasks_with_priority_from_template,
)
ONEDAY = timedelta(days=1)
NUM_PRIORITY_TASKS = {None: 100, "high": 60, "normal": 30, "low": 20}
def subdict(d, keys=None, excl=()):
if keys is None:
keys = [k for k in d.keys()]
return {k: d[k] for k in keys if k not in excl}
def metrics_sort_key(m: SchedulerMetrics) -> Tuple[uuid.UUID, str]:
return (m.lister_id, m.visit_type)
def assert_metrics_equal(left, right):
assert sorted(left, key=metrics_sort_key) == sorted(right, key=metrics_sort_key)
class TestScheduler:
def test_interface(self, swh_scheduler):
"""Checks all methods of SchedulerInterface are implemented by this
backend, and that they have the same signature."""
# Create an instance of the protocol (which cannot be instantiated
# directly, so this creates a subclass, then instantiates it)
interface = type("_", (SchedulerInterface,), {})()
assert "create_task_type" in dir(interface)
missing_methods = []
for meth_name in dir(interface):
if meth_name.startswith("_"):
continue
interface_meth = getattr(interface, meth_name)
try:
concrete_meth = getattr(swh_scheduler, meth_name)
except AttributeError:
if not getattr(interface_meth, "deprecated_endpoint", False):
# The backend is missing a (non-deprecated) endpoint
missing_methods.append(meth_name)
continue
expected_signature = inspect.signature(interface_meth)
actual_signature = inspect.signature(concrete_meth)
assert expected_signature == actual_signature, meth_name
assert missing_methods == []
def test_add_task_type(self, swh_scheduler):
tt = TASK_TYPES["test-git"]
swh_scheduler.create_task_type(tt)
assert tt == swh_scheduler.get_task_type(tt["type"])
tt2 = TASK_TYPES["test-hg"]
swh_scheduler.create_task_type(tt2)
assert tt == swh_scheduler.get_task_type(tt["type"])
assert tt2 == swh_scheduler.get_task_type(tt2["type"])
def test_create_task_type_idempotence(self, swh_scheduler):
tt = TASK_TYPES["test-git"]
swh_scheduler.create_task_type(tt)
swh_scheduler.create_task_type(tt)
assert tt == swh_scheduler.get_task_type(tt["type"])
def test_get_task_types(self, swh_scheduler):
tt, tt2 = TASK_TYPES["test-git"], TASK_TYPES["test-hg"]
swh_scheduler.create_task_type(tt)
swh_scheduler.create_task_type(tt2)
actual_task_types = swh_scheduler.get_task_types()
assert tt in actual_task_types
assert tt2 in actual_task_types
def test_create_tasks(self, swh_scheduler):
self._create_task_types(swh_scheduler)
num_git = 100
tasks_1 = tasks_from_template(TEMPLATES["test-git"], utcnow(), num_git)
tasks_2 = tasks_from_template(
TEMPLATES["test-hg"], utcnow(), num_priorities=NUM_PRIORITY_TASKS
)
tasks = tasks_1 + tasks_2
# tasks are returned only once with their ids
ret1 = swh_scheduler.create_tasks(tasks + tasks)
set_ret1 = set([t["id"] for t in ret1])
# creating the same set result in the same ids
ret = swh_scheduler.create_tasks(tasks)
set_ret = set([t["id"] for t in ret])
# Idempotence results
assert set_ret == set_ret1
assert len(ret) == len(ret1)
ids = set()
actual_priorities = defaultdict(int)
for task, orig_task in zip(ret, tasks):
task = copy.deepcopy(task)
task_type = TASK_TYPES[orig_task["type"].split("-", 1)[-1]]
assert task["id"] not in ids
assert task["status"] == "next_run_not_scheduled"
assert task["current_interval"] == task_type["default_interval"]
assert task["policy"] == orig_task.get("policy", "recurring")
priority = task.get("priority")
actual_priorities[priority] += 1
assert task["retries_left"] == (task_type["num_retries"] or 0)
ids.add(task["id"])
del task["id"]
del task["status"]
del task["current_interval"]
del task["retries_left"]
if "policy" not in orig_task:
del task["policy"]
if "priority" not in orig_task:
del task["priority"]
assert task == orig_task
expected_priorities = NUM_PRIORITY_TASKS.copy()
expected_priorities[None] += num_git
assert dict(actual_priorities) == expected_priorities
def test_peek_ready_tasks_no_priority(self, swh_scheduler):
self._create_task_types(swh_scheduler)
t = utcnow()
task_type = TEMPLATES["test-git"]["type"]
tasks = tasks_from_template(TEMPLATES["test-git"], t, 100)
random.shuffle(tasks)
swh_scheduler.create_tasks(tasks)
ready_tasks = swh_scheduler.peek_ready_tasks(task_type)
assert len(ready_tasks) == len(tasks)
for i in range(len(ready_tasks) - 1):
assert ready_tasks[i]["next_run"] <= ready_tasks[i + 1]["next_run"]
# Only get the first few ready tasks
limit = random.randrange(5, 5 + len(tasks) // 2)
ready_tasks_limited = swh_scheduler.peek_ready_tasks(task_type, num_tasks=limit)
assert len(ready_tasks_limited) == limit
assert ready_tasks_limited == ready_tasks[:limit]
# Limit by timestamp
max_ts = tasks[limit - 1]["next_run"]
ready_tasks_timestamped = swh_scheduler.peek_ready_tasks(
task_type, timestamp=max_ts
)
for ready_task in ready_tasks_timestamped:
assert ready_task["next_run"] <= max_ts
# Make sure we get proper behavior for the first ready tasks
assert ready_tasks[: len(ready_tasks_timestamped)] == ready_tasks_timestamped
# Limit by both
ready_tasks_both = swh_scheduler.peek_ready_tasks(
task_type, timestamp=max_ts, num_tasks=limit // 3
)
assert len(ready_tasks_both) <= limit // 3
for ready_task in ready_tasks_both:
assert ready_task["next_run"] <= max_ts
assert ready_task in ready_tasks[: limit // 3]
def test_peek_ready_tasks_returns_only_no_priority_tasks(self, swh_scheduler):
"""Peek ready tasks only return standard tasks (no priority)"""
self._create_task_types(swh_scheduler)
t = utcnow()
task_type = TEMPLATES["test-git"]["type"]
# Create tasks with and without priorities
tasks = tasks_from_template(
- TEMPLATES["test-git"], t, num_priorities=NUM_PRIORITY_TASKS,
+ TEMPLATES["test-git"],
+ t,
+ num_priorities=NUM_PRIORITY_TASKS,
)
count_priority = 0
for task in tasks:
count_priority += 0 if task.get("priority") is None else 1
assert count_priority > 0, "Some created tasks should have some priority"
random.shuffle(tasks)
swh_scheduler.create_tasks(tasks)
# take all available no priority tasks
ready_tasks = swh_scheduler.peek_ready_tasks(task_type)
assert len(ready_tasks) == len(tasks) - count_priority
# No read task should have any priority
for task in ready_tasks:
assert task.get("priority") is None
def test_grab_ready_tasks(self, swh_scheduler):
self._create_task_types(swh_scheduler)
t = utcnow()
task_type = TEMPLATES["test-git"]["type"]
# Create tasks with and without priorities
tasks = tasks_from_template(
TEMPLATES["test-git"], t, num_priorities=NUM_PRIORITY_TASKS
)
random.shuffle(tasks)
swh_scheduler.create_tasks(tasks)
first_ready_tasks = swh_scheduler.peek_ready_tasks(task_type, num_tasks=50)
grabbed_tasks = swh_scheduler.grab_ready_tasks(task_type, num_tasks=50)
first_ready_tasks.sort(key=lambda task: task["arguments"]["args"][0])
grabbed_tasks.sort(key=lambda task: task["arguments"]["args"][0])
for peeked, grabbed in zip(first_ready_tasks, grabbed_tasks):
assert peeked["status"] == "next_run_not_scheduled"
del peeked["status"]
assert grabbed["status"] == "next_run_scheduled"
del grabbed["status"]
assert peeked == grabbed
priority = grabbed["priority"]
assert priority == peeked["priority"]
assert priority is None
def test_grab_ready_priority_tasks(self, swh_scheduler):
"""check the grab and peek priority tasks endpoint behave as expected"""
self._create_task_types(swh_scheduler)
t = utcnow()
task_type = TEMPLATES["test-git"]["type"]
num_tasks = 100
# Create tasks with and without priorities
tasks0 = tasks_with_priority_from_template(
- TEMPLATES["test-git"], t, num_tasks, "high",
+ TEMPLATES["test-git"],
+ t,
+ num_tasks,
+ "high",
)
tasks1 = tasks_with_priority_from_template(
- TEMPLATES["test-hg"], t, num_tasks, "low",
+ TEMPLATES["test-hg"],
+ t,
+ num_tasks,
+ "low",
)
tasks2 = tasks_with_priority_from_template(
- TEMPLATES["test-hg"], t, num_tasks, "normal",
+ TEMPLATES["test-hg"],
+ t,
+ num_tasks,
+ "normal",
)
tasks = tasks0 + tasks1 + tasks2
random.shuffle(tasks)
swh_scheduler.create_tasks(tasks)
ready_tasks = swh_scheduler.peek_ready_priority_tasks(task_type, num_tasks=50)
grabbed_tasks = swh_scheduler.grab_ready_priority_tasks(task_type, num_tasks=50)
ready_tasks.sort(key=lambda task: task["arguments"]["args"][0])
grabbed_tasks.sort(key=lambda task: task["arguments"]["args"][0])
for peeked, grabbed in zip(ready_tasks, grabbed_tasks):
assert peeked["status"] == "next_run_not_scheduled"
del peeked["status"]
assert grabbed["status"] == "next_run_scheduled"
del grabbed["status"]
assert peeked == grabbed
assert peeked["priority"] == grabbed["priority"]
assert peeked["priority"] is not None
def test_get_tasks(self, swh_scheduler):
self._create_task_types(swh_scheduler)
t = utcnow()
tasks = tasks_from_template(TEMPLATES["test-git"], t, 100)
tasks = swh_scheduler.create_tasks(tasks)
random.shuffle(tasks)
while len(tasks) > 1:
length = random.randrange(1, len(tasks))
cur_tasks = sorted(tasks[:length], key=lambda x: x["id"])
tasks[:length] = []
ret = swh_scheduler.get_tasks(task["id"] for task in cur_tasks)
# result is not guaranteed to be sorted
ret.sort(key=lambda x: x["id"])
assert ret == cur_tasks
def test_search_tasks(self, swh_scheduler):
def make_real_dicts(lst):
"""RealDictRow is not a real dict."""
return [dict(d.items()) for d in lst]
self._create_task_types(swh_scheduler)
t = utcnow()
tasks = tasks_from_template(TEMPLATES["test-git"], t, 100)
tasks = swh_scheduler.create_tasks(tasks)
assert make_real_dicts(swh_scheduler.search_tasks()) == make_real_dicts(tasks)
def assert_filtered_task_ok(
self, task: Dict[str, Any], after: datetime.datetime, before: datetime.datetime
) -> None:
"""Ensure filtered tasks have the right expected properties
- (within the range, recurring disabled, etc..)
+ (within the range, recurring disabled, etc..)
"""
started = task["started"]
date = started if started is not None else task["scheduled"]
assert after <= date and date <= before
if task["task_policy"] == "oneshot":
assert task["task_status"] in ["completed", "disabled"]
if task["task_policy"] == "recurring":
assert task["task_status"] in ["disabled"]
def test_filter_task_to_archive(self, swh_scheduler):
- """Filtering only list disabled recurring or completed oneshot tasks
-
- """
+ """Filtering only list disabled recurring or completed oneshot tasks"""
self._create_task_types(swh_scheduler)
_time = utcnow()
recurring = tasks_from_template(TEMPLATES["test-git"], _time, 12)
oneshots = tasks_from_template(TEMPLATES["test-hg"], _time, 12)
total_tasks = len(recurring) + len(oneshots)
# simulate scheduling tasks
pending_tasks = swh_scheduler.create_tasks(recurring + oneshots)
backend_tasks = [
{
"task": task["id"],
"backend_id": str(uuid.uuid4()),
"scheduled": utcnow(),
}
for task in pending_tasks
]
swh_scheduler.mass_schedule_task_runs(backend_tasks)
# we simulate the task are being done
_tasks = []
for task in backend_tasks:
t = swh_scheduler.end_task_run(task["backend_id"], status="eventful")
_tasks.append(t)
# Randomly update task's status per policy
status_per_policy = {"recurring": 0, "oneshot": 0}
status_choice = {
# policy: [tuple (1-for-filtering, 'associated-status')]
"recurring": [
(1, "disabled"),
(0, "completed"),
(0, "next_run_not_scheduled"),
],
"oneshot": [
(0, "next_run_not_scheduled"),
(1, "disabled"),
(1, "completed"),
],
}
tasks_to_update = defaultdict(list)
_task_ids = defaultdict(list)
# randomize 'disabling' recurring task or 'complete' oneshot task
for task in pending_tasks:
policy = task["policy"]
_task_ids[policy].append(task["id"])
status = random.choice(status_choice[policy])
if status[0] != 1:
continue
# elected for filtering
status_per_policy[policy] += status[0]
tasks_to_update[policy].append(task["id"])
swh_scheduler.disable_tasks(tasks_to_update["recurring"])
# hack: change the status to something else than completed/disabled
swh_scheduler.set_status_tasks(
_task_ids["oneshot"], status="next_run_not_scheduled"
)
# complete the tasks to update
swh_scheduler.set_status_tasks(tasks_to_update["oneshot"], status="completed")
total_tasks_filtered = (
status_per_policy["recurring"] + status_per_policy["oneshot"]
)
# no pagination scenario
# retrieve tasks to archive
after = _time - ONEDAY
after_ts = after.strftime("%Y-%m-%d")
before = utcnow() + ONEDAY
before_ts = before.strftime("%Y-%m-%d")
tasks_result = swh_scheduler.filter_task_to_archive(
after_ts=after_ts, before_ts=before_ts, limit=total_tasks
)
tasks_to_archive = tasks_result["tasks"]
assert len(tasks_to_archive) == total_tasks_filtered
assert tasks_result.get("next_page_token") is None
actual_filtered_per_status = {"recurring": 0, "oneshot": 0}
for task in tasks_to_archive:
self.assert_filtered_task_ok(task, after, before)
actual_filtered_per_status[task["task_policy"]] += 1
assert actual_filtered_per_status == status_per_policy
# pagination scenario
nb_tasks = 3
tasks_result = swh_scheduler.filter_task_to_archive(
after_ts=after_ts, before_ts=before_ts, limit=nb_tasks
)
tasks_to_archive2 = tasks_result["tasks"]
assert len(tasks_to_archive2) == nb_tasks
next_page_token = tasks_result["next_page_token"]
assert next_page_token is not None
all_tasks = tasks_to_archive2
while next_page_token is not None: # Retrieve paginated results
tasks_result = swh_scheduler.filter_task_to_archive(
after_ts=after_ts,
before_ts=before_ts,
limit=nb_tasks,
page_token=next_page_token,
)
tasks_to_archive2 = tasks_result["tasks"]
assert len(tasks_to_archive2) <= nb_tasks
all_tasks.extend(tasks_to_archive2)
next_page_token = tasks_result.get("next_page_token")
actual_filtered_per_status = {"recurring": 0, "oneshot": 0}
for task in all_tasks:
self.assert_filtered_task_ok(task, after, before)
actual_filtered_per_status[task["task_policy"]] += 1
assert actual_filtered_per_status == status_per_policy
def test_delete_archived_tasks(self, swh_scheduler):
self._create_task_types(swh_scheduler)
_time = utcnow()
recurring = tasks_from_template(TEMPLATES["test-git"], _time, 12)
oneshots = tasks_from_template(TEMPLATES["test-hg"], _time, 12)
total_tasks = len(recurring) + len(oneshots)
pending_tasks = swh_scheduler.create_tasks(recurring + oneshots)
backend_tasks = [
{
"task": task["id"],
"backend_id": str(uuid.uuid4()),
"scheduled": utcnow(),
}
for task in pending_tasks
]
swh_scheduler.mass_schedule_task_runs(backend_tasks)
_tasks = []
percent = random.randint(0, 100) # random election removal boundary
for task in backend_tasks:
t = swh_scheduler.end_task_run(task["backend_id"], status="eventful")
c = random.randint(0, 100)
if c <= percent:
_tasks.append({"task_id": t["task"], "task_run_id": t["id"]})
swh_scheduler.delete_archived_tasks(_tasks)
all_tasks = [task["id"] for task in swh_scheduler.search_tasks()]
tasks_count = len(all_tasks)
tasks_run_count = len(swh_scheduler.get_task_runs(all_tasks))
assert tasks_count == total_tasks - len(_tasks)
assert tasks_run_count == total_tasks - len(_tasks)
def test_get_task_runs_no_task(self, swh_scheduler):
"""No task exist in the scheduler's db, get_task_runs() should always return an
empty list.
"""
assert not swh_scheduler.get_task_runs(task_ids=())
assert not swh_scheduler.get_task_runs(task_ids=(1, 2, 3))
assert not swh_scheduler.get_task_runs(task_ids=(1, 2, 3), limit=10)
def test_get_task_runs_no_task_executed(self, swh_scheduler):
"""No task has been executed yet, get_task_runs() should always return an empty
list.
"""
self._create_task_types(swh_scheduler)
_time = utcnow()
recurring = tasks_from_template(TEMPLATES["test-git"], _time, 12)
oneshots = tasks_from_template(TEMPLATES["test-hg"], _time, 12)
swh_scheduler.create_tasks(recurring + oneshots)
assert not swh_scheduler.get_task_runs(task_ids=())
assert not swh_scheduler.get_task_runs(task_ids=(1, 2, 3))
assert not swh_scheduler.get_task_runs(task_ids=(1, 2, 3), limit=10)
def test_get_task_runs_with_scheduled(self, swh_scheduler):
"""Some tasks have been scheduled but not executed yet, get_task_runs() should
not return an empty list. limit should behave as expected.
"""
self._create_task_types(swh_scheduler)
_time = utcnow()
recurring = tasks_from_template(TEMPLATES["test-git"], _time, 12)
oneshots = tasks_from_template(TEMPLATES["test-hg"], _time, 12)
total_tasks = len(recurring) + len(oneshots)
pending_tasks = swh_scheduler.create_tasks(recurring + oneshots)
backend_tasks = [
{
"task": task["id"],
"backend_id": str(uuid.uuid4()),
"scheduled": utcnow(),
}
for task in pending_tasks
]
swh_scheduler.mass_schedule_task_runs(backend_tasks)
assert not swh_scheduler.get_task_runs(task_ids=[total_tasks + 1])
btask = backend_tasks[0]
runs = swh_scheduler.get_task_runs(task_ids=[btask["task"]])
assert len(runs) == 1
run = runs[0]
assert subdict(run, excl=("id",)) == {
"task": btask["task"],
"backend_id": btask["backend_id"],
"scheduled": btask["scheduled"],
"started": None,
"ended": None,
"metadata": None,
"status": "scheduled",
}
runs = swh_scheduler.get_task_runs(
task_ids=[bt["task"] for bt in backend_tasks], limit=2
)
assert len(runs) == 2
runs = swh_scheduler.get_task_runs(
task_ids=[bt["task"] for bt in backend_tasks]
)
assert len(runs) == total_tasks
keys = ("task", "backend_id", "scheduled")
assert (
sorted([subdict(x, keys) for x in runs], key=lambda x: x["task"])
== backend_tasks
)
def test_get_task_runs_with_executed(self, swh_scheduler):
"""Some tasks have been executed, get_task_runs() should
not return an empty list. limit should behave as expected.
"""
self._create_task_types(swh_scheduler)
_time = utcnow()
recurring = tasks_from_template(TEMPLATES["test-git"], _time, 12)
oneshots = tasks_from_template(TEMPLATES["test-hg"], _time, 12)
pending_tasks = swh_scheduler.create_tasks(recurring + oneshots)
backend_tasks = [
{
"task": task["id"],
"backend_id": str(uuid.uuid4()),
"scheduled": utcnow(),
}
for task in pending_tasks
]
swh_scheduler.mass_schedule_task_runs(backend_tasks)
btask = backend_tasks[0]
ts = utcnow()
swh_scheduler.start_task_run(
btask["backend_id"], metadata={"something": "stupid"}, timestamp=ts
)
runs = swh_scheduler.get_task_runs(task_ids=[btask["task"]])
assert len(runs) == 1
assert subdict(runs[0], excl=("id")) == {
"task": btask["task"],
"backend_id": btask["backend_id"],
"scheduled": btask["scheduled"],
"started": ts,
"ended": None,
"metadata": {"something": "stupid"},
"status": "started",
}
ts2 = utcnow()
swh_scheduler.end_task_run(
btask["backend_id"],
metadata={"other": "stuff"},
timestamp=ts2,
status="eventful",
)
runs = swh_scheduler.get_task_runs(task_ids=[btask["task"]])
assert len(runs) == 1
assert subdict(runs[0], excl=("id")) == {
"task": btask["task"],
"backend_id": btask["backend_id"],
"scheduled": btask["scheduled"],
"started": ts,
"ended": ts2,
"metadata": {"something": "stupid", "other": "stuff"},
"status": "eventful",
}
def test_get_or_create_lister(self, swh_scheduler):
db_listers = []
for lister_args in LISTERS:
db_listers.append(swh_scheduler.get_or_create_lister(**lister_args))
for lister, lister_args in zip(db_listers, LISTERS):
assert lister.name == lister_args["name"]
assert lister.instance_name == lister_args.get("instance_name", "")
lister_get_again = swh_scheduler.get_or_create_lister(
lister.name, lister.instance_name
)
assert lister == lister_get_again
def test_get_lister(self, swh_scheduler):
for lister_args in LISTERS:
assert swh_scheduler.get_lister(**lister_args) is None
db_listers = []
for lister_args in LISTERS:
db_listers.append(swh_scheduler.get_or_create_lister(**lister_args))
for lister, lister_args in zip(db_listers, LISTERS):
lister_get_again = swh_scheduler.get_lister(
lister.name, lister.instance_name
)
assert lister == lister_get_again
def test_get_listers(self, swh_scheduler):
assert swh_scheduler.get_listers() == []
db_listers = []
for lister_args in LISTERS:
db_listers.append(swh_scheduler.get_or_create_lister(**lister_args))
assert swh_scheduler.get_listers() == db_listers
def test_update_lister(self, swh_scheduler, stored_lister):
lister = attr.evolve(stored_lister, current_state={"updated": "now"})
updated_lister = swh_scheduler.update_lister(lister)
assert updated_lister.updated > lister.updated
assert updated_lister == attr.evolve(lister, updated=updated_lister.updated)
def test_update_lister_stale(self, swh_scheduler, stored_lister):
swh_scheduler.update_lister(stored_lister)
with pytest.raises(StaleData) as exc:
swh_scheduler.update_lister(stored_lister)
assert "state not updated" in exc.value.args[0]
def test_record_listed_origins(self, swh_scheduler, listed_origins):
ret = swh_scheduler.record_listed_origins(listed_origins)
assert set(returned.url for returned in ret) == set(
origin.url for origin in listed_origins
)
assert all(origin.first_seen == origin.last_seen for origin in ret)
def test_record_listed_origins_with_duplicate(self, swh_scheduler, listed_origins):
# the duplicates must be in the same page to raise the "on conflict error"
listed_origins.insert(0, listed_origins[0])
ret = swh_scheduler.record_listed_origins(listed_origins)
# without the duplicate
assert len(ret) == len(listed_origins) - 1
def test_record_listed_origins_upsert(self, swh_scheduler, listed_origins):
# First, insert `cutoff` origins
cutoff = 100
assert cutoff < len(listed_origins)
ret = swh_scheduler.record_listed_origins(listed_origins[:cutoff])
assert len(ret) == cutoff
# Then, insert all origins, including the `cutoff` first.
ret = swh_scheduler.record_listed_origins(listed_origins)
assert len(ret) == len(listed_origins)
# Two different "first seen" values
assert len(set(origin.first_seen for origin in ret)) == 2
# But a single "last seen" value
assert len(set(origin.last_seen for origin in ret)) == 1
def test_get_listed_origins_exact(self, swh_scheduler, listed_origins):
swh_scheduler.record_listed_origins(listed_origins)
for i, origin in enumerate(listed_origins):
ret = swh_scheduler.get_listed_origins(
lister_id=origin.lister_id, url=origin.url
)
assert ret.next_page_token is None
assert len(ret.results) == 1
assert ret.results[0].lister_id == origin.lister_id
assert ret.results[0].url == origin.url
@pytest.mark.parametrize("num_origins,limit", [(20, 6), (5, 42), (20, 20)])
def test_get_listed_origins_limit(
self, swh_scheduler, listed_origins, num_origins, limit
) -> None:
added_origins = sorted(
listed_origins[:num_origins], key=lambda o: (o.lister_id, o.url)
)
swh_scheduler.record_listed_origins(added_origins)
returned_origins: List[ListedOrigin] = []
call_count = 0
next_page_token: Optional[ListedOriginPageToken] = None
while True:
call_count += 1
ret = swh_scheduler.get_listed_origins(
lister_id=listed_origins[0].lister_id,
limit=limit,
page_token=next_page_token,
)
returned_origins.extend(ret.results)
next_page_token = ret.next_page_token
if next_page_token is None:
break
assert call_count == (num_origins // limit) + 1
assert len(returned_origins) == num_origins
assert [(origin.lister_id, origin.url) for origin in returned_origins] == [
(origin.lister_id, origin.url) for origin in added_origins
]
def test_get_listed_origins_all(self, swh_scheduler, listed_origins) -> None:
swh_scheduler.record_listed_origins(listed_origins)
ret = swh_scheduler.get_listed_origins(limit=len(listed_origins) + 1)
assert ret.next_page_token is None
assert len(ret.results) == len(listed_origins)
def _grab_next_visits_setup(self, swh_scheduler, listed_origins_by_type):
"""Basic origins setup for scheduling policy tests"""
visit_type = next(iter(listed_origins_by_type))
origins = listed_origins_by_type[visit_type][:100]
assert len(origins) > 0
recorded_origins = swh_scheduler.record_listed_origins(origins)
return visit_type, recorded_origins
def _check_grab_next_visit_basic(
self, swh_scheduler, visit_type, policy, expected, **kwargs
):
"""Calls grab_next_visits with the passed policy, and check that:
- all the origins returned are the expected ones (in the same order)
- no extra origins are returned
- the last_scheduled field has been set properly.
Pass the extra keyword arguments to the calls to grab_next_visits.
Returns a timestamp greater than all `last_scheduled` values for the grabbed
visits.
"""
assert len(expected) != 0
before = utcnow()
ret = swh_scheduler.grab_next_visits(
visit_type=visit_type,
# Request one more than expected to check that no extra origin is returned
count=len(expected) + 1,
policy=policy,
**kwargs,
)
after = utcnow()
assert ret == expected
visit_stats_list = swh_scheduler.origin_visit_stats_get(
[(origin.url, origin.visit_type) for origin in expected]
)
assert len(visit_stats_list) == len(expected)
for visit_stats in visit_stats_list:
# Check that last_scheduled got updated
assert before <= visit_stats.last_scheduled <= after
# They should not be scheduled again
ret = swh_scheduler.grab_next_visits(
visit_type=visit_type, count=len(expected) + 1, policy=policy, **kwargs
)
assert ret == [], "grab_next_visits returned already-scheduled origins"
return after
def _check_grab_next_visit(
self, swh_scheduler, visit_type, policy, expected, **kwargs
):
"""Run the same check as _check_grab_next_visit_basic, but also checks the
origin visits have been marked as scheduled, and are only re-scheduled a
week later
"""
after = self._check_grab_next_visit_basic(
swh_scheduler, visit_type, policy, expected, **kwargs
)
# But a week, later, they should
ret = swh_scheduler.grab_next_visits(
visit_type=visit_type,
count=len(expected) + 1,
policy=policy,
timestamp=after + timedelta(days=7),
)
# We need to sort them because their 'last_scheduled' field is updated to
# exactly the same value, so the order is not deterministic
assert sorted(ret) == sorted(
expected
), "grab_next_visits didn't reschedule visits after a week"
def _prepare_oldest_scheduled_first_origins(
self, swh_scheduler, listed_origins_by_type
):
visit_type, origins = self._grab_next_visits_setup(
swh_scheduler, listed_origins_by_type
)
# Give all origins but one a last_scheduled date
base_date = datetime.datetime(2020, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
visit_stats = [
OriginVisitStats(
url=origin.url,
visit_type=origin.visit_type,
last_snapshot=None,
last_successful=None,
last_visit=None,
last_scheduled=base_date - timedelta(seconds=i),
)
for i, origin in enumerate(origins[1:])
]
swh_scheduler.origin_visit_stats_upsert(visit_stats)
# We expect to retrieve the origin with a NULL last_scheduled
# as well as those with the oldest values (i.e. the last ones), in order.
expected = [origins[0]] + origins[1:][::-1]
return visit_type, origins, expected
def test_grab_next_visits_oldest_scheduled_first(
- self, swh_scheduler, listed_origins_by_type,
+ self,
+ swh_scheduler,
+ listed_origins_by_type,
):
visit_type, origins, expected = self._prepare_oldest_scheduled_first_origins(
swh_scheduler, listed_origins_by_type
)
self._check_grab_next_visit(
swh_scheduler,
visit_type=visit_type,
policy="oldest_scheduled_first",
expected=expected,
)
@pytest.mark.parametrize("which_cooldown", ("scheduled", "failed", "not_found"))
@pytest.mark.parametrize("cooldown", (7, 15))
def test_grab_next_visits_cooldowns(
- self, swh_scheduler, listed_origins_by_type, which_cooldown, cooldown,
+ self,
+ swh_scheduler,
+ listed_origins_by_type,
+ which_cooldown,
+ cooldown,
):
visit_type, origins, expected = self._prepare_oldest_scheduled_first_origins(
swh_scheduler, listed_origins_by_type
)
after = self._check_grab_next_visit_basic(
swh_scheduler,
visit_type=visit_type,
policy="oldest_scheduled_first",
expected=expected,
)
# Mark all the visits as scheduled, failed or notfound on the `after` timestamp
ovs_args = {
"last_visit": None,
"last_visit_status": None,
"last_scheduled": None,
}
if which_cooldown == "scheduled":
ovs_args["last_scheduled"] = after
else:
ovs_args["last_visit"] = after
ovs_args["last_visit_status"] = LastVisitStatus(which_cooldown)
visit_stats = [
OriginVisitStats(
url=origin.url,
visit_type=origin.visit_type,
last_snapshot=None,
last_successful=None,
**ovs_args,
)
for i, origin in enumerate(origins)
]
swh_scheduler.origin_visit_stats_upsert(visit_stats)
cooldown_td = timedelta(days=cooldown)
cooldown_args = {
"scheduled_cooldown": None,
"failed_cooldown": None,
"not_found_cooldown": None,
}
cooldown_args[f"{which_cooldown}_cooldown"] = cooldown_td
ret = swh_scheduler.grab_next_visits(
visit_type=visit_type,
count=len(expected) + 1,
policy="oldest_scheduled_first",
timestamp=after + cooldown_td - timedelta(seconds=1),
**cooldown_args,
)
assert ret == [], f"{which_cooldown}_cooldown ignored"
ret = swh_scheduler.grab_next_visits(
visit_type=visit_type,
count=len(expected) + 1,
policy="oldest_scheduled_first",
timestamp=after + cooldown_td + timedelta(seconds=1),
**cooldown_args,
)
assert sorted(ret) == sorted(
expected
), "grab_next_visits didn't reschedule visits after the configured cooldown"
def test_grab_next_visits_tablesample(
- self, swh_scheduler, listed_origins_by_type,
+ self,
+ swh_scheduler,
+ listed_origins_by_type,
):
visit_type, origins, expected = self._prepare_oldest_scheduled_first_origins(
swh_scheduler, listed_origins_by_type
)
ret = swh_scheduler.grab_next_visits(
visit_type=visit_type,
policy="oldest_scheduled_first",
tablesample=50,
count=len(expected),
)
# Just a smoke test, not obvious how to test this more reliably
assert ret is not None
def test_grab_next_visits_never_visited_oldest_update_first(
- self, swh_scheduler, listed_origins_by_type,
+ self,
+ swh_scheduler,
+ listed_origins_by_type,
):
visit_type, origins = self._grab_next_visits_setup(
swh_scheduler, listed_origins_by_type
)
# Update known origins with a `last_update` field that we control
base_date = datetime.datetime(2020, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
updated_origins = [
attr.evolve(origin, last_update=base_date - timedelta(seconds=i))
for i, origin in enumerate(origins)
]
updated_origins = swh_scheduler.record_listed_origins(updated_origins)
# We expect to retrieve origins with the oldest update date, that is
# origins at the end of our updated_origins list.
expected_origins = sorted(updated_origins, key=lambda o: o.last_update)
self._check_grab_next_visit(
swh_scheduler,
visit_type=visit_type,
policy="never_visited_oldest_update_first",
expected=expected_origins,
)
def test_grab_next_visits_already_visited_order_by_lag(
- self, swh_scheduler, listed_origins_by_type,
+ self,
+ swh_scheduler,
+ listed_origins_by_type,
):
visit_type, origins = self._grab_next_visits_setup(
swh_scheduler, listed_origins_by_type
)
# Update known origins with a `last_update` field that we control
base_date = datetime.datetime(2020, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
updated_origins = [
attr.evolve(origin, last_update=base_date - timedelta(seconds=i))
for i, origin in enumerate(origins)
]
updated_origins = swh_scheduler.record_listed_origins(updated_origins)
# Update the visit stats with a known visit at a controlled date for
# half the origins. Pick the date in the middle of the
# updated_origins' `last_update` range
visit_date = updated_origins[len(updated_origins) // 2].last_update
visited_origins = updated_origins[::2]
visit_stats = [
OriginVisitStats(
url=origin.url,
visit_type=origin.visit_type,
last_snapshot=hash_to_bytes("d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"),
last_successful=visit_date,
last_visit=visit_date,
)
for origin in visited_origins
]
swh_scheduler.origin_visit_stats_upsert(visit_stats)
# We expect to retrieve visited origins with the largest lag, but only
# those which haven't been visited since their last update
expected_origins = sorted(
[origin for origin in visited_origins if origin.last_update > visit_date],
key=lambda o: visit_date - o.last_update,
)
self._check_grab_next_visit(
swh_scheduler,
visit_type=visit_type,
policy="already_visited_order_by_lag",
expected=expected_origins,
)
def test_grab_next_visits_underflow(self, swh_scheduler, listed_origins_by_type):
"""Check that grab_next_visits works when there not enough origins in
the database"""
visit_type = next(iter(listed_origins_by_type))
# Only add 5 origins to the database
origins = listed_origins_by_type[visit_type][:5]
assert origins
swh_scheduler.record_listed_origins(origins)
ret = swh_scheduler.grab_next_visits(
visit_type, len(origins) + 2, policy="oldest_scheduled_first"
)
assert len(ret) == 5
def test_grab_next_visits_no_last_update_nor_visit_stats(
self, swh_scheduler, listed_origins_by_type
):
- """grab_next_visits should retrieve tasks without last update (nor visit stats)
-
- """
+ """grab_next_visits should retrieve tasks without last update (nor visit stats)"""
visit_type = next(iter(listed_origins_by_type))
origins = []
for origin in listed_origins_by_type[visit_type]:
origins.append(
attr.evolve(origin, last_update=None)
) # void the last update so we are in the relevant context
assert len(origins) > 0
swh_scheduler.record_listed_origins(origins)
# Initially, we have no global queue position
current_state = swh_scheduler.visit_scheduler_queue_position_get()
assert current_state == {}
# nor any visit statuses
actual_visit_stats = swh_scheduler.origin_visit_stats_get(
(o.url, o.visit_type) for o in origins
)
assert len(actual_visit_stats) == 0
# Grab some new visits
next_visits = swh_scheduler.grab_next_visits(
- visit_type, count=len(origins), policy="origins_without_last_update",
+ visit_type,
+ count=len(origins),
+ policy="origins_without_last_update",
)
# we do have the one without any last update
assert len(next_visits) == len(origins)
# Now the global state got updated
current_state = swh_scheduler.visit_scheduler_queue_position_get()
assert current_state[visit_type] is not None
actual_visit_stats = swh_scheduler.origin_visit_stats_get(
(o.url, o.visit_type) for o in next_visits
)
# Visit stats got algo created
assert len(actual_visit_stats) == len(origins)
def test_grab_next_visits_no_last_update_with_visit_stats(
self, swh_scheduler, listed_origins_by_type
):
"""grab_next_visits should retrieve tasks without last update"""
visit_type = next(iter(listed_origins_by_type))
origins = []
for origin in listed_origins_by_type[visit_type]:
origins.append(
attr.evolve(origin, last_update=None)
) # void the last update so we are in the relevant context
assert len(origins) > 0
swh_scheduler.record_listed_origins(origins)
# Initially, we have no global queue position
current_state = swh_scheduler.visit_scheduler_queue_position_get()
assert current_state == {}
# Simulate some of those origins have associated visit stats (some with an
# existing queue position and some without any)
visit_stats = (
[
OriginVisitStats(
url=origin.url,
visit_type=origin.visit_type,
last_successful=utcnow(),
last_visit=utcnow(),
next_visit_queue_position=int(24 * 3600 * random.uniform(-10, 1)),
)
for origin in origins[:100]
]
+ [
OriginVisitStats(
url=origin.url,
visit_type=origin.visit_type,
last_successful=utcnow(),
last_visit=utcnow(),
next_visit_queue_position=int(
24 * 3600 * random.uniform(1, 10)
), # definitely > 0
)
for origin in origins[100:150]
]
+ [
OriginVisitStats(
url=origin.url,
visit_type=origin.visit_type,
last_successful=utcnow(),
last_visit=utcnow(),
)
for origin in origins[150:]
]
)
swh_scheduler.origin_visit_stats_upsert(visit_stats)
# Grab next visits
actual_visits = swh_scheduler.grab_next_visits(
- visit_type, count=len(origins), policy="origins_without_last_update",
+ visit_type,
+ count=len(origins),
+ policy="origins_without_last_update",
)
assert len(actual_visits) == len(origins)
actual_visit_stats = swh_scheduler.origin_visit_stats_get(
(o.url, o.visit_type) for o in actual_visits
)
assert len(actual_visit_stats) == len(origins)
current_state = swh_scheduler.visit_scheduler_queue_position_get()
assert current_state == {
visit_type: max(
s.next_visit_queue_position
for s in actual_visit_stats
if s.next_visit_queue_position is not None
)
}
def test_grab_next_visits_unknown_policy(self, swh_scheduler):
unknown_policy = "non_existing_policy"
NUM_RESULTS = 5
with pytest.raises(UnknownPolicy, match=unknown_policy):
swh_scheduler.grab_next_visits("type", NUM_RESULTS, policy=unknown_policy)
def test_grab_next_visit_duplicates(self, swh_scheduler, listed_origins):
"""Checks grab_next_visits does not crash when there are rows with
duplicated (origin_url, visit_type) in the database
"""
lister2 = swh_scheduler.get_or_create_lister(**LISTERS[1])
assert lister2.id != listed_origins[0].lister_id
# Create two origins with the same url and visit_type, but different listers
# (and also differing value for last_update so they are returned in
# deterministic order)
origin1 = attr.evolve(
listed_origins[0], first_seen=utcnow(), last_seen=utcnow()
)
origin2 = attr.evolve(
origin1,
lister_id=lister2.id,
last_update=origin1.last_update + datetime.timedelta(seconds=10),
)
origins = [origin1, origin2]
recorded_origins = swh_scheduler.record_listed_origins(origins)
expected_origins = sorted(recorded_origins, key=lambda o: o.last_update)
self._check_grab_next_visit(
swh_scheduler,
visit_type=origin1.visit_type,
policy="never_visited_oldest_update_first",
expected=expected_origins,
)
def _create_task_types(self, scheduler):
for tt in TASK_TYPES.values():
scheduler.create_task_type(tt)
def test_origin_visit_stats_get_empty(self, swh_scheduler) -> None:
assert swh_scheduler.origin_visit_stats_get([]) == []
def test_origin_visit_stats_get_pagination(self, swh_scheduler) -> None:
page_size = inspect.signature(execute_values).parameters["page_size"].default
visit_stats = [
OriginVisitStats(
url=f"https://example.com/origin-{i:03d}",
visit_type="git",
last_successful=utcnow(),
last_visit=utcnow(),
)
for i in range(
page_size + 1
) # Ensure overflow of the psycopg2.extras.execute_values page_size
]
swh_scheduler.origin_visit_stats_upsert(visit_stats)
assert set(
swh_scheduler.origin_visit_stats_get(
[(ovs.url, ovs.visit_type) for ovs in visit_stats]
)
) == set(visit_stats)
def test_origin_visit_stats_upsert(self, swh_scheduler) -> None:
eventful_date = utcnow()
url = "https://github.com/test"
visit_stats = OriginVisitStats(
url=url,
visit_type="git",
last_successful=eventful_date,
last_visit=eventful_date,
)
swh_scheduler.origin_visit_stats_upsert([visit_stats])
swh_scheduler.origin_visit_stats_upsert([visit_stats])
assert swh_scheduler.origin_visit_stats_get([(url, "git")]) == [visit_stats]
assert swh_scheduler.origin_visit_stats_get([(url, "svn")]) == []
new_visit_date = utcnow()
visit_stats = OriginVisitStats(
- url=url, visit_type="git", last_successful=None, last_visit=new_visit_date,
+ url=url,
+ visit_type="git",
+ last_successful=None,
+ last_visit=new_visit_date,
)
swh_scheduler.origin_visit_stats_upsert([visit_stats])
uneventful_visits = swh_scheduler.origin_visit_stats_get([(url, "git")])
expected_visit_stats = OriginVisitStats(
url=url,
visit_type="git",
last_successful=eventful_date,
last_visit=new_visit_date,
)
assert uneventful_visits == [expected_visit_stats]
def test_origin_visit_stats_upsert_with_snapshot(self, swh_scheduler) -> None:
eventful_date = utcnow()
url = "https://github.com/666/test"
visit_stats = OriginVisitStats(
url=url,
visit_type="git",
last_successful=eventful_date,
last_snapshot=hash_to_bytes("d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"),
)
swh_scheduler.origin_visit_stats_upsert([visit_stats])
assert swh_scheduler.origin_visit_stats_get([(url, "git")]) == [visit_stats]
assert swh_scheduler.origin_visit_stats_get([(url, "svn")]) == []
def test_origin_visit_stats_upsert_batch(self, swh_scheduler) -> None:
"""Batch upsert is ok"""
visit_stats = [
OriginVisitStats(
url="foo",
visit_type="git",
last_successful=utcnow(),
last_snapshot=hash_to_bytes("d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"),
),
OriginVisitStats(
url="bar",
visit_type="git",
last_visit=utcnow(),
last_snapshot=hash_to_bytes("fffcc0710eb6cf9efd5b920a8453e1e07157bfff"),
),
]
swh_scheduler.origin_visit_stats_upsert(visit_stats)
for visit_stat in swh_scheduler.origin_visit_stats_get(
[(vs.url, vs.visit_type) for vs in visit_stats]
):
assert visit_stat is not None
def test_origin_visit_stats_upsert_cardinality_failing(self, swh_scheduler) -> None:
- """Batch upsert does not support altering multiple times the same origin-visit-status
-
- """
+ """Batch upsert does not support altering multiple times the same origin-visit-status"""
with pytest.raises(SchedulerException, match="CardinalityViolation"):
swh_scheduler.origin_visit_stats_upsert(
[
OriginVisitStats(
url="foo",
visit_type="git",
last_successful=None,
last_visit=utcnow(),
),
OriginVisitStats(
url="foo",
visit_type="git",
last_successful=utcnow(),
last_visit=None,
),
]
)
def test_visit_scheduler_queue_position(
self, swh_scheduler, listed_origins
) -> None:
result = swh_scheduler.visit_scheduler_queue_position_get()
assert result == {}
expected_result = {}
visit_types = set()
for origin in listed_origins:
visit_type = origin.visit_type
if visit_type in visit_types:
continue
visit_types.add(visit_type)
position = 42
swh_scheduler.visit_scheduler_queue_position_set(visit_type, position)
expected_result[visit_type] = position
result = swh_scheduler.visit_scheduler_queue_position_get()
assert result == expected_result
def test_metrics_origins_known(self, swh_scheduler, listed_origins):
swh_scheduler.record_listed_origins(listed_origins)
ret = swh_scheduler.update_metrics()
assert sum(metric.origins_known for metric in ret) == len(listed_origins)
def test_metrics_origins_enabled(self, swh_scheduler, listed_origins):
swh_scheduler.record_listed_origins(listed_origins)
disabled_origin = attr.evolve(listed_origins[0], enabled=False)
swh_scheduler.record_listed_origins([disabled_origin])
ret = swh_scheduler.update_metrics(lister_id=disabled_origin.lister_id)
for metric in ret:
if metric.visit_type == disabled_origin.visit_type:
# We disabled one of these origins
assert metric.origins_known - metric.origins_enabled == 1
else:
# But these are still all enabled
assert metric.origins_known == metric.origins_enabled
def test_metrics_origins_never_visited(self, swh_scheduler, listed_origins):
swh_scheduler.record_listed_origins(listed_origins)
# Pretend that we've recorded a visit on one origin
visited_origin = listed_origins[0]
swh_scheduler.origin_visit_stats_upsert(
[
OriginVisitStats(
url=visited_origin.url,
visit_type=visited_origin.visit_type,
last_successful=utcnow(),
last_snapshot=hash_to_bytes(
"d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"
),
),
]
)
ret = swh_scheduler.update_metrics(lister_id=visited_origin.lister_id)
for metric in ret:
if metric.visit_type == visited_origin.visit_type:
# We visited one of these origins
assert metric.origins_known - metric.origins_never_visited == 1
else:
# But none of these have been visited
assert metric.origins_known == metric.origins_never_visited
def test_metrics_origins_with_pending_changes(self, swh_scheduler, listed_origins):
swh_scheduler.record_listed_origins(listed_origins)
# Pretend that we've recorded a visit on one origin, in the past with
# respect to the "last update" time for the origin
visited_origin = listed_origins[0]
assert visited_origin.last_update is not None
swh_scheduler.origin_visit_stats_upsert(
[
OriginVisitStats(
url=visited_origin.url,
visit_type=visited_origin.visit_type,
last_successful=visited_origin.last_update - timedelta(days=1),
last_snapshot=hash_to_bytes(
"d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"
),
),
]
)
ret = swh_scheduler.update_metrics(lister_id=visited_origin.lister_id)
for metric in ret:
if metric.visit_type == visited_origin.visit_type:
# We visited one of these origins, in the past
assert metric.origins_with_pending_changes == 1
else:
# But none of these have been visited
assert metric.origins_with_pending_changes == 0
def test_update_metrics_explicit_lister(self, swh_scheduler, listed_origins):
swh_scheduler.record_listed_origins(listed_origins)
fake_uuid = uuid.uuid4()
assert all(fake_uuid != origin.lister_id for origin in listed_origins)
ret = swh_scheduler.update_metrics(lister_id=fake_uuid)
assert len(ret) == 0
def test_update_metrics_explicit_timestamp(self, swh_scheduler, listed_origins):
swh_scheduler.record_listed_origins(listed_origins)
ts = datetime.datetime(2020, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
ret = swh_scheduler.update_metrics(timestamp=ts)
assert all(metric.last_update == ts for metric in ret)
def test_update_metrics_twice(self, swh_scheduler, listed_origins):
swh_scheduler.record_listed_origins(listed_origins)
ts = utcnow()
ret = swh_scheduler.update_metrics(timestamp=ts)
assert all(metric.last_update == ts for metric in ret)
second_ts = ts + timedelta(seconds=1)
ret = swh_scheduler.update_metrics(timestamp=second_ts)
assert all(metric.last_update == second_ts for metric in ret)
def test_get_metrics(self, swh_scheduler, listed_origins):
swh_scheduler.record_listed_origins(listed_origins)
updated = swh_scheduler.update_metrics()
retrieved = swh_scheduler.get_metrics()
assert_metrics_equal(updated, retrieved)
def test_get_metrics_by_lister(self, swh_scheduler, listed_origins):
lister_id = listed_origins[0].lister_id
assert lister_id is not None
swh_scheduler.record_listed_origins(listed_origins)
updated = swh_scheduler.update_metrics()
retrieved = swh_scheduler.get_metrics(lister_id=lister_id)
assert len(retrieved) > 0
assert_metrics_equal(
[metric for metric in updated if metric.lister_id == lister_id], retrieved
)
def test_get_metrics_by_visit_type(self, swh_scheduler, listed_origins):
visit_type = listed_origins[0].visit_type
assert visit_type is not None
swh_scheduler.record_listed_origins(listed_origins)
updated = swh_scheduler.update_metrics()
retrieved = swh_scheduler.get_metrics(visit_type=visit_type)
assert len(retrieved) > 0
assert_metrics_equal(
[metric for metric in updated if metric.visit_type == visit_type], retrieved
)
diff --git a/swh/scheduler/tests/test_server.py b/swh/scheduler/tests/test_server.py
index b5e1166..a678dd8 100644
--- a/swh/scheduler/tests/test_server.py
+++ b/swh/scheduler/tests/test_server.py
@@ -1,90 +1,100 @@
# Copyright (C) 2019-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import pytest
import yaml
from swh.scheduler.api.server import load_and_check_config
def prepare_config_file(tmpdir, content, name="config.yml"):
"""Prepare configuration file in `$tmpdir/name` with content `content`.
Args:
tmpdir (LocalPath): root directory
content (str/dict): Content of the file either as string or as a dict.
If a dict, converts the dict into a yaml string.
name (str): configuration filename
Returns
path (str) of the configuration file prepared.
"""
config_path = tmpdir / name
if isinstance(content, dict): # convert if needed
content = yaml.dump(content)
config_path.write_text(content, encoding="utf-8")
# pytest on python3.5 does not support LocalPath manipulation, so
# convert path to string
return str(config_path)
@pytest.mark.parametrize("scheduler_class", [None, ""])
def test_load_and_check_config_no_configuration(scheduler_class):
"""Inexistent configuration files raises"""
with pytest.raises(EnvironmentError, match="Configuration file must be defined"):
load_and_check_config(scheduler_class)
def test_load_and_check_config_inexistent_fil():
"""Inexistent config filepath should raise"""
config_path = "/some/inexistent/config.yml"
expected_error = f"Configuration file {config_path} does not exist"
with pytest.raises(FileNotFoundError, match=expected_error):
load_and_check_config(config_path)
def test_load_and_check_config_wrong_configuration(tmpdir):
"""Wrong configuration raises"""
config_path = prepare_config_file(tmpdir, "something: useless")
with pytest.raises(KeyError, match="Missing '%scheduler' configuration"):
load_and_check_config(config_path)
def test_load_and_check_config_remote_config_local_type_raise(tmpdir):
"""Configuration without 'local' storage is rejected"""
config = {"scheduler": {"cls": "remote"}}
config_path = prepare_config_file(tmpdir, config)
expected_error = (
"The scheduler backend can only be started with a 'local'" " configuration"
)
with pytest.raises(ValueError, match=expected_error):
load_and_check_config(config_path, type="local")
def test_load_and_check_config_local_incomplete_configuration(tmpdir):
"""Incomplete 'local' configuration should raise"""
- config = {"scheduler": {"cls": "local", "something": "needed-for-test",}}
+ config = {
+ "scheduler": {
+ "cls": "local",
+ "something": "needed-for-test",
+ }
+ }
config_path = prepare_config_file(tmpdir, config)
expected_error = "Invalid configuration; missing 'db' config entry"
with pytest.raises(KeyError, match=expected_error):
load_and_check_config(config_path)
def test_load_and_check_config_local_config_fine(tmpdir):
"""Local configuration is fine"""
- config = {"scheduler": {"cls": "local", "db": "db",}}
+ config = {
+ "scheduler": {
+ "cls": "local",
+ "db": "db",
+ }
+ }
config_path = prepare_config_file(tmpdir, config)
cfg = load_and_check_config(config_path, type="local")
assert cfg == config
def test_load_and_check_config_remote_config_fine(tmpdir):
"""Remote configuration is fine"""
config = {"scheduler": {"cls": "remote"}}
config_path = prepare_config_file(tmpdir, config)
cfg = load_and_check_config(config_path, type="any")
assert cfg == config
diff --git a/swh/scheduler/tests/test_utils.py b/swh/scheduler/tests/test_utils.py
index 85aecc6..cba3036 100644
--- a/swh/scheduler/tests/test_utils.py
+++ b/swh/scheduler/tests/test_utils.py
@@ -1,79 +1,82 @@
# Copyright (C) 2017-2018 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from datetime import timezone
import unittest
from unittest.mock import patch
from swh.scheduler import utils
class UtilsTest(unittest.TestCase):
@patch("swh.scheduler.utils.datetime")
def test_create_oneshot_task_dict_simple(self, mock_datetime):
mock_datetime.now.return_value = "some-date"
actual_task = utils.create_oneshot_task_dict("some-task-type")
expected_task = {
"policy": "oneshot",
"type": "some-task-type",
"next_run": "some-date",
- "arguments": {"args": [], "kwargs": {},},
+ "arguments": {
+ "args": [],
+ "kwargs": {},
+ },
}
self.assertEqual(actual_task, expected_task)
mock_datetime.now.assert_called_once_with(tz=timezone.utc)
@patch("swh.scheduler.utils.datetime")
def test_create_oneshot_task_dict_other_call(self, mock_datetime):
mock_datetime.now.return_value = "some-other-date"
actual_task = utils.create_oneshot_task_dict(
"some-task-type", "arg0", "arg1", priority="high", other_stuff="normal"
)
expected_task = {
"policy": "oneshot",
"type": "some-task-type",
"next_run": "some-other-date",
"arguments": {
"args": ("arg0", "arg1"),
"kwargs": {"other_stuff": "normal"},
},
"priority": "high",
}
self.assertEqual(actual_task, expected_task)
mock_datetime.now.assert_called_once_with(tz=timezone.utc)
@patch("swh.scheduler.utils.datetime")
def test_create_task_dict(self, mock_datetime):
mock_datetime.now.return_value = "date"
actual_task = utils.create_task_dict(
"task-type",
"recurring",
"arg0",
"arg1",
priority="low",
other_stuff="normal",
retries_left=3,
)
expected_task = {
"policy": "recurring",
"type": "task-type",
"next_run": "date",
"arguments": {
"args": ("arg0", "arg1"),
"kwargs": {"other_stuff": "normal"},
},
"priority": "low",
"retries_left": 3,
}
self.assertEqual(actual_task, expected_task)
mock_datetime.now.assert_called_once_with(tz=timezone.utc)