Page Menu
Home
Software Heritage
Search
Configure Global Search
Log In
Files
F9345908
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
108 KB
Subscribers
None
View Options
diff --git a/swh/scheduler/backend.py b/swh/scheduler/backend.py
index 21dd6ea..e18a209 100644
--- a/swh/scheduler/backend.py
+++ b/swh/scheduler/backend.py
@@ -1,990 +1,984 @@
# Copyright (C) 2015-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import datetime
import json
import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from uuid import UUID
import attr
from psycopg2.errors import CardinalityViolation
import psycopg2.extras
import psycopg2.pool
from swh.core.db import BaseDb
from swh.core.db.common import db_transaction
from swh.scheduler.utils import utcnow
from .exc import SchedulerException, StaleData, UnknownPolicy
-from .model import (
- ListedOrigin,
- ListedOriginPageToken,
- Lister,
- OriginVisitStats,
- PaginatedListedOriginList,
- SchedulerMetrics,
-)
+from .interface import ListedOriginPageToken, PaginatedListedOriginList
+from .model import ListedOrigin, Lister, OriginVisitStats, SchedulerMetrics
logger = logging.getLogger(__name__)
psycopg2.extensions.register_adapter(dict, psycopg2.extras.Json)
psycopg2.extras.register_uuid()
def format_query(query, keys):
"""Format a query with the given keys"""
query_keys = ", ".join(keys)
placeholders = ", ".join(["%s"] * len(keys))
return query.format(keys=query_keys, placeholders=placeholders)
class SchedulerBackend:
"""Backend for the Software Heritage scheduling database.
"""
def __init__(self, db, min_pool_conns=1, max_pool_conns=10):
"""
Args:
db_conn: either a libpq connection string, or a psycopg2 connection
"""
if isinstance(db, psycopg2.extensions.connection):
self._pool = None
self._db = BaseDb(db)
else:
self._pool = psycopg2.pool.ThreadedConnectionPool(
min_pool_conns,
max_pool_conns,
db,
cursor_factory=psycopg2.extras.RealDictCursor,
)
self._db = None
def get_db(self):
if self._db:
return self._db
return BaseDb.from_pool(self._pool)
def put_db(self, db):
if db is not self._db:
db.put_conn()
task_type_keys = [
"type",
"description",
"backend_name",
"default_interval",
"min_interval",
"max_interval",
"backoff_factor",
"max_queue_length",
"num_retries",
"retry_delay",
]
@db_transaction()
def create_task_type(self, task_type, db=None, cur=None):
"""Create a new task type ready for scheduling.
Args:
task_type (dict): a dictionary with the following keys:
- type (str): an identifier for the task type
- description (str): a human-readable description of what the
task does
- backend_name (str): the name of the task in the
job-scheduling backend
- default_interval (datetime.timedelta): the default interval
between two task runs
- min_interval (datetime.timedelta): the minimum interval
between two task runs
- max_interval (datetime.timedelta): the maximum interval
between two task runs
- backoff_factor (float): the factor by which the interval
changes at each run
- max_queue_length (int): the maximum length of the task queue
for this task type
"""
keys = [key for key in self.task_type_keys if key in task_type]
query = format_query(
"""insert into task_type ({keys}) values ({placeholders})
on conflict do nothing""",
keys,
)
cur.execute(query, [task_type[key] for key in keys])
@db_transaction()
def get_task_type(self, task_type_name, db=None, cur=None):
"""Retrieve the task type with id task_type_name"""
query = format_query(
"select {keys} from task_type where type=%s", self.task_type_keys,
)
cur.execute(query, (task_type_name,))
return cur.fetchone()
@db_transaction()
def get_task_types(self, db=None, cur=None):
"""Retrieve all registered task types"""
query = format_query("select {keys} from task_type", self.task_type_keys,)
cur.execute(query)
return cur.fetchall()
@db_transaction()
def get_lister(
self, name: str, instance_name: Optional[str] = None, db=None, cur=None
) -> Optional[Lister]:
"""Retrieve information about the given instance of the lister from the
database.
"""
if instance_name is None:
instance_name = ""
select_cols = ", ".join(Lister.select_columns())
query = f"""
select {select_cols} from listers
where (name, instance_name) = (%s, %s)
"""
cur.execute(query, (name, instance_name))
ret = cur.fetchone()
if not ret:
return None
return Lister(**ret)
@db_transaction()
def get_or_create_lister(
self, name: str, instance_name: Optional[str] = None, db=None, cur=None
) -> Lister:
"""Retrieve information about the given instance of the lister from the
database, or create the entry if it did not exist.
"""
if instance_name is None:
instance_name = ""
select_cols = ", ".join(Lister.select_columns())
insert_cols, insert_meta = (
", ".join(tup) for tup in Lister.insert_columns_and_metavars()
)
query = f"""
with added as (
insert into listers ({insert_cols}) values ({insert_meta})
on conflict do nothing
returning {select_cols}
)
select {select_cols} from added
union all
select {select_cols} from listers
where (name, instance_name) = (%(name)s, %(instance_name)s);
"""
cur.execute(query, attr.asdict(Lister(name=name, instance_name=instance_name)))
return Lister(**cur.fetchone())
@db_transaction()
def update_lister(self, lister: Lister, db=None, cur=None) -> Lister:
"""Update the state for the given lister instance in the database.
Returns:
a new Lister object, with all fields updated from the database
Raises:
StaleData if the `updated` timestamp for the lister instance in
database doesn't match the one passed by the user.
"""
select_cols = ", ".join(Lister.select_columns())
set_vars = ", ".join(
f"{col} = {meta}"
for col, meta in zip(*Lister.insert_columns_and_metavars())
)
query = f"""update listers
set {set_vars}
where id=%(id)s and updated=%(updated)s
returning {select_cols}"""
cur.execute(query, attr.asdict(lister))
updated = cur.fetchone()
if not updated:
raise StaleData("Stale data; Lister state not updated")
return Lister(**updated)
@db_transaction()
def record_listed_origins(
self, listed_origins: Iterable[ListedOrigin], db=None, cur=None
) -> List[ListedOrigin]:
"""Record a set of origins that a lister has listed.
This performs an "upsert": origins with the same (lister_id, url,
visit_type) values are updated with new values for
extra_loader_arguments, last_update and last_seen.
"""
pk_cols = ListedOrigin.primary_key_columns()
select_cols = ListedOrigin.select_columns()
insert_cols, insert_meta = ListedOrigin.insert_columns_and_metavars()
upsert_cols = [col for col in insert_cols if col not in pk_cols]
upsert_set = ", ".join(f"{col} = EXCLUDED.{col}" for col in upsert_cols)
query = f"""INSERT into listed_origins ({", ".join(insert_cols)})
VALUES %s
ON CONFLICT ({", ".join(pk_cols)}) DO UPDATE
SET {upsert_set}
RETURNING {", ".join(select_cols)}
"""
ret = psycopg2.extras.execute_values(
cur=cur,
sql=query,
argslist=(attr.asdict(origin) for origin in listed_origins),
template=f"({', '.join(insert_meta)})",
page_size=1000,
fetch=True,
)
return [ListedOrigin(**d) for d in ret]
@db_transaction()
def get_listed_origins(
self,
lister_id: Optional[UUID] = None,
url: Optional[str] = None,
limit: int = 1000,
page_token: Optional[ListedOriginPageToken] = None,
db=None,
cur=None,
) -> PaginatedListedOriginList:
"""Get information on the listed origins matching either the `url` or
`lister_id`, or both arguments.
"""
query_filters: List[str] = []
query_params: List[Union[int, str, UUID, Tuple[UUID, str]]] = []
if lister_id:
query_filters.append("lister_id = %s")
query_params.append(lister_id)
if url is not None:
query_filters.append("url = %s")
query_params.append(url)
if page_token is not None:
query_filters.append("(lister_id, url) > %s")
# the typeshed annotation for tuple() is too strict.
query_params.append(tuple(page_token)) # type: ignore
query_params.append(limit)
select_cols = ", ".join(ListedOrigin.select_columns())
if query_filters:
where_clause = "where %s" % (" and ".join(query_filters))
else:
where_clause = ""
query = f"""SELECT {select_cols}
from listed_origins
{where_clause}
ORDER BY lister_id, url
LIMIT %s"""
cur.execute(query, tuple(query_params))
origins = [ListedOrigin(**d) for d in cur]
if len(origins) == limit:
- page_token = (origins[-1].lister_id, origins[-1].url)
+ page_token = (str(origins[-1].lister_id), origins[-1].url)
else:
page_token = None
return PaginatedListedOriginList(origins, page_token)
@db_transaction()
def grab_next_visits(
self, visit_type: str, count: int, policy: str, db=None, cur=None,
) -> List[ListedOrigin]:
"""Get at most the `count` next origins that need to be visited with
the `visit_type` loader according to the given scheduling `policy`.
This will mark the origins as "being visited" in the listed_origins
table, to avoid scheduling multiple visits to the same origin.
"""
origin_select_cols = ", ".join(ListedOrigin.select_columns())
# TODO: filter on last_scheduled "too recent" to avoid always
# re-scheduling the same tasks.
where_clauses = [
"enabled", # "NOT enabled" = the lister said the origin no longer exists
"visit_type = %s",
]
if policy == "oldest_scheduled_first":
order_by = "origin_visit_stats.last_scheduled NULLS FIRST"
elif policy == "never_visited_oldest_update_first":
# never visited origins have a NULL last_snapshot
where_clauses.append("origin_visit_stats.last_snapshot IS NULL")
# order by increasing last_update (oldest first)
where_clauses.append("listed_origins.last_update IS NOT NULL")
order_by = "listed_origins.last_update"
elif policy == "already_visited_order_by_lag":
# TODO: store "visit lag" in a materialized view?
# visited origins have a NOT NULL last_snapshot
where_clauses.append("origin_visit_stats.last_snapshot IS NOT NULL")
# ignore origins we have visited after the known last update
where_clauses.append("listed_origins.last_update IS NOT NULL")
where_clauses.append(
"""
listed_origins.last_update
> GREATEST(
origin_visit_stats.last_eventful,
origin_visit_stats.last_uneventful
)
"""
)
# order by decreasing visit lag
order_by = """\
listed_origins.last_update
- GREATEST(
origin_visit_stats.last_eventful,
origin_visit_stats.last_uneventful
)
DESC
"""
else:
raise UnknownPolicy(f"Unknown scheduling policy {policy}")
select_query = f"""
SELECT
{origin_select_cols}
FROM
listed_origins
LEFT JOIN
origin_visit_stats USING (url, visit_type)
WHERE
{" AND ".join(where_clauses)}
ORDER BY
{order_by}
LIMIT %s
"""
query = f"""
WITH selected_origins AS (
{select_query}
),
update_stats AS (
INSERT INTO
origin_visit_stats (
url, visit_type, last_scheduled
)
SELECT
url, visit_type, now()
FROM
selected_origins
ON CONFLICT (url, visit_type) DO UPDATE
SET last_scheduled = GREATEST(
origin_visit_stats.last_scheduled,
EXCLUDED.last_scheduled
)
)
SELECT
*
FROM
selected_origins
"""
cur.execute(query, (visit_type, count))
return [ListedOrigin(**d) for d in cur]
task_create_keys = [
"type",
"arguments",
"next_run",
"policy",
"status",
"retries_left",
"priority",
]
task_keys = task_create_keys + ["id", "current_interval"]
@db_transaction()
def create_tasks(self, tasks, policy="recurring", db=None, cur=None):
"""Create new tasks.
Args:
tasks (list): each task is a dictionary with the following keys:
- type (str): the task type
- arguments (dict): the arguments for the task runner, keys:
- args (list of str): arguments
- kwargs (dict str -> str): keyword arguments
- next_run (datetime.datetime): the next scheduled run for the
task
Returns:
a list of created tasks.
"""
cur.execute("select swh_scheduler_mktemp_task()")
db.copy_to(
tasks,
"tmp_task",
self.task_create_keys,
default_values={"policy": policy, "status": "next_run_not_scheduled"},
cur=cur,
)
query = format_query(
"select {keys} from swh_scheduler_create_tasks_from_temp()", self.task_keys,
)
cur.execute(query)
return cur.fetchall()
@db_transaction()
def set_status_tasks(
self, task_ids, status="disabled", next_run=None, db=None, cur=None
):
"""Set the tasks' status whose ids are listed.
If given, also set the next_run date.
"""
if not task_ids:
return
query = ["UPDATE task SET status = %s"]
args = [status]
if next_run:
query.append(", next_run = %s")
args.append(next_run)
query.append(" WHERE id IN %s")
args.append(tuple(task_ids))
cur.execute("".join(query), args)
@db_transaction()
def disable_tasks(self, task_ids, db=None, cur=None):
"""Disable the tasks whose ids are listed."""
return self.set_status_tasks(task_ids, db=db, cur=cur)
@db_transaction()
def search_tasks(
self,
task_id=None,
task_type=None,
status=None,
priority=None,
policy=None,
before=None,
after=None,
limit=None,
db=None,
cur=None,
):
"""Search tasks from selected criterions"""
where = []
args = []
if task_id:
if isinstance(task_id, (str, int)):
where.append("id = %s")
else:
where.append("id in %s")
task_id = tuple(task_id)
args.append(task_id)
if task_type:
if isinstance(task_type, str):
where.append("type = %s")
else:
where.append("type in %s")
task_type = tuple(task_type)
args.append(task_type)
if status:
if isinstance(status, str):
where.append("status = %s")
else:
where.append("status in %s")
status = tuple(status)
args.append(status)
if priority:
if isinstance(priority, str):
where.append("priority = %s")
else:
priority = tuple(priority)
where.append("priority in %s")
args.append(priority)
if policy:
where.append("policy = %s")
args.append(policy)
if before:
where.append("next_run <= %s")
args.append(before)
if after:
where.append("next_run >= %s")
args.append(after)
query = "select * from task"
if where:
query += " where " + " and ".join(where)
if limit:
query += " limit %s :: bigint"
args.append(limit)
cur.execute(query, args)
return cur.fetchall()
@db_transaction()
def get_tasks(self, task_ids, db=None, cur=None):
"""Retrieve the info of tasks whose ids are listed."""
query = format_query("select {keys} from task where id in %s", self.task_keys)
cur.execute(query, (tuple(task_ids),))
return cur.fetchall()
@db_transaction()
def peek_ready_tasks(
self,
task_type,
timestamp=None,
num_tasks=None,
num_tasks_priority=None,
db=None,
cur=None,
):
"""Fetch the list of ready tasks
Args:
task_type (str): filtering task per their type
timestamp (datetime.datetime): peek tasks that need to be executed
before that timestamp
num_tasks (int): only peek at num_tasks tasks (with no priority)
num_tasks_priority (int): only peek at num_tasks_priority
tasks (with priority)
Returns:
a list of tasks
"""
if timestamp is None:
timestamp = utcnow()
cur.execute(
"""select * from swh_scheduler_peek_ready_tasks(
%s, %s, %s :: bigint, %s :: bigint)""",
(task_type, timestamp, num_tasks, num_tasks_priority),
)
logger.debug("PEEK %s => %s" % (task_type, cur.rowcount))
return cur.fetchall()
@db_transaction()
def grab_ready_tasks(
self,
task_type,
timestamp=None,
num_tasks=None,
num_tasks_priority=None,
db=None,
cur=None,
):
"""Fetch the list of ready tasks, and mark them as scheduled
Args:
task_type (str): filtering task per their type
timestamp (datetime.datetime): grab tasks that need to be executed
before that timestamp
num_tasks (int): only grab num_tasks tasks (with no priority)
num_tasks_priority (int): only grab oneshot num_tasks tasks (with
priorities)
Returns:
a list of tasks
"""
if timestamp is None:
timestamp = utcnow()
cur.execute(
"""select * from swh_scheduler_grab_ready_tasks(
%s, %s, %s :: bigint, %s :: bigint)""",
(task_type, timestamp, num_tasks, num_tasks_priority),
)
logger.debug("GRAB %s => %s" % (task_type, cur.rowcount))
return cur.fetchall()
task_run_create_keys = ["task", "backend_id", "scheduled", "metadata"]
@db_transaction()
def schedule_task_run(
self, task_id, backend_id, metadata=None, timestamp=None, db=None, cur=None
):
"""Mark a given task as scheduled, adding a task_run entry in the database.
Args:
task_id (int): the identifier for the task being scheduled
backend_id (str): the identifier of the job in the backend
metadata (dict): metadata to add to the task_run entry
timestamp (datetime.datetime): the instant the event occurred
Returns:
a fresh task_run entry
"""
if metadata is None:
metadata = {}
if timestamp is None:
timestamp = utcnow()
cur.execute(
"select * from swh_scheduler_schedule_task_run(%s, %s, %s, %s)",
(task_id, backend_id, metadata, timestamp),
)
return cur.fetchone()
@db_transaction()
def mass_schedule_task_runs(self, task_runs, db=None, cur=None):
"""Schedule a bunch of task runs.
Args:
task_runs (list): a list of dicts with keys:
- task (int): the identifier for the task being scheduled
- backend_id (str): the identifier of the job in the backend
- metadata (dict): metadata to add to the task_run entry
- scheduled (datetime.datetime): the instant the event occurred
Returns:
None
"""
cur.execute("select swh_scheduler_mktemp_task_run()")
db.copy_to(task_runs, "tmp_task_run", self.task_run_create_keys, cur=cur)
cur.execute("select swh_scheduler_schedule_task_run_from_temp()")
@db_transaction()
def start_task_run(
self, backend_id, metadata=None, timestamp=None, db=None, cur=None
):
"""Mark a given task as started, updating the corresponding task_run
entry in the database.
Args:
backend_id (str): the identifier of the job in the backend
metadata (dict): metadata to add to the task_run entry
timestamp (datetime.datetime): the instant the event occurred
Returns:
the updated task_run entry
"""
if metadata is None:
metadata = {}
if timestamp is None:
timestamp = utcnow()
cur.execute(
"select * from swh_scheduler_start_task_run(%s, %s, %s)",
(backend_id, metadata, timestamp),
)
return cur.fetchone()
@db_transaction()
def end_task_run(
self,
backend_id,
status,
metadata=None,
timestamp=None,
result=None,
db=None,
cur=None,
):
"""Mark a given task as ended, updating the corresponding task_run entry in the
database.
Args:
backend_id (str): the identifier of the job in the backend
status (str): how the task ended; one of: 'eventful', 'uneventful',
'failed'
metadata (dict): metadata to add to the task_run entry
timestamp (datetime.datetime): the instant the event occurred
Returns:
the updated task_run entry
"""
if metadata is None:
metadata = {}
if timestamp is None:
timestamp = utcnow()
cur.execute(
"select * from swh_scheduler_end_task_run(%s, %s, %s, %s)",
(backend_id, status, metadata, timestamp),
)
return cur.fetchone()
@db_transaction()
def filter_task_to_archive(
self,
after_ts: str,
before_ts: str,
limit: int = 10,
page_token: Optional[str] = None,
db=None,
cur=None,
) -> Dict[str, Any]:
"""Compute the tasks to archive within the datetime interval
[after_ts, before_ts[. The method returns a paginated result.
Returns:
dict with the following keys:
- **next_page_token**: opaque token to be used as
`page_token` to retrieve the next page of result. If absent,
there is no more pages to gather.
- **tasks**: list of task dictionaries with the following keys:
**id** (str): origin task id
**started** (Optional[datetime]): started date
**scheduled** (datetime): scheduled date
**arguments** (json dict): task's arguments
...
"""
assert not page_token or isinstance(page_token, str)
last_id = -1 if page_token is None else int(page_token)
tasks = []
cur.execute(
"select * from swh_scheduler_task_to_archive(%s, %s, %s, %s)",
(after_ts, before_ts, last_id, limit + 1),
)
for row in cur:
task = dict(row)
# nested type index does not accept bare values
# transform it as a dict to comply with this
task["arguments"]["args"] = {
i: v for i, v in enumerate(task["arguments"]["args"])
}
kwargs = task["arguments"]["kwargs"]
task["arguments"]["kwargs"] = json.dumps(kwargs)
tasks.append(task)
if len(tasks) >= limit + 1: # remains data, add pagination information
result = {
"tasks": tasks[:limit],
"next_page_token": str(tasks[-1]["task_id"]),
}
else:
result = {"tasks": tasks}
return result
@db_transaction()
def delete_archived_tasks(self, task_ids, db=None, cur=None):
"""Delete archived tasks as much as possible. Only the task_ids whose
complete associated task_run have been cleaned up will be.
"""
_task_ids = _task_run_ids = []
for task_id in task_ids:
_task_ids.append(task_id["task_id"])
_task_run_ids.append(task_id["task_run_id"])
cur.execute(
"select * from swh_scheduler_delete_archived_tasks(%s, %s)",
(_task_ids, _task_run_ids),
)
task_run_keys = [
"id",
"task",
"backend_id",
"scheduled",
"started",
"ended",
"metadata",
"status",
]
@db_transaction()
def get_task_runs(self, task_ids, limit=None, db=None, cur=None):
"""Search task run for a task id"""
where = []
args = []
if task_ids:
if isinstance(task_ids, (str, int)):
where.append("task = %s")
else:
where.append("task in %s")
task_ids = tuple(task_ids)
args.append(task_ids)
else:
return ()
query = "select * from task_run where " + " and ".join(where)
if limit:
query += " limit %s :: bigint"
args.append(limit)
cur.execute(query, args)
return cur.fetchall()
@db_transaction()
def get_priority_ratios(self, db=None, cur=None):
cur.execute("select id, ratio from priority_ratio")
return {row["id"]: row["ratio"] for row in cur.fetchall()}
@db_transaction()
def origin_visit_stats_upsert(
self, origin_visit_stats: Iterable[OriginVisitStats], db=None, cur=None
) -> None:
pk_cols = OriginVisitStats.primary_key_columns()
insert_cols, insert_meta = OriginVisitStats.insert_columns_and_metavars()
query = f"""
INSERT into origin_visit_stats AS ovi ({", ".join(insert_cols)})
VALUES %s
ON CONFLICT ({", ".join(pk_cols)}) DO UPDATE
SET last_eventful = (
select max(eventful.date) from (values
(excluded.last_eventful),
(ovi.last_eventful)
) as eventful(date)
),
last_uneventful = (
select max(uneventful.date) from (values
(excluded.last_uneventful),
(ovi.last_uneventful)
) as uneventful(date)
),
last_failed = (
select max(failed.date) from (values
(excluded.last_failed),
(ovi.last_failed)
) as failed(date)
),
last_notfound = (
select max(notfound.date) from (values
(excluded.last_notfound),
(ovi.last_notfound)
) as notfound(date)
),
last_snapshot = (select
case
when ovi.last_eventful < excluded.last_eventful then excluded.last_snapshot
else coalesce(ovi.last_snapshot, excluded.last_snapshot)
end
)
""" # noqa
try:
psycopg2.extras.execute_values(
cur=cur,
sql=query,
argslist=(
attr.asdict(visit_stats) for visit_stats in origin_visit_stats
),
template=f"({', '.join(insert_meta)})",
page_size=1000,
fetch=False,
)
except CardinalityViolation as e:
raise SchedulerException(repr(e))
@db_transaction()
def origin_visit_stats_get(
self, ids: Iterable[Tuple[str, str]], db=None, cur=None
) -> List[OriginVisitStats]:
if not ids:
return []
primary_keys = tuple((origin, visit_type) for (origin, visit_type) in ids)
query = format_query(
"""
SELECT {keys}
FROM (VALUES %s) as stats(url, visit_type)
INNER JOIN origin_visit_stats USING (url, visit_type)
""",
OriginVisitStats.select_columns(),
)
psycopg2.extras.execute_values(cur=cur, sql=query, argslist=primary_keys)
return [OriginVisitStats(**row) for row in cur.fetchall()]
@db_transaction()
def update_metrics(
self,
lister_id: Optional[UUID] = None,
timestamp: Optional[datetime.datetime] = None,
db=None,
cur=None,
) -> List[SchedulerMetrics]:
"""Update the performance metrics of this scheduler instance.
Returns the updated metrics.
Args:
lister_id: if passed, update the metrics only for this lister instance
timestamp: if passed, the date at which we're updating the metrics,
defaults to the database NOW()
"""
query = format_query(
"SELECT {keys} FROM update_metrics(%s, %s)",
SchedulerMetrics.select_columns(),
)
cur.execute(query, (lister_id, timestamp))
return [SchedulerMetrics(**row) for row in cur.fetchall()]
@db_transaction()
def get_metrics(
self,
lister_id: Optional[UUID] = None,
visit_type: Optional[str] = None,
db=None,
cur=None,
) -> List[SchedulerMetrics]:
"""Retrieve the performance metrics of this scheduler instance.
Args:
lister_id: filter the metrics for this lister instance only
visit_type: filter the metrics for this visit type only
"""
where_filters = []
where_args = []
if lister_id:
where_filters.append("lister_id = %s")
where_args.append(str(lister_id))
if visit_type:
where_filters.append("visit_type = %s")
where_args.append(visit_type)
where_clause = ""
if where_filters:
where_clause = f"where {' and '.join(where_filters)}"
query = format_query(
"SELECT {keys} FROM scheduler_metrics %s" % where_clause,
SchedulerMetrics.select_columns(),
)
cur.execute(query, tuple(where_args))
return [SchedulerMetrics(**row) for row in cur.fetchall()]
diff --git a/swh/scheduler/interface.py b/swh/scheduler/interface.py
index 289d95d..94c1841 100644
--- a/swh/scheduler/interface.py
+++ b/swh/scheduler/interface.py
@@ -1,388 +1,400 @@
# Copyright (C) 2015-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import datetime
-from typing import Any, Dict, Iterable, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from uuid import UUID
from typing_extensions import Protocol, runtime_checkable
from swh.core.api import remote_api_endpoint
-from swh.scheduler.model import (
- ListedOrigin,
- ListedOriginPageToken,
- Lister,
- OriginVisitStats,
- PaginatedListedOriginList,
- SchedulerMetrics,
-)
+from swh.core.api.classes import PagedResult
+from swh.scheduler.model import ListedOrigin, Lister, OriginVisitStats, SchedulerMetrics
+
+ListedOriginPageToken = Tuple[str, str]
+
+
+class PaginatedListedOriginList(PagedResult[ListedOrigin, ListedOriginPageToken]):
+ """A list of listed origins, with a continuation token"""
+
+ def __init__(
+ self,
+ results: List[ListedOrigin],
+ next_page_token: Union[None, ListedOriginPageToken, List[str]],
+ ):
+ parsed_next_page_token: Optional[Tuple[str, str]] = None
+ if next_page_token is not None:
+ if len(next_page_token) != 2:
+ raise TypeError("Expected Tuple[str, str] or list of size 2.")
+ parsed_next_page_token = tuple(next_page_token) # type: ignore
+ super().__init__(results, parsed_next_page_token)
@runtime_checkable
class SchedulerInterface(Protocol):
@remote_api_endpoint("task_type/create")
def create_task_type(self, task_type):
"""Create a new task type ready for scheduling.
Args:
task_type (dict): a dictionary with the following keys:
- type (str): an identifier for the task type
- description (str): a human-readable description of what the
task does
- backend_name (str): the name of the task in the
job-scheduling backend
- default_interval (datetime.timedelta): the default interval
between two task runs
- min_interval (datetime.timedelta): the minimum interval
between two task runs
- max_interval (datetime.timedelta): the maximum interval
between two task runs
- backoff_factor (float): the factor by which the interval
changes at each run
- max_queue_length (int): the maximum length of the task queue
for this task type
"""
...
@remote_api_endpoint("task_type/get")
def get_task_type(self, task_type_name):
"""Retrieve the task type with id task_type_name"""
...
@remote_api_endpoint("task_type/get_all")
def get_task_types(self):
"""Retrieve all registered task types"""
...
@remote_api_endpoint("task/create")
def create_tasks(self, tasks, policy="recurring"):
"""Create new tasks.
Args:
tasks (list): each task is a dictionary with the following keys:
- type (str): the task type
- arguments (dict): the arguments for the task runner, keys:
- args (list of str): arguments
- kwargs (dict str -> str): keyword arguments
- next_run (datetime.datetime): the next scheduled run for the
task
Returns:
a list of created tasks.
"""
...
@remote_api_endpoint("task/set_status")
def set_status_tasks(self, task_ids, status="disabled", next_run=None):
"""Set the tasks' status whose ids are listed.
If given, also set the next_run date.
"""
...
@remote_api_endpoint("task/disable")
def disable_tasks(self, task_ids):
"""Disable the tasks whose ids are listed."""
...
@remote_api_endpoint("task/search")
def search_tasks(
self,
task_id=None,
task_type=None,
status=None,
priority=None,
policy=None,
before=None,
after=None,
limit=None,
):
"""Search tasks from selected criterions"""
...
@remote_api_endpoint("task/get")
def get_tasks(self, task_ids):
"""Retrieve the info of tasks whose ids are listed."""
...
@remote_api_endpoint("task/peek_ready")
def peek_ready_tasks(
self, task_type, timestamp=None, num_tasks=None, num_tasks_priority=None,
):
"""Fetch the list of ready tasks
Args:
task_type (str): filtering task per their type
timestamp (datetime.datetime): peek tasks that need to be executed
before that timestamp
num_tasks (int): only peek at num_tasks tasks (with no priority)
num_tasks_priority (int): only peek at num_tasks_priority
tasks (with priority)
Returns:
a list of tasks
"""
...
@remote_api_endpoint("task/grab_ready")
def grab_ready_tasks(
self, task_type, timestamp=None, num_tasks=None, num_tasks_priority=None,
):
"""Fetch the list of ready tasks, and mark them as scheduled
Args:
task_type (str): filtering task per their type
timestamp (datetime.datetime): grab tasks that need to be executed
before that timestamp
num_tasks (int): only grab num_tasks tasks (with no priority)
num_tasks_priority (int): only grab oneshot num_tasks tasks (with
priorities)
Returns:
a list of tasks
"""
...
@remote_api_endpoint("task_run/schedule_one")
def schedule_task_run(self, task_id, backend_id, metadata=None, timestamp=None):
"""Mark a given task as scheduled, adding a task_run entry in the database.
Args:
task_id (int): the identifier for the task being scheduled
backend_id (str): the identifier of the job in the backend
metadata (dict): metadata to add to the task_run entry
timestamp (datetime.datetime): the instant the event occurred
Returns:
a fresh task_run entry
"""
...
@remote_api_endpoint("task_run/schedule")
def mass_schedule_task_runs(self, task_runs):
"""Schedule a bunch of task runs.
Args:
task_runs (list): a list of dicts with keys:
- task (int): the identifier for the task being scheduled
- backend_id (str): the identifier of the job in the backend
- metadata (dict): metadata to add to the task_run entry
- scheduled (datetime.datetime): the instant the event occurred
Returns:
None
"""
...
@remote_api_endpoint("task_run/start")
def start_task_run(self, backend_id, metadata=None, timestamp=None):
"""Mark a given task as started, updating the corresponding task_run
entry in the database.
Args:
backend_id (str): the identifier of the job in the backend
metadata (dict): metadata to add to the task_run entry
timestamp (datetime.datetime): the instant the event occurred
Returns:
the updated task_run entry
"""
...
@remote_api_endpoint("task_run/end")
def end_task_run(
self, backend_id, status, metadata=None, timestamp=None, result=None,
):
"""Mark a given task as ended, updating the corresponding task_run entry in the
database.
Args:
backend_id (str): the identifier of the job in the backend
status (str): how the task ended; one of: 'eventful', 'uneventful',
'failed'
metadata (dict): metadata to add to the task_run entry
timestamp (datetime.datetime): the instant the event occurred
Returns:
the updated task_run entry
"""
...
@remote_api_endpoint("task/filter_for_archive")
def filter_task_to_archive(
self,
after_ts: str,
before_ts: str,
limit: int = 10,
page_token: Optional[str] = None,
) -> Dict[str, Any]:
"""Compute the tasks to archive within the datetime interval
[after_ts, before_ts[. The method returns a paginated result.
Returns:
dict with the following keys:
- **next_page_token**: opaque token to be used as
`page_token` to retrieve the next page of result. If absent,
there is no more pages to gather.
- **tasks**: list of task dictionaries with the following keys:
**id** (str): origin task id
**started** (Optional[datetime]): started date
**scheduled** (datetime): scheduled date
**arguments** (json dict): task's arguments
...
"""
...
@remote_api_endpoint("task/delete_archived")
def delete_archived_tasks(self, task_ids):
"""Delete archived tasks as much as possible. Only the task_ids whose
complete associated task_run have been cleaned up will be.
"""
...
@remote_api_endpoint("task_run/get")
def get_task_runs(self, task_ids, limit=None):
"""Search task run for a task id"""
...
@remote_api_endpoint("lister/get")
def get_lister(
self, name: str, instance_name: Optional[str] = None
) -> Optional[Lister]:
"""Retrieve information about the given instance of the lister from the
database.
"""
...
@remote_api_endpoint("lister/get_or_create")
def get_or_create_lister(
self, name: str, instance_name: Optional[str] = None
) -> Lister:
"""Retrieve information about the given instance of the lister from the
database, or create the entry if it did not exist.
"""
...
@remote_api_endpoint("lister/update")
def update_lister(self, lister: Lister) -> Lister:
"""Update the state for the given lister instance in the database.
Returns:
a new Lister object, with all fields updated from the database
Raises:
StaleData if the `updated` timestamp for the lister instance in
database doesn't match the one passed by the user.
"""
...
@remote_api_endpoint("origins/record")
def record_listed_origins(
self, listed_origins: Iterable[ListedOrigin]
) -> List[ListedOrigin]:
"""Record a set of origins that a lister has listed.
This performs an "upsert": origins with the same (lister_id, url,
visit_type) values are updated with new values for
extra_loader_arguments, last_update and last_seen.
"""
...
@remote_api_endpoint("origins/get")
def get_listed_origins(
self,
lister_id: Optional[UUID] = None,
url: Optional[str] = None,
limit: int = 1000,
page_token: Optional[ListedOriginPageToken] = None,
) -> PaginatedListedOriginList:
"""Get information on the listed origins matching either the `url` or
`lister_id`, or both arguments.
Use the `limit` and `page_token` arguments for continuation. The next
page token, if any, is returned in the PaginatedListedOriginList object.
"""
...
@remote_api_endpoint("origins/grab_next")
def grab_next_visits(
self, visit_type: str, count: int, policy: str
) -> List[ListedOrigin]:
"""Get at most the `count` next origins that need to be visited with
the `visit_type` loader according to the given scheduling `policy`.
This will mark the origins as "being visited" in the listed_origins
table, to avoid scheduling multiple visits to the same origin.
"""
...
@remote_api_endpoint("priority_ratios/get")
def get_priority_ratios(self):
...
@remote_api_endpoint("visit_stats/upsert")
def origin_visit_stats_upsert(
self, origin_visit_stats: Iterable[OriginVisitStats]
) -> None:
"""Create a new origin visit stats
"""
...
@remote_api_endpoint("visit_stats/get")
def origin_visit_stats_get(
self, ids: Iterable[Tuple[str, str]]
) -> List[OriginVisitStats]:
"""Retrieve the stats for an origin with a given visit type
If some visit_stats are not found, they are filtered out of the result. So the
output list may be of length inferior to the length of the input list.
"""
...
@remote_api_endpoint("scheduler_metrics/update")
def update_metrics(
self,
lister_id: Optional[UUID] = None,
timestamp: Optional[datetime.datetime] = None,
) -> List[SchedulerMetrics]:
"""Update the performance metrics of this scheduler instance.
Returns the updated metrics.
Args:
lister_id: if passed, update the metrics only for this lister instance
timestamp: if passed, the date at which we're updating the metrics,
defaults to the database NOW()
"""
...
@remote_api_endpoint("scheduler_metrics/get")
def get_metrics(
self, lister_id: Optional[UUID] = None, visit_type: Optional[str] = None
) -> List[SchedulerMetrics]:
"""Retrieve the performance metrics of this scheduler instance.
Args:
lister_id: filter the metrics for this lister instance only
visit_type: filter the metrics for this visit type only
"""
...
diff --git a/swh/scheduler/model.py b/swh/scheduler/model.py
index 1889b23..c14781f 100644
--- a/swh/scheduler/model.py
+++ b/swh/scheduler/model.py
@@ -1,285 +1,254 @@
# Copyright (C) 2020-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import datetime
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple
from uuid import UUID
import attr
import attr.converters
from attrs_strict import type_validator
def check_timestamptz(value) -> None:
"""Checks the date has a timezone."""
if value is not None and value.tzinfo is None:
raise ValueError("date must be a timezone-aware datetime.")
@attr.s
class BaseSchedulerModel:
"""Base class for database-backed objects.
These database-backed objects are defined through attrs-based attributes
that match the columns of the database 1:1. This is a (very) lightweight
ORM.
These attrs-based attributes have metadata specific to the functionality
expected from these fields in the database:
- `primary_key`: the column is a primary key; it should be filtered out
when doing an `update` of the object
- `auto_primary_key`: the column is a primary key, which is automatically handled
by the database. It will not be inserted to. This must be matched with a
database-side default value.
- `auto_now_add`: the column is a timestamp that is set to the current time when
the object is inserted, and never updated afterwards. This must be matched with
a database-side default value.
- `auto_now`: the column is a timestamp that is set to the current time when
the object is inserted or updated.
"""
_pk_cols: Optional[Tuple[str, ...]] = None
_select_cols: Optional[Tuple[str, ...]] = None
_insert_cols_and_metavars: Optional[Tuple[Tuple[str, ...], Tuple[str, ...]]] = None
@classmethod
def primary_key_columns(cls) -> Tuple[str, ...]:
"""Get the primary key columns for this object type"""
if cls._pk_cols is None:
columns: List[str] = []
for field in attr.fields(cls):
if any(
field.metadata.get(flag)
for flag in ("auto_primary_key", "primary_key")
):
columns.append(field.name)
cls._pk_cols = tuple(sorted(columns))
return cls._pk_cols
@classmethod
def select_columns(cls) -> Tuple[str, ...]:
"""Get all the database columns needed for a `select` on this object type"""
if cls._select_cols is None:
columns: List[str] = []
for field in attr.fields(cls):
columns.append(field.name)
cls._select_cols = tuple(sorted(columns))
return cls._select_cols
@classmethod
def insert_columns_and_metavars(cls) -> Tuple[Tuple[str, ...], Tuple[str, ...]]:
"""Get the database columns and metavars needed for an `insert` or `update` on
this object type.
This implements support for the `auto_*` field metadata attributes.
"""
if cls._insert_cols_and_metavars is None:
zipped_cols_and_metavars: List[Tuple[str, str]] = []
for field in attr.fields(cls):
if any(
field.metadata.get(flag)
for flag in ("auto_now_add", "auto_primary_key")
):
continue
elif field.metadata.get("auto_now"):
zipped_cols_and_metavars.append((field.name, "now()"))
else:
zipped_cols_and_metavars.append((field.name, f"%({field.name})s"))
zipped_cols_and_metavars.sort()
cols, metavars = zip(*zipped_cols_and_metavars)
cls._insert_cols_and_metavars = cols, metavars
return cls._insert_cols_and_metavars
@attr.s
class Lister(BaseSchedulerModel):
name = attr.ib(type=str, validator=[type_validator()])
instance_name = attr.ib(type=str, validator=[type_validator()])
# Populated by database
id = attr.ib(
type=Optional[UUID],
validator=type_validator(),
default=None,
metadata={"auto_primary_key": True},
)
current_state = attr.ib(
type=Dict[str, Any], validator=[type_validator()], factory=dict
)
created = attr.ib(
type=Optional[datetime.datetime],
validator=[type_validator()],
default=None,
metadata={"auto_now_add": True},
)
updated = attr.ib(
type=Optional[datetime.datetime],
validator=[type_validator()],
default=None,
metadata={"auto_now": True},
)
@attr.s
class ListedOrigin(BaseSchedulerModel):
"""Basic information about a listed origin, output by a lister"""
lister_id = attr.ib(
type=UUID, validator=[type_validator()], metadata={"primary_key": True}
)
url = attr.ib(
type=str, validator=[type_validator()], metadata={"primary_key": True}
)
visit_type = attr.ib(
type=str, validator=[type_validator()], metadata={"primary_key": True}
)
extra_loader_arguments = attr.ib(
type=Dict[str, str], validator=[type_validator()], factory=dict
)
last_update = attr.ib(
type=Optional[datetime.datetime], validator=[type_validator()], default=None,
)
enabled = attr.ib(type=bool, validator=[type_validator()], default=True)
first_seen = attr.ib(
type=Optional[datetime.datetime],
validator=[type_validator()],
default=None,
metadata={"auto_now_add": True},
)
last_seen = attr.ib(
type=Optional[datetime.datetime],
validator=[type_validator()],
default=None,
metadata={"auto_now": True},
)
def as_task_dict(self):
return {
"type": f"load-{self.visit_type}",
"arguments": {
"args": [],
"kwargs": {"url": self.url, **self.extra_loader_arguments},
},
}
-ListedOriginPageToken = Tuple[UUID, str]
-
-
-def convert_listed_origin_page_token(
- input: Union[None, ListedOriginPageToken, List[Union[UUID, str]]]
-) -> Optional[ListedOriginPageToken]:
- if input is None:
- return None
-
- if isinstance(input, tuple):
- return input
-
- x, y = input
- assert isinstance(x, UUID)
- assert isinstance(y, str)
- return (x, y)
-
-
-@attr.s
-class PaginatedListedOriginList(BaseSchedulerModel):
- """A list of listed origins, with a continuation token"""
-
- origins = attr.ib(type=List[ListedOrigin], validator=[type_validator()])
- next_page_token = attr.ib(
- type=Optional[ListedOriginPageToken],
- validator=[type_validator()],
- converter=convert_listed_origin_page_token,
- default=None,
- )
-
-
@attr.s(frozen=True, slots=True)
class OriginVisitStats(BaseSchedulerModel):
"""Represents an aggregated origin visits view.
"""
url = attr.ib(
type=str, validator=[type_validator()], metadata={"primary_key": True}
)
visit_type = attr.ib(
type=str, validator=[type_validator()], metadata={"primary_key": True}
)
last_eventful = attr.ib(
type=Optional[datetime.datetime], validator=type_validator()
)
last_uneventful = attr.ib(
type=Optional[datetime.datetime], validator=type_validator()
)
last_failed = attr.ib(type=Optional[datetime.datetime], validator=type_validator())
last_notfound = attr.ib(
type=Optional[datetime.datetime], validator=type_validator()
)
last_scheduled = attr.ib(
type=Optional[datetime.datetime], validator=[type_validator()], default=None,
)
last_snapshot = attr.ib(
type=Optional[bytes], validator=type_validator(), default=None
)
@last_eventful.validator
def check_last_eventful(self, attribute, value):
check_timestamptz(value)
@last_uneventful.validator
def check_last_uneventful(self, attribute, value):
check_timestamptz(value)
@last_failed.validator
def check_last_failed(self, attribute, value):
check_timestamptz(value)
@last_notfound.validator
def check_last_notfound(self, attribute, value):
check_timestamptz(value)
@attr.s(frozen=True, slots=True)
class SchedulerMetrics(BaseSchedulerModel):
"""Metrics for the scheduler, aggregated by (lister_id, visit_type)"""
lister_id = attr.ib(
type=UUID, validator=[type_validator()], metadata={"primary_key": True}
)
visit_type = attr.ib(
type=str, validator=[type_validator()], metadata={"primary_key": True}
)
last_update = attr.ib(
type=Optional[datetime.datetime], validator=[type_validator()], default=None,
)
origins_known = attr.ib(type=int, validator=[type_validator()], default=0)
"""Number of known (enabled or disabled) origins"""
origins_enabled = attr.ib(type=int, validator=[type_validator()], default=0)
"""Number of origins that were present in the latest listings"""
origins_never_visited = attr.ib(type=int, validator=[type_validator()], default=0)
"""Number of enabled origins that have never been visited
(according to the visit cache)"""
origins_with_pending_changes = attr.ib(
type=int, validator=[type_validator()], default=0
)
"""Number of enabled origins with known activity (recorded by a lister)
since our last visit"""
diff --git a/swh/scheduler/tests/test_scheduler.py b/swh/scheduler/tests/test_scheduler.py
index a46f9a2..1ecd7fc 100644
--- a/swh/scheduler/tests/test_scheduler.py
+++ b/swh/scheduler/tests/test_scheduler.py
@@ -1,1281 +1,1276 @@
# Copyright (C) 2017-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from collections import defaultdict
import copy
import datetime
import inspect
import random
from typing import Any, Dict, List, Optional, Tuple
import uuid
import attr
import pytest
from swh.model.hashutil import hash_to_bytes
from swh.scheduler.exc import SchedulerException, StaleData, UnknownPolicy
-from swh.scheduler.interface import SchedulerInterface
-from swh.scheduler.model import (
- ListedOrigin,
- ListedOriginPageToken,
- OriginVisitStats,
- SchedulerMetrics,
-)
+from swh.scheduler.interface import ListedOriginPageToken, SchedulerInterface
+from swh.scheduler.model import ListedOrigin, OriginVisitStats, SchedulerMetrics
from swh.scheduler.utils import utcnow
from .common import LISTERS, TASK_TYPES, TEMPLATES, tasks_from_template
ONEDAY = datetime.timedelta(days=1)
def subdict(d, keys=None, excl=()):
if keys is None:
keys = [k for k in d.keys()]
return {k: d[k] for k in keys if k not in excl}
def metrics_sort_key(m: SchedulerMetrics) -> Tuple[uuid.UUID, str]:
return (m.lister_id, m.visit_type)
def assert_metrics_equal(left, right):
assert sorted(left, key=metrics_sort_key) == sorted(right, key=metrics_sort_key)
class TestScheduler:
def test_interface(self, swh_scheduler):
"""Checks all methods of SchedulerInterface are implemented by this
backend, and that they have the same signature."""
# Create an instance of the protocol (which cannot be instantiated
# directly, so this creates a subclass, then instantiates it)
interface = type("_", (SchedulerInterface,), {})()
assert "create_task_type" in dir(interface)
missing_methods = []
for meth_name in dir(interface):
if meth_name.startswith("_"):
continue
interface_meth = getattr(interface, meth_name)
try:
concrete_meth = getattr(swh_scheduler, meth_name)
except AttributeError:
if not getattr(interface_meth, "deprecated_endpoint", False):
# The backend is missing a (non-deprecated) endpoint
missing_methods.append(meth_name)
continue
expected_signature = inspect.signature(interface_meth)
actual_signature = inspect.signature(concrete_meth)
assert expected_signature == actual_signature, meth_name
assert missing_methods == []
def test_get_priority_ratios(self, swh_scheduler):
assert swh_scheduler.get_priority_ratios() == {
"high": 0.5,
"normal": 0.3,
"low": 0.2,
}
def test_add_task_type(self, swh_scheduler):
tt = TASK_TYPES["git"]
swh_scheduler.create_task_type(tt)
assert tt == swh_scheduler.get_task_type(tt["type"])
tt2 = TASK_TYPES["hg"]
swh_scheduler.create_task_type(tt2)
assert tt == swh_scheduler.get_task_type(tt["type"])
assert tt2 == swh_scheduler.get_task_type(tt2["type"])
def test_create_task_type_idempotence(self, swh_scheduler):
tt = TASK_TYPES["git"]
swh_scheduler.create_task_type(tt)
swh_scheduler.create_task_type(tt)
assert tt == swh_scheduler.get_task_type(tt["type"])
def test_get_task_types(self, swh_scheduler):
tt, tt2 = TASK_TYPES["git"], TASK_TYPES["hg"]
swh_scheduler.create_task_type(tt)
swh_scheduler.create_task_type(tt2)
actual_task_types = swh_scheduler.get_task_types()
assert tt in actual_task_types
assert tt2 in actual_task_types
def test_create_tasks(self, swh_scheduler):
priority_ratio = self._priority_ratio(swh_scheduler)
self._create_task_types(swh_scheduler)
num_tasks_priority = 100
tasks_1 = tasks_from_template(TEMPLATES["git"], utcnow(), 100)
tasks_2 = tasks_from_template(
TEMPLATES["hg"],
utcnow(),
100,
num_tasks_priority,
priorities=priority_ratio,
)
tasks = tasks_1 + tasks_2
# tasks are returned only once with their ids
ret1 = swh_scheduler.create_tasks(tasks + tasks_1 + tasks_2)
set_ret1 = set([t["id"] for t in ret1])
# creating the same set result in the same ids
ret = swh_scheduler.create_tasks(tasks)
set_ret = set([t["id"] for t in ret])
# Idempotence results
assert set_ret == set_ret1
assert len(ret) == len(ret1)
ids = set()
actual_priorities = defaultdict(int)
for task, orig_task in zip(ret, tasks):
task = copy.deepcopy(task)
task_type = TASK_TYPES[orig_task["type"].split("-")[-1]]
assert task["id"] not in ids
assert task["status"] == "next_run_not_scheduled"
assert task["current_interval"] == task_type["default_interval"]
assert task["policy"] == orig_task.get("policy", "recurring")
priority = task.get("priority")
if priority:
actual_priorities[priority] += 1
assert task["retries_left"] == (task_type["num_retries"] or 0)
ids.add(task["id"])
del task["id"]
del task["status"]
del task["current_interval"]
del task["retries_left"]
if "policy" not in orig_task:
del task["policy"]
if "priority" not in orig_task:
del task["priority"]
assert task == orig_task
assert dict(actual_priorities) == {
priority: int(ratio * num_tasks_priority)
for priority, ratio in priority_ratio.items()
}
def test_peek_ready_tasks_no_priority(self, swh_scheduler):
self._create_task_types(swh_scheduler)
t = utcnow()
task_type = TEMPLATES["git"]["type"]
tasks = tasks_from_template(TEMPLATES["git"], t, 100)
random.shuffle(tasks)
swh_scheduler.create_tasks(tasks)
ready_tasks = swh_scheduler.peek_ready_tasks(task_type)
assert len(ready_tasks) == len(tasks)
for i in range(len(ready_tasks) - 1):
assert ready_tasks[i]["next_run"] <= ready_tasks[i + 1]["next_run"]
# Only get the first few ready tasks
limit = random.randrange(5, 5 + len(tasks) // 2)
ready_tasks_limited = swh_scheduler.peek_ready_tasks(task_type, num_tasks=limit)
assert len(ready_tasks_limited) == limit
assert ready_tasks_limited == ready_tasks[:limit]
# Limit by timestamp
max_ts = tasks[limit - 1]["next_run"]
ready_tasks_timestamped = swh_scheduler.peek_ready_tasks(
task_type, timestamp=max_ts
)
for ready_task in ready_tasks_timestamped:
assert ready_task["next_run"] <= max_ts
# Make sure we get proper behavior for the first ready tasks
assert ready_tasks[: len(ready_tasks_timestamped)] == ready_tasks_timestamped
# Limit by both
ready_tasks_both = swh_scheduler.peek_ready_tasks(
task_type, timestamp=max_ts, num_tasks=limit // 3
)
assert len(ready_tasks_both) <= limit // 3
for ready_task in ready_tasks_both:
assert ready_task["next_run"] <= max_ts
assert ready_task in ready_tasks[: limit // 3]
def _priority_ratio(self, swh_scheduler):
return swh_scheduler.get_priority_ratios()
def test_peek_ready_tasks_mixed_priorities(self, swh_scheduler):
priority_ratio = self._priority_ratio(swh_scheduler)
self._create_task_types(swh_scheduler)
t = utcnow()
task_type = TEMPLATES["git"]["type"]
num_tasks_priority = 100
num_tasks_no_priority = 100
# Create tasks with and without priorities
tasks = tasks_from_template(
TEMPLATES["git"],
t,
num=num_tasks_no_priority,
num_priority=num_tasks_priority,
priorities=priority_ratio,
)
random.shuffle(tasks)
swh_scheduler.create_tasks(tasks)
# take all available tasks
ready_tasks = swh_scheduler.peek_ready_tasks(task_type)
assert len(ready_tasks) == len(tasks)
assert num_tasks_priority + num_tasks_no_priority == len(ready_tasks)
count_tasks_per_priority = defaultdict(int)
for task in ready_tasks:
priority = task.get("priority")
if priority:
count_tasks_per_priority[priority] += 1
assert dict(count_tasks_per_priority) == {
priority: int(ratio * num_tasks_priority)
for priority, ratio in priority_ratio.items()
}
# Only get some ready tasks
num_tasks = random.randrange(5, 5 + num_tasks_no_priority // 2)
num_tasks_priority = random.randrange(5, num_tasks_priority // 2)
ready_tasks_limited = swh_scheduler.peek_ready_tasks(
task_type, num_tasks=num_tasks, num_tasks_priority=num_tasks_priority
)
count_tasks_per_priority = defaultdict(int)
for task in ready_tasks_limited:
priority = task.get("priority")
count_tasks_per_priority[priority] += 1
import math
for priority, ratio in priority_ratio.items():
expected_count = math.ceil(ratio * num_tasks_priority)
actual_prio = count_tasks_per_priority[priority]
assert actual_prio == expected_count or actual_prio == expected_count + 1
assert count_tasks_per_priority[None] == num_tasks
def test_grab_ready_tasks(self, swh_scheduler):
priority_ratio = self._priority_ratio(swh_scheduler)
self._create_task_types(swh_scheduler)
t = utcnow()
task_type = TEMPLATES["git"]["type"]
num_tasks_priority = 100
num_tasks_no_priority = 100
# Create tasks with and without priorities
tasks = tasks_from_template(
TEMPLATES["git"],
t,
num=num_tasks_no_priority,
num_priority=num_tasks_priority,
priorities=priority_ratio,
)
random.shuffle(tasks)
swh_scheduler.create_tasks(tasks)
first_ready_tasks = swh_scheduler.peek_ready_tasks(
task_type, num_tasks=10, num_tasks_priority=10
)
grabbed_tasks = swh_scheduler.grab_ready_tasks(
task_type, num_tasks=10, num_tasks_priority=10
)
for peeked, grabbed in zip(first_ready_tasks, grabbed_tasks):
assert peeked["status"] == "next_run_not_scheduled"
del peeked["status"]
assert grabbed["status"] == "next_run_scheduled"
del grabbed["status"]
assert peeked == grabbed
assert peeked["priority"] == grabbed["priority"]
def test_get_tasks(self, swh_scheduler):
self._create_task_types(swh_scheduler)
t = utcnow()
tasks = tasks_from_template(TEMPLATES["git"], t, 100)
tasks = swh_scheduler.create_tasks(tasks)
random.shuffle(tasks)
while len(tasks) > 1:
length = random.randrange(1, len(tasks))
cur_tasks = sorted(tasks[:length], key=lambda x: x["id"])
tasks[:length] = []
ret = swh_scheduler.get_tasks(task["id"] for task in cur_tasks)
# result is not guaranteed to be sorted
ret.sort(key=lambda x: x["id"])
assert ret == cur_tasks
def test_search_tasks(self, swh_scheduler):
def make_real_dicts(lst):
"""RealDictRow is not a real dict."""
return [dict(d.items()) for d in lst]
self._create_task_types(swh_scheduler)
t = utcnow()
tasks = tasks_from_template(TEMPLATES["git"], t, 100)
tasks = swh_scheduler.create_tasks(tasks)
assert make_real_dicts(swh_scheduler.search_tasks()) == make_real_dicts(tasks)
def assert_filtered_task_ok(
self, task: Dict[str, Any], after: datetime.datetime, before: datetime.datetime
) -> None:
"""Ensure filtered tasks have the right expected properties
(within the range, recurring disabled, etc..)
"""
started = task["started"]
date = started if started is not None else task["scheduled"]
assert after <= date and date <= before
if task["task_policy"] == "oneshot":
assert task["task_status"] in ["completed", "disabled"]
if task["task_policy"] == "recurring":
assert task["task_status"] in ["disabled"]
def test_filter_task_to_archive(self, swh_scheduler):
"""Filtering only list disabled recurring or completed oneshot tasks
"""
self._create_task_types(swh_scheduler)
_time = utcnow()
recurring = tasks_from_template(TEMPLATES["git"], _time, 12)
oneshots = tasks_from_template(TEMPLATES["hg"], _time, 12)
total_tasks = len(recurring) + len(oneshots)
# simulate scheduling tasks
pending_tasks = swh_scheduler.create_tasks(recurring + oneshots)
backend_tasks = [
{
"task": task["id"],
"backend_id": str(uuid.uuid4()),
"scheduled": utcnow(),
}
for task in pending_tasks
]
swh_scheduler.mass_schedule_task_runs(backend_tasks)
# we simulate the task are being done
_tasks = []
for task in backend_tasks:
t = swh_scheduler.end_task_run(task["backend_id"], status="eventful")
_tasks.append(t)
# Randomly update task's status per policy
status_per_policy = {"recurring": 0, "oneshot": 0}
status_choice = {
# policy: [tuple (1-for-filtering, 'associated-status')]
"recurring": [
(1, "disabled"),
(0, "completed"),
(0, "next_run_not_scheduled"),
],
"oneshot": [
(0, "next_run_not_scheduled"),
(1, "disabled"),
(1, "completed"),
],
}
tasks_to_update = defaultdict(list)
_task_ids = defaultdict(list)
# randomize 'disabling' recurring task or 'complete' oneshot task
for task in pending_tasks:
policy = task["policy"]
_task_ids[policy].append(task["id"])
status = random.choice(status_choice[policy])
if status[0] != 1:
continue
# elected for filtering
status_per_policy[policy] += status[0]
tasks_to_update[policy].append(task["id"])
swh_scheduler.disable_tasks(tasks_to_update["recurring"])
# hack: change the status to something else than completed/disabled
swh_scheduler.set_status_tasks(
_task_ids["oneshot"], status="next_run_not_scheduled"
)
# complete the tasks to update
swh_scheduler.set_status_tasks(tasks_to_update["oneshot"], status="completed")
total_tasks_filtered = (
status_per_policy["recurring"] + status_per_policy["oneshot"]
)
# no pagination scenario
# retrieve tasks to archive
after = _time - ONEDAY
after_ts = after.strftime("%Y-%m-%d")
before = utcnow() + ONEDAY
before_ts = before.strftime("%Y-%m-%d")
tasks_result = swh_scheduler.filter_task_to_archive(
after_ts=after_ts, before_ts=before_ts, limit=total_tasks
)
tasks_to_archive = tasks_result["tasks"]
assert len(tasks_to_archive) == total_tasks_filtered
assert tasks_result.get("next_page_token") is None
actual_filtered_per_status = {"recurring": 0, "oneshot": 0}
for task in tasks_to_archive:
self.assert_filtered_task_ok(task, after, before)
actual_filtered_per_status[task["task_policy"]] += 1
assert actual_filtered_per_status == status_per_policy
# pagination scenario
nb_tasks = 3
tasks_result = swh_scheduler.filter_task_to_archive(
after_ts=after_ts, before_ts=before_ts, limit=nb_tasks
)
tasks_to_archive2 = tasks_result["tasks"]
assert len(tasks_to_archive2) == nb_tasks
next_page_token = tasks_result["next_page_token"]
assert next_page_token is not None
all_tasks = tasks_to_archive2
while next_page_token is not None: # Retrieve paginated results
tasks_result = swh_scheduler.filter_task_to_archive(
after_ts=after_ts,
before_ts=before_ts,
limit=nb_tasks,
page_token=next_page_token,
)
tasks_to_archive2 = tasks_result["tasks"]
assert len(tasks_to_archive2) <= nb_tasks
all_tasks.extend(tasks_to_archive2)
next_page_token = tasks_result.get("next_page_token")
actual_filtered_per_status = {"recurring": 0, "oneshot": 0}
for task in all_tasks:
self.assert_filtered_task_ok(task, after, before)
actual_filtered_per_status[task["task_policy"]] += 1
assert actual_filtered_per_status == status_per_policy
def test_delete_archived_tasks(self, swh_scheduler):
self._create_task_types(swh_scheduler)
_time = utcnow()
recurring = tasks_from_template(TEMPLATES["git"], _time, 12)
oneshots = tasks_from_template(TEMPLATES["hg"], _time, 12)
total_tasks = len(recurring) + len(oneshots)
pending_tasks = swh_scheduler.create_tasks(recurring + oneshots)
backend_tasks = [
{
"task": task["id"],
"backend_id": str(uuid.uuid4()),
"scheduled": utcnow(),
}
for task in pending_tasks
]
swh_scheduler.mass_schedule_task_runs(backend_tasks)
_tasks = []
percent = random.randint(0, 100) # random election removal boundary
for task in backend_tasks:
t = swh_scheduler.end_task_run(task["backend_id"], status="eventful")
c = random.randint(0, 100)
if c <= percent:
_tasks.append({"task_id": t["task"], "task_run_id": t["id"]})
swh_scheduler.delete_archived_tasks(_tasks)
all_tasks = [task["id"] for task in swh_scheduler.search_tasks()]
tasks_count = len(all_tasks)
tasks_run_count = len(swh_scheduler.get_task_runs(all_tasks))
assert tasks_count == total_tasks - len(_tasks)
assert tasks_run_count == total_tasks - len(_tasks)
def test_get_task_runs_no_task(self, swh_scheduler):
"""No task exist in the scheduler's db, get_task_runs() should always return an
empty list.
"""
assert not swh_scheduler.get_task_runs(task_ids=())
assert not swh_scheduler.get_task_runs(task_ids=(1, 2, 3))
assert not swh_scheduler.get_task_runs(task_ids=(1, 2, 3), limit=10)
def test_get_task_runs_no_task_executed(self, swh_scheduler):
"""No task has been executed yet, get_task_runs() should always return an empty
list.
"""
self._create_task_types(swh_scheduler)
_time = utcnow()
recurring = tasks_from_template(TEMPLATES["git"], _time, 12)
oneshots = tasks_from_template(TEMPLATES["hg"], _time, 12)
swh_scheduler.create_tasks(recurring + oneshots)
assert not swh_scheduler.get_task_runs(task_ids=())
assert not swh_scheduler.get_task_runs(task_ids=(1, 2, 3))
assert not swh_scheduler.get_task_runs(task_ids=(1, 2, 3), limit=10)
def test_get_task_runs_with_scheduled(self, swh_scheduler):
"""Some tasks have been scheduled but not executed yet, get_task_runs() should
not return an empty list. limit should behave as expected.
"""
self._create_task_types(swh_scheduler)
_time = utcnow()
recurring = tasks_from_template(TEMPLATES["git"], _time, 12)
oneshots = tasks_from_template(TEMPLATES["hg"], _time, 12)
total_tasks = len(recurring) + len(oneshots)
pending_tasks = swh_scheduler.create_tasks(recurring + oneshots)
backend_tasks = [
{
"task": task["id"],
"backend_id": str(uuid.uuid4()),
"scheduled": utcnow(),
}
for task in pending_tasks
]
swh_scheduler.mass_schedule_task_runs(backend_tasks)
assert not swh_scheduler.get_task_runs(task_ids=[total_tasks + 1])
btask = backend_tasks[0]
runs = swh_scheduler.get_task_runs(task_ids=[btask["task"]])
assert len(runs) == 1
run = runs[0]
assert subdict(run, excl=("id",)) == {
"task": btask["task"],
"backend_id": btask["backend_id"],
"scheduled": btask["scheduled"],
"started": None,
"ended": None,
"metadata": None,
"status": "scheduled",
}
runs = swh_scheduler.get_task_runs(
task_ids=[bt["task"] for bt in backend_tasks], limit=2
)
assert len(runs) == 2
runs = swh_scheduler.get_task_runs(
task_ids=[bt["task"] for bt in backend_tasks]
)
assert len(runs) == total_tasks
keys = ("task", "backend_id", "scheduled")
assert (
sorted([subdict(x, keys) for x in runs], key=lambda x: x["task"])
== backend_tasks
)
def test_get_task_runs_with_executed(self, swh_scheduler):
"""Some tasks have been executed, get_task_runs() should
not return an empty list. limit should behave as expected.
"""
self._create_task_types(swh_scheduler)
_time = utcnow()
recurring = tasks_from_template(TEMPLATES["git"], _time, 12)
oneshots = tasks_from_template(TEMPLATES["hg"], _time, 12)
pending_tasks = swh_scheduler.create_tasks(recurring + oneshots)
backend_tasks = [
{
"task": task["id"],
"backend_id": str(uuid.uuid4()),
"scheduled": utcnow(),
}
for task in pending_tasks
]
swh_scheduler.mass_schedule_task_runs(backend_tasks)
btask = backend_tasks[0]
ts = utcnow()
swh_scheduler.start_task_run(
btask["backend_id"], metadata={"something": "stupid"}, timestamp=ts
)
runs = swh_scheduler.get_task_runs(task_ids=[btask["task"]])
assert len(runs) == 1
assert subdict(runs[0], excl=("id")) == {
"task": btask["task"],
"backend_id": btask["backend_id"],
"scheduled": btask["scheduled"],
"started": ts,
"ended": None,
"metadata": {"something": "stupid"},
"status": "started",
}
ts2 = utcnow()
swh_scheduler.end_task_run(
btask["backend_id"],
metadata={"other": "stuff"},
timestamp=ts2,
status="eventful",
)
runs = swh_scheduler.get_task_runs(task_ids=[btask["task"]])
assert len(runs) == 1
assert subdict(runs[0], excl=("id")) == {
"task": btask["task"],
"backend_id": btask["backend_id"],
"scheduled": btask["scheduled"],
"started": ts,
"ended": ts2,
"metadata": {"something": "stupid", "other": "stuff"},
"status": "eventful",
}
def test_get_or_create_lister(self, swh_scheduler):
db_listers = []
for lister_args in LISTERS:
db_listers.append(swh_scheduler.get_or_create_lister(**lister_args))
for lister, lister_args in zip(db_listers, LISTERS):
assert lister.name == lister_args["name"]
assert lister.instance_name == lister_args.get("instance_name", "")
lister_get_again = swh_scheduler.get_or_create_lister(
lister.name, lister.instance_name
)
assert lister == lister_get_again
def test_get_lister(self, swh_scheduler):
for lister_args in LISTERS:
assert swh_scheduler.get_lister(**lister_args) is None
db_listers = []
for lister_args in LISTERS:
db_listers.append(swh_scheduler.get_or_create_lister(**lister_args))
for lister, lister_args in zip(db_listers, LISTERS):
lister_get_again = swh_scheduler.get_lister(
lister.name, lister.instance_name
)
assert lister == lister_get_again
def test_update_lister(self, swh_scheduler, stored_lister):
lister = attr.evolve(stored_lister, current_state={"updated": "now"})
updated_lister = swh_scheduler.update_lister(lister)
assert updated_lister.updated > lister.updated
assert updated_lister == attr.evolve(lister, updated=updated_lister.updated)
def test_update_lister_stale(self, swh_scheduler, stored_lister):
swh_scheduler.update_lister(stored_lister)
with pytest.raises(StaleData) as exc:
swh_scheduler.update_lister(stored_lister)
assert "state not updated" in exc.value.args[0]
def test_record_listed_origins(self, swh_scheduler, listed_origins):
ret = swh_scheduler.record_listed_origins(listed_origins)
assert set(returned.url for returned in ret) == set(
origin.url for origin in listed_origins
)
assert all(origin.first_seen == origin.last_seen for origin in ret)
def test_record_listed_origins_upsert(self, swh_scheduler, listed_origins):
# First, insert `cutoff` origins
cutoff = 100
assert cutoff < len(listed_origins)
ret = swh_scheduler.record_listed_origins(listed_origins[:cutoff])
assert len(ret) == cutoff
# Then, insert all origins, including the `cutoff` first.
ret = swh_scheduler.record_listed_origins(listed_origins)
assert len(ret) == len(listed_origins)
# Two different "first seen" values
assert len(set(origin.first_seen for origin in ret)) == 2
# But a single "last seen" value
assert len(set(origin.last_seen for origin in ret)) == 1
def test_get_listed_origins_exact(self, swh_scheduler, listed_origins):
swh_scheduler.record_listed_origins(listed_origins)
for i, origin in enumerate(listed_origins):
ret = swh_scheduler.get_listed_origins(
lister_id=origin.lister_id, url=origin.url
)
assert ret.next_page_token is None
- assert len(ret.origins) == 1
- assert ret.origins[0].lister_id == origin.lister_id
- assert ret.origins[0].url == origin.url
+ assert len(ret.results) == 1
+ assert ret.results[0].lister_id == origin.lister_id
+ assert ret.results[0].url == origin.url
@pytest.mark.parametrize("num_origins,limit", [(20, 6), (5, 42), (20, 20)])
def test_get_listed_origins_limit(
self, swh_scheduler, listed_origins, num_origins, limit
) -> None:
added_origins = sorted(
listed_origins[:num_origins], key=lambda o: (o.lister_id, o.url)
)
swh_scheduler.record_listed_origins(added_origins)
returned_origins: List[ListedOrigin] = []
call_count = 0
next_page_token: Optional[ListedOriginPageToken] = None
while True:
call_count += 1
ret = swh_scheduler.get_listed_origins(
lister_id=listed_origins[0].lister_id,
limit=limit,
page_token=next_page_token,
)
- returned_origins.extend(ret.origins)
+ returned_origins.extend(ret.results)
next_page_token = ret.next_page_token
if next_page_token is None:
break
assert call_count == (num_origins // limit) + 1
assert len(returned_origins) == num_origins
assert [(origin.lister_id, origin.url) for origin in returned_origins] == [
(origin.lister_id, origin.url) for origin in added_origins
]
def test_get_listed_origins_all(self, swh_scheduler, listed_origins) -> None:
swh_scheduler.record_listed_origins(listed_origins)
ret = swh_scheduler.get_listed_origins(limit=len(listed_origins) + 1)
assert ret.next_page_token is None
- assert len(ret.origins) == len(listed_origins)
+ assert len(ret.results) == len(listed_origins)
def _grab_next_visits_setup(self, swh_scheduler, listed_origins_by_type):
"""Basic origins setup for scheduling policy tests"""
visit_type = next(iter(listed_origins_by_type))
origins = listed_origins_by_type[visit_type][:100]
assert len(origins) > 0
recorded_origins = swh_scheduler.record_listed_origins(origins)
return visit_type, recorded_origins
def _check_grab_next_visit(self, swh_scheduler, visit_type, policy, expected):
"""Calls grab_next_visits with the passed policy, and check that all
the origins returned are the expected ones (in the same order), and
that no extra origins are returned. Also checks the origin visits have
been marked as scheduled."""
assert len(expected) != 0
before = utcnow()
ret = swh_scheduler.grab_next_visits(
visit_type=visit_type,
# Request one more than expected to check that no extra origin is returned
count=len(expected) + 1,
policy=policy,
)
after = utcnow()
assert ret == expected
visit_stats_list = swh_scheduler.origin_visit_stats_get(
[(origin.url, origin.visit_type) for origin in expected]
)
assert len(visit_stats_list) == len(expected)
for visit_stats in visit_stats_list:
# Check that last_scheduled got updated
assert before <= visit_stats.last_scheduled <= after
def test_grab_next_visits_oldest_scheduled_first(
self, swh_scheduler, listed_origins_by_type,
):
visit_type, origins = self._grab_next_visits_setup(
swh_scheduler, listed_origins_by_type
)
# Give all origins but one a last_scheduled date
base_date = datetime.datetime(2020, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
visit_stats = [
OriginVisitStats(
url=origin.url,
visit_type=origin.visit_type,
last_snapshot=None,
last_eventful=None,
last_uneventful=None,
last_failed=None,
last_notfound=None,
last_scheduled=base_date - datetime.timedelta(seconds=i),
)
for i, origin in enumerate(origins[1:])
]
swh_scheduler.origin_visit_stats_upsert(visit_stats)
# We expect to retrieve the origin with a NULL last_scheduled
# as well as those with the oldest values (i.e. the last ones), in order.
expected = [origins[0]] + origins[1:][::-1]
self._check_grab_next_visit(
swh_scheduler,
visit_type=visit_type,
policy="oldest_scheduled_first",
expected=expected,
)
def test_grab_next_visits_never_visited_oldest_update_first(
self, swh_scheduler, listed_origins_by_type,
):
visit_type, origins = self._grab_next_visits_setup(
swh_scheduler, listed_origins_by_type
)
# Update known origins with a `last_update` field that we control
base_date = datetime.datetime(2020, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
updated_origins = [
attr.evolve(origin, last_update=base_date - datetime.timedelta(seconds=i))
for i, origin in enumerate(origins)
]
updated_origins = swh_scheduler.record_listed_origins(updated_origins)
# We expect to retrieve origins with the oldest update date, that is
# origins at the end of our updated_origins list.
expected_origins = sorted(updated_origins, key=lambda o: o.last_update)
self._check_grab_next_visit(
swh_scheduler,
visit_type=visit_type,
policy="never_visited_oldest_update_first",
expected=expected_origins,
)
def test_grab_next_visits_already_visited_order_by_lag(
self, swh_scheduler, listed_origins_by_type,
):
visit_type, origins = self._grab_next_visits_setup(
swh_scheduler, listed_origins_by_type
)
# Update known origins with a `last_update` field that we control
base_date = datetime.datetime(2020, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
updated_origins = [
attr.evolve(origin, last_update=base_date - datetime.timedelta(seconds=i))
for i, origin in enumerate(origins)
]
updated_origins = swh_scheduler.record_listed_origins(updated_origins)
# Update the visit stats with a known visit at a controlled date for
# half the origins. Pick the date in the middle of the
# updated_origins' `last_update` range
visit_date = updated_origins[len(updated_origins) // 2].last_update
visited_origins = updated_origins[::2]
visit_stats = [
OriginVisitStats(
url=origin.url,
visit_type=origin.visit_type,
last_snapshot=hash_to_bytes("d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"),
last_eventful=visit_date,
last_uneventful=None,
last_failed=None,
last_notfound=None,
)
for origin in visited_origins
]
swh_scheduler.origin_visit_stats_upsert(visit_stats)
# We expect to retrieve visited origins with the largest lag, but only
# those which haven't been visited since their last update
expected_origins = sorted(
[origin for origin in visited_origins if origin.last_update > visit_date],
key=lambda o: visit_date - o.last_update,
)
self._check_grab_next_visit(
swh_scheduler,
visit_type=visit_type,
policy="already_visited_order_by_lag",
expected=expected_origins,
)
def test_grab_next_visits_underflow(self, swh_scheduler, listed_origins_by_type):
"""Check that grab_next_visits works when there not enough origins in
the database"""
visit_type = next(iter(listed_origins_by_type))
# Only add 5 origins to the database
origins = listed_origins_by_type[visit_type][:5]
assert origins
swh_scheduler.record_listed_origins(origins)
ret = swh_scheduler.grab_next_visits(
visit_type, len(origins) + 2, policy="oldest_scheduled_first"
)
assert len(ret) == 5
def test_grab_next_visits_unknown_policy(self, swh_scheduler):
unknown_policy = "non_existing_policy"
NUM_RESULTS = 5
with pytest.raises(UnknownPolicy, match=unknown_policy):
swh_scheduler.grab_next_visits("type", NUM_RESULTS, policy=unknown_policy)
def _create_task_types(self, scheduler):
for tt in TASK_TYPES.values():
scheduler.create_task_type(tt)
def test_origin_visit_stats_get_empty(self, swh_scheduler) -> None:
assert swh_scheduler.origin_visit_stats_get([]) == []
def test_origin_visit_stats_upsert(self, swh_scheduler) -> None:
eventful_date = utcnow()
url = "https://github.com/test"
visit_stats = OriginVisitStats(
url=url,
visit_type="git",
last_eventful=eventful_date,
last_uneventful=None,
last_failed=None,
last_notfound=None,
)
swh_scheduler.origin_visit_stats_upsert([visit_stats])
swh_scheduler.origin_visit_stats_upsert([visit_stats])
assert swh_scheduler.origin_visit_stats_get([(url, "git")]) == [visit_stats]
assert swh_scheduler.origin_visit_stats_get([(url, "svn")]) == []
uneventful_date = utcnow()
visit_stats = OriginVisitStats(
url=url,
visit_type="git",
last_eventful=None,
last_uneventful=uneventful_date,
last_failed=None,
last_notfound=None,
)
swh_scheduler.origin_visit_stats_upsert([visit_stats])
uneventful_visits = swh_scheduler.origin_visit_stats_get([(url, "git")])
expected_visit_stats = OriginVisitStats(
url=url,
visit_type="git",
last_eventful=eventful_date,
last_uneventful=uneventful_date,
last_failed=None,
last_notfound=None,
)
assert uneventful_visits == [expected_visit_stats]
failed_date = utcnow()
visit_stats = OriginVisitStats(
url=url,
visit_type="git",
last_eventful=None,
last_uneventful=None,
last_failed=failed_date,
last_notfound=None,
)
swh_scheduler.origin_visit_stats_upsert([visit_stats])
failed_visits = swh_scheduler.origin_visit_stats_get([(url, "git")])
expected_visit_stats = OriginVisitStats(
url=url,
visit_type="git",
last_eventful=eventful_date,
last_uneventful=uneventful_date,
last_failed=failed_date,
last_notfound=None,
)
assert failed_visits == [expected_visit_stats]
def test_origin_visit_stats_upsert_with_snapshot(self, swh_scheduler) -> None:
eventful_date = utcnow()
url = "https://github.com/666/test"
visit_stats = OriginVisitStats(
url=url,
visit_type="git",
last_eventful=eventful_date,
last_uneventful=None,
last_failed=None,
last_notfound=None,
last_snapshot=hash_to_bytes("d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"),
)
swh_scheduler.origin_visit_stats_upsert([visit_stats])
assert swh_scheduler.origin_visit_stats_get([(url, "git")]) == [visit_stats]
assert swh_scheduler.origin_visit_stats_get([(url, "svn")]) == []
def test_origin_visit_stats_upsert_messing_with_time(self, swh_scheduler) -> None:
url = "interesting-origin"
# Let's play with dates...
date2 = utcnow()
date1 = date2 - ONEDAY
date0 = date1 - ONEDAY
assert date0 < date1 < date2
snapshot2 = hash_to_bytes("d81cc0710eb6cf9efd5b920a8453e1e07157b6cd")
snapshot0 = hash_to_bytes("fffcc0710eb6cf9efd5b920a8453e1e07157bfff")
visit_stats0 = OriginVisitStats(
url=url,
visit_type="git",
last_eventful=date2,
last_uneventful=None,
last_failed=None,
last_notfound=None,
last_snapshot=snapshot2,
)
swh_scheduler.origin_visit_stats_upsert([visit_stats0])
actual_visit_stats0 = swh_scheduler.origin_visit_stats_get([(url, "git")])[0]
assert actual_visit_stats0 == visit_stats0
visit_stats2 = OriginVisitStats(
url=url,
visit_type="git",
last_eventful=None,
last_uneventful=date1,
last_notfound=None,
last_failed=None,
)
swh_scheduler.origin_visit_stats_upsert([visit_stats2])
actual_visit_stats2 = swh_scheduler.origin_visit_stats_get([(url, "git")])[0]
assert actual_visit_stats2 == attr.evolve(
actual_visit_stats0, last_uneventful=date1
)
# a past date, what happens?
# date0 < date2 so this ovs should be dismissed
# the "eventful" associated snapshot should be dismissed as well
visit_stats1 = OriginVisitStats(
url=url,
visit_type="git",
last_eventful=date0,
last_uneventful=None,
last_failed=None,
last_notfound=None,
last_snapshot=snapshot0,
)
swh_scheduler.origin_visit_stats_upsert([visit_stats1])
actual_visit_stats1 = swh_scheduler.origin_visit_stats_get([(url, "git")])[0]
assert actual_visit_stats1 == attr.evolve(
actual_visit_stats2, last_eventful=date2
)
def test_origin_visit_stats_upsert_batch(self, swh_scheduler) -> None:
"""Batch upsert is ok"""
visit_stats = [
OriginVisitStats(
url="foo",
visit_type="git",
last_eventful=utcnow(),
last_uneventful=None,
last_failed=None,
last_notfound=None,
last_snapshot=hash_to_bytes("d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"),
),
OriginVisitStats(
url="bar",
visit_type="git",
last_eventful=None,
last_uneventful=utcnow(),
last_notfound=None,
last_failed=None,
last_snapshot=hash_to_bytes("fffcc0710eb6cf9efd5b920a8453e1e07157bfff"),
),
]
swh_scheduler.origin_visit_stats_upsert(visit_stats)
for visit_stat in swh_scheduler.origin_visit_stats_get(
[(vs.url, vs.visit_type) for vs in visit_stats]
):
assert visit_stat is not None
def test_origin_visit_stats_upsert_cardinality_failing(self, swh_scheduler) -> None:
"""Batch upsert does not support altering multiple times the same origin-visit-status
"""
with pytest.raises(SchedulerException, match="CardinalityViolation"):
swh_scheduler.origin_visit_stats_upsert(
[
OriginVisitStats(
url="foo",
visit_type="git",
last_eventful=None,
last_uneventful=utcnow(),
last_notfound=None,
last_failed=None,
last_snapshot=None,
),
OriginVisitStats(
url="foo",
visit_type="git",
last_eventful=None,
last_uneventful=utcnow(),
last_notfound=None,
last_failed=None,
last_snapshot=None,
),
]
)
def test_metrics_origins_known(self, swh_scheduler, listed_origins):
swh_scheduler.record_listed_origins(listed_origins)
ret = swh_scheduler.update_metrics()
assert sum(metric.origins_known for metric in ret) == len(listed_origins)
def test_metrics_origins_enabled(self, swh_scheduler, listed_origins):
swh_scheduler.record_listed_origins(listed_origins)
disabled_origin = attr.evolve(listed_origins[0], enabled=False)
swh_scheduler.record_listed_origins([disabled_origin])
ret = swh_scheduler.update_metrics(lister_id=disabled_origin.lister_id)
for metric in ret:
if metric.visit_type == disabled_origin.visit_type:
# We disabled one of these origins
assert metric.origins_known - metric.origins_enabled == 1
else:
# But these are still all enabled
assert metric.origins_known == metric.origins_enabled
def test_metrics_origins_never_visited(self, swh_scheduler, listed_origins):
swh_scheduler.record_listed_origins(listed_origins)
# Pretend that we've recorded a visit on one origin
visited_origin = listed_origins[0]
swh_scheduler.origin_visit_stats_upsert(
[
OriginVisitStats(
url=visited_origin.url,
visit_type=visited_origin.visit_type,
last_eventful=utcnow(),
last_uneventful=None,
last_failed=None,
last_notfound=None,
last_snapshot=hash_to_bytes(
"d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"
),
),
]
)
ret = swh_scheduler.update_metrics(lister_id=visited_origin.lister_id)
for metric in ret:
if metric.visit_type == visited_origin.visit_type:
# We visited one of these origins
assert metric.origins_known - metric.origins_never_visited == 1
else:
# But none of these have been visited
assert metric.origins_known == metric.origins_never_visited
def test_metrics_origins_with_pending_changes(self, swh_scheduler, listed_origins):
swh_scheduler.record_listed_origins(listed_origins)
# Pretend that we've recorded a visit on one origin, in the past with
# respect to the "last update" time for the origin
visited_origin = listed_origins[0]
assert visited_origin.last_update is not None
swh_scheduler.origin_visit_stats_upsert(
[
OriginVisitStats(
url=visited_origin.url,
visit_type=visited_origin.visit_type,
last_eventful=visited_origin.last_update
- datetime.timedelta(days=1),
last_uneventful=None,
last_failed=None,
last_notfound=None,
last_snapshot=hash_to_bytes(
"d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"
),
),
]
)
ret = swh_scheduler.update_metrics(lister_id=visited_origin.lister_id)
for metric in ret:
if metric.visit_type == visited_origin.visit_type:
# We visited one of these origins, in the past
assert metric.origins_with_pending_changes == 1
else:
# But none of these have been visited
assert metric.origins_with_pending_changes == 0
def test_update_metrics_explicit_lister(self, swh_scheduler, listed_origins):
swh_scheduler.record_listed_origins(listed_origins)
fake_uuid = uuid.uuid4()
assert all(fake_uuid != origin.lister_id for origin in listed_origins)
ret = swh_scheduler.update_metrics(lister_id=fake_uuid)
assert len(ret) == 0
def test_update_metrics_explicit_timestamp(self, swh_scheduler, listed_origins):
swh_scheduler.record_listed_origins(listed_origins)
ts = datetime.datetime(2020, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
ret = swh_scheduler.update_metrics(timestamp=ts)
assert all(metric.last_update == ts for metric in ret)
def test_update_metrics_twice(self, swh_scheduler, listed_origins):
swh_scheduler.record_listed_origins(listed_origins)
ts = utcnow()
ret = swh_scheduler.update_metrics(timestamp=ts)
assert all(metric.last_update == ts for metric in ret)
second_ts = ts + datetime.timedelta(seconds=1)
ret = swh_scheduler.update_metrics(timestamp=second_ts)
assert all(metric.last_update == second_ts for metric in ret)
def test_get_metrics(self, swh_scheduler, listed_origins):
swh_scheduler.record_listed_origins(listed_origins)
updated = swh_scheduler.update_metrics()
retrieved = swh_scheduler.get_metrics()
assert_metrics_equal(updated, retrieved)
def test_get_metrics_by_lister(self, swh_scheduler, listed_origins):
lister_id = listed_origins[0].lister_id
assert lister_id is not None
swh_scheduler.record_listed_origins(listed_origins)
updated = swh_scheduler.update_metrics()
retrieved = swh_scheduler.get_metrics(lister_id=lister_id)
assert len(retrieved) > 0
assert_metrics_equal(
[metric for metric in updated if metric.lister_id == lister_id], retrieved
)
def test_get_metrics_by_visit_type(self, swh_scheduler, listed_origins):
visit_type = listed_origins[0].visit_type
assert visit_type is not None
swh_scheduler.record_listed_origins(listed_origins)
updated = swh_scheduler.update_metrics()
retrieved = swh_scheduler.get_metrics(visit_type=visit_type)
assert len(retrieved) > 0
assert_metrics_equal(
[metric for metric in updated if metric.visit_type == visit_type], retrieved
)
diff --git a/swh/scheduler/tests/test_simulator.py b/swh/scheduler/tests/test_simulator.py
index a93542e..7c7aaca 100644
--- a/swh/scheduler/tests/test_simulator.py
+++ b/swh/scheduler/tests/test_simulator.py
@@ -1,53 +1,53 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import pytest
+from swh.core.api.classes import stream_results
import swh.scheduler.simulator as simulator
from swh.scheduler.tests.common import TASK_TYPES
NUM_ORIGINS = 42
TEST_RUNTIME = 1000
def test_fill_test_data(swh_scheduler):
for task_type in TASK_TYPES.values():
swh_scheduler.create_task_type(task_type)
simulator.fill_test_data(swh_scheduler, num_origins=NUM_ORIGINS)
- res = swh_scheduler.get_listed_origins()
- assert len(res.origins) == NUM_ORIGINS
- assert res.next_page_token is None
+ origins = list(stream_results(swh_scheduler.get_listed_origins))
+ assert len(origins) == NUM_ORIGINS
res = swh_scheduler.search_tasks()
assert len(res) == NUM_ORIGINS
@pytest.mark.parametrize("policy", ("oldest_scheduled_first",))
def test_run_origin_scheduler(swh_scheduler, policy):
for task_type in TASK_TYPES.values():
swh_scheduler.create_task_type(task_type)
simulator.fill_test_data(swh_scheduler, num_origins=NUM_ORIGINS)
simulator.run(
swh_scheduler,
scheduler_type="origin_scheduler",
policy=policy,
runtime=TEST_RUNTIME,
)
def test_run_task_scheduler(swh_scheduler):
for task_type in TASK_TYPES.values():
swh_scheduler.create_task_type(task_type)
simulator.fill_test_data(swh_scheduler, num_origins=NUM_ORIGINS)
simulator.run(
swh_scheduler,
scheduler_type="task_scheduler",
policy=None,
runtime=TEST_RUNTIME,
)
File Metadata
Details
Attached
Mime Type
text/x-diff
Expires
Fri, Jul 4, 3:35 PM (1 w, 1 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3272037
Attached To
rDSCH Scheduling utilities
Event Timeline
Log In to Comment