diff --git a/PKG-INFO b/PKG-INFO index a625ff7..511eb74 100644 --- a/PKG-INFO +++ b/PKG-INFO @@ -1,32 +1,32 @@ Metadata-Version: 2.1 Name: swh.scheduler -Version: 0.9.2 +Version: 0.10.0 Summary: Software Heritage Scheduler Home-page: https://forge.softwareheritage.org/diffusion/DSCH/ Author: Software Heritage developers Author-email: swh-devel@inria.fr License: UNKNOWN Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate Project-URL: Source, https://forge.softwareheritage.org/source/swh-scheduler Project-URL: Documentation, https://docs.softwareheritage.org/devel/swh-scheduler/ Description: swh-scheduler ============= Job scheduler for the Software Heritage project. Task manager for asynchronous/delayed tasks, used for both recurrent (e.g., listing a forge, loading new stuff from a Git repository) and one-off activities (e.g., loading a specific version of a source package). Platform: UNKNOWN Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable Requires-Python: >=3.7 Description-Content-Type: text/markdown Provides-Extra: testing Provides-Extra: journal Provides-Extra: simulator diff --git a/docs/simulator.rst b/docs/simulator.rst index 923d71a..979a389 100644 --- a/docs/simulator.rst +++ b/docs/simulator.rst @@ -1,65 +1,80 @@ .. _swh-scheduler-simulator: Software Heritage Scheduler Simulator ===================================== This component simulates the interaction between the scheduling and loading infrastructure of Software Heritage. This allows quick(er) development of new task scheduling policies without having to wait for the actual infrastructure to perform (heavy) loading tasks. Simulator components -------------------- - real instance of the scheduler database - simulated task queues: replaces RabbitMQ with simple in-memory structures - simulated workers: replaces Celery with simple while loops - simulated load tasks: replaces loaders with noops that take a certain time, and generate synthetic OriginVisitStatus objects - simulated archive -> scheduler feedback loop: OriginVisitStatus objects are pushed to a simple queue which gets processed by the scheduler journal client's process function directly (instead of going through swh.storage and swh.journal (kafka)) In short, only the scheduler database and scheduler logic is kept; every other component (RabbitMQ, Celery, Kafka, SWH loaders, SWH storage) is either replaced with an barebones in-process utility, or removed entirely. Installing the simulator ------------------------ The simulator depends on SimPy and other specific libraries. To install them, please use: .. code-block:: bash pip install 'swh.scheduler[simulator]' Running the simulator --------------------- The simulator uses a real instance of the scheduler database, which is (at least for now) persistent across runs of the simulator. You need to set that up beforehand: .. code-block:: bash # if you want to use a temporary instance of postgresql eval `pifpaf run postgresql` # Set this variable for the simulator to know which db to connect to. pifpaf # sets other variables like PGPORT, PGHOST, ... export PGDATABASE=swh-scheduler # Create/initialize the scheduler database swh db create scheduler -d $PGDATABASE swh db init scheduler -d $PGDATABASE # This generates some data in the scheduler database. You can also feed the # database with more realistic data, e.g. from a lister or from a dump of the # production database. swh scheduler -d "dbname=$PGDATABASE" simulator fill-test-data # Run the simulator itself, interacting with the scheduler database you've # just seeded. swh scheduler -d "dbname=$PGDATABASE" simulator run --scheduler origin_scheduler + + +Origin model +------------ + +The origin model is how we represent the behaviors of origins: when they are +created/discovered, how many commits they get and when, and when they fail to load. + +For now it is only a simple approximation designed to exercise simple cases: +origin creation/discovery, a continuous stream of commits, and failure if they have +too many commits to load at once. +For details, see :py:`swh.scheduler.simulator.origins`. + +To keep the simulation fast enough, each origin's state is kept in memory, so the +simulator process will linearly increase in memory usage as it runs. diff --git a/swh.scheduler.egg-info/PKG-INFO b/swh.scheduler.egg-info/PKG-INFO index a625ff7..511eb74 100644 --- a/swh.scheduler.egg-info/PKG-INFO +++ b/swh.scheduler.egg-info/PKG-INFO @@ -1,32 +1,32 @@ Metadata-Version: 2.1 Name: swh.scheduler -Version: 0.9.2 +Version: 0.10.0 Summary: Software Heritage Scheduler Home-page: https://forge.softwareheritage.org/diffusion/DSCH/ Author: Software Heritage developers Author-email: swh-devel@inria.fr License: UNKNOWN Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate Project-URL: Source, https://forge.softwareheritage.org/source/swh-scheduler Project-URL: Documentation, https://docs.softwareheritage.org/devel/swh-scheduler/ Description: swh-scheduler ============= Job scheduler for the Software Heritage project. Task manager for asynchronous/delayed tasks, used for both recurrent (e.g., listing a forge, loading new stuff from a Git repository) and one-off activities (e.g., loading a specific version of a source package). Platform: UNKNOWN Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable Requires-Python: >=3.7 Description-Content-Type: text/markdown Provides-Extra: testing Provides-Extra: journal Provides-Extra: simulator diff --git a/swh/scheduler/backend.py b/swh/scheduler/backend.py index 04220b9..41f1cca 100644 --- a/swh/scheduler/backend.py +++ b/swh/scheduler/backend.py @@ -1,959 +1,998 @@ # Copyright (C) 2015-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import json import logging from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from uuid import UUID import attr from psycopg2.errors import CardinalityViolation import psycopg2.extras import psycopg2.pool from swh.core.db import BaseDb from swh.core.db.common import db_transaction from swh.scheduler.utils import utcnow from .exc import SchedulerException, StaleData, UnknownPolicy from .interface import ListedOriginPageToken, PaginatedListedOriginList from .model import ListedOrigin, Lister, OriginVisitStats, SchedulerMetrics logger = logging.getLogger(__name__) psycopg2.extensions.register_adapter(dict, psycopg2.extras.Json) psycopg2.extras.register_uuid() def format_query(query, keys): """Format a query with the given keys""" query_keys = ", ".join(keys) placeholders = ", ".join(["%s"] * len(keys)) return query.format(keys=query_keys, placeholders=placeholders) class SchedulerBackend: """Backend for the Software Heritage scheduling database. """ def __init__(self, db, min_pool_conns=1, max_pool_conns=10): """ Args: db_conn: either a libpq connection string, or a psycopg2 connection """ if isinstance(db, psycopg2.extensions.connection): self._pool = None self._db = BaseDb(db) else: self._pool = psycopg2.pool.ThreadedConnectionPool( min_pool_conns, max_pool_conns, db, cursor_factory=psycopg2.extras.RealDictCursor, ) self._db = None def get_db(self): if self._db: return self._db return BaseDb.from_pool(self._pool) def put_db(self, db): if db is not self._db: db.put_conn() task_type_keys = [ "type", "description", "backend_name", "default_interval", "min_interval", "max_interval", "backoff_factor", "max_queue_length", "num_retries", "retry_delay", ] @db_transaction() def create_task_type(self, task_type, db=None, cur=None): """Create a new task type ready for scheduling. Args: task_type (dict): a dictionary with the following keys: - type (str): an identifier for the task type - description (str): a human-readable description of what the task does - backend_name (str): the name of the task in the job-scheduling backend - default_interval (datetime.timedelta): the default interval between two task runs - min_interval (datetime.timedelta): the minimum interval between two task runs - max_interval (datetime.timedelta): the maximum interval between two task runs - backoff_factor (float): the factor by which the interval changes at each run - max_queue_length (int): the maximum length of the task queue for this task type """ keys = [key for key in self.task_type_keys if key in task_type] query = format_query( """insert into task_type ({keys}) values ({placeholders}) on conflict do nothing""", keys, ) cur.execute(query, [task_type[key] for key in keys]) @db_transaction() def get_task_type(self, task_type_name, db=None, cur=None): """Retrieve the task type with id task_type_name""" query = format_query( "select {keys} from task_type where type=%s", self.task_type_keys, ) cur.execute(query, (task_type_name,)) return cur.fetchone() @db_transaction() def get_task_types(self, db=None, cur=None): """Retrieve all registered task types""" query = format_query("select {keys} from task_type", self.task_type_keys,) cur.execute(query) return cur.fetchall() @db_transaction() def get_lister( self, name: str, instance_name: Optional[str] = None, db=None, cur=None ) -> Optional[Lister]: """Retrieve information about the given instance of the lister from the database. """ if instance_name is None: instance_name = "" select_cols = ", ".join(Lister.select_columns()) query = f""" select {select_cols} from listers where (name, instance_name) = (%s, %s) """ cur.execute(query, (name, instance_name)) ret = cur.fetchone() if not ret: return None return Lister(**ret) @db_transaction() def get_or_create_lister( self, name: str, instance_name: Optional[str] = None, db=None, cur=None ) -> Lister: """Retrieve information about the given instance of the lister from the database, or create the entry if it did not exist. """ if instance_name is None: instance_name = "" select_cols = ", ".join(Lister.select_columns()) insert_cols, insert_meta = ( ", ".join(tup) for tup in Lister.insert_columns_and_metavars() ) query = f""" with added as ( insert into listers ({insert_cols}) values ({insert_meta}) on conflict do nothing returning {select_cols} ) select {select_cols} from added union all select {select_cols} from listers where (name, instance_name) = (%(name)s, %(instance_name)s); """ cur.execute(query, attr.asdict(Lister(name=name, instance_name=instance_name))) return Lister(**cur.fetchone()) @db_transaction() def update_lister(self, lister: Lister, db=None, cur=None) -> Lister: """Update the state for the given lister instance in the database. Returns: a new Lister object, with all fields updated from the database Raises: StaleData if the `updated` timestamp for the lister instance in database doesn't match the one passed by the user. """ select_cols = ", ".join(Lister.select_columns()) set_vars = ", ".join( f"{col} = {meta}" for col, meta in zip(*Lister.insert_columns_and_metavars()) ) query = f"""update listers set {set_vars} where id=%(id)s and updated=%(updated)s returning {select_cols}""" cur.execute(query, attr.asdict(lister)) updated = cur.fetchone() if not updated: raise StaleData("Stale data; Lister state not updated") return Lister(**updated) @db_transaction() def record_listed_origins( self, listed_origins: Iterable[ListedOrigin], db=None, cur=None ) -> List[ListedOrigin]: """Record a set of origins that a lister has listed. This performs an "upsert": origins with the same (lister_id, url, visit_type) values are updated with new values for extra_loader_arguments, last_update and last_seen. """ pk_cols = ListedOrigin.primary_key_columns() select_cols = ListedOrigin.select_columns() insert_cols, insert_meta = ListedOrigin.insert_columns_and_metavars() upsert_cols = [col for col in insert_cols if col not in pk_cols] upsert_set = ", ".join(f"{col} = EXCLUDED.{col}" for col in upsert_cols) query = f"""INSERT into listed_origins ({", ".join(insert_cols)}) VALUES %s ON CONFLICT ({", ".join(pk_cols)}) DO UPDATE SET {upsert_set} RETURNING {", ".join(select_cols)} """ ret = psycopg2.extras.execute_values( cur=cur, sql=query, argslist=(attr.asdict(origin) for origin in listed_origins), template=f"({', '.join(insert_meta)})", page_size=1000, fetch=True, ) return [ListedOrigin(**d) for d in ret] @db_transaction() def get_listed_origins( self, lister_id: Optional[UUID] = None, url: Optional[str] = None, limit: int = 1000, page_token: Optional[ListedOriginPageToken] = None, db=None, cur=None, ) -> PaginatedListedOriginList: """Get information on the listed origins matching either the `url` or `lister_id`, or both arguments. """ query_filters: List[str] = [] query_params: List[Union[int, str, UUID, Tuple[UUID, str]]] = [] if lister_id: query_filters.append("lister_id = %s") query_params.append(lister_id) if url is not None: query_filters.append("url = %s") query_params.append(url) if page_token is not None: query_filters.append("(lister_id, url) > %s") # the typeshed annotation for tuple() is too strict. query_params.append(tuple(page_token)) # type: ignore query_params.append(limit) select_cols = ", ".join(ListedOrigin.select_columns()) if query_filters: where_clause = "where %s" % (" and ".join(query_filters)) else: where_clause = "" query = f"""SELECT {select_cols} from listed_origins {where_clause} ORDER BY lister_id, url LIMIT %s""" cur.execute(query, tuple(query_params)) origins = [ListedOrigin(**d) for d in cur] if len(origins) == limit: page_token = (str(origins[-1].lister_id), origins[-1].url) else: page_token = None return PaginatedListedOriginList(origins, page_token) @db_transaction() def grab_next_visits( - self, visit_type: str, count: int, policy: str, db=None, cur=None, + self, + visit_type: str, + count: int, + policy: str, + timestamp: Optional[datetime.datetime] = None, + db=None, + cur=None, ) -> List[ListedOrigin]: """Get at most the `count` next origins that need to be visited with the `visit_type` loader according to the given scheduling `policy`. - This will mark the origins as "being visited" in the listed_origins + This will mark the origins as scheduled in the origin_visit_stats table, to avoid scheduling multiple visits to the same origin. + + Arguments: + visit_type: type of visits to schedule + count: number of visits to schedule + policy: the scheduling policy used to select which visits to schedule + timestamp: the mocked timestamp at which we're recording that the visits are + being scheduled (defaults to the current time) """ + if timestamp is None: + timestamp = utcnow() + origin_select_cols = ", ".join(ListedOrigin.select_columns()) - # TODO: filter on last_scheduled "too recent" to avoid always - # re-scheduling the same tasks. - where_clauses = [ - "enabled", # "NOT enabled" = the lister said the origin no longer exists - "visit_type = %s", - ] + query_args: List[Any] = [] + + where_clauses = [] + + # "NOT enabled" = the lister said the origin no longer exists + where_clauses.append("enabled") + + # Only schedule visits of the given type + where_clauses.append("visit_type = %s") + query_args.append(visit_type) + + # Don't re-schedule visits if they're already scheduled but we haven't + # recorded a result yet, unless they've been scheduled more than a week + # ago (it probably means we've lost them in flight somewhere). + where_clauses.append( + """origin_visit_stats.last_scheduled IS NULL + OR origin_visit_stats.last_scheduled < GREATEST( + %s - '7 day'::interval, + origin_visit_stats.last_eventful, + origin_visit_stats.last_uneventful, + origin_visit_stats.last_failed, + origin_visit_stats.last_notfound + ) + """ + ) + query_args.append(timestamp) + if policy == "oldest_scheduled_first": order_by = "origin_visit_stats.last_scheduled NULLS FIRST" elif policy == "never_visited_oldest_update_first": # never visited origins have a NULL last_snapshot where_clauses.append("origin_visit_stats.last_snapshot IS NULL") # order by increasing last_update (oldest first) where_clauses.append("listed_origins.last_update IS NOT NULL") order_by = "listed_origins.last_update" elif policy == "already_visited_order_by_lag": # TODO: store "visit lag" in a materialized view? # visited origins have a NOT NULL last_snapshot where_clauses.append("origin_visit_stats.last_snapshot IS NOT NULL") # ignore origins we have visited after the known last update where_clauses.append("listed_origins.last_update IS NOT NULL") where_clauses.append( """ listed_origins.last_update > GREATEST( origin_visit_stats.last_eventful, origin_visit_stats.last_uneventful ) """ ) # order by decreasing visit lag order_by = """\ listed_origins.last_update - GREATEST( origin_visit_stats.last_eventful, origin_visit_stats.last_uneventful ) DESC """ else: raise UnknownPolicy(f"Unknown scheduling policy {policy}") select_query = f""" SELECT {origin_select_cols} FROM listed_origins LEFT JOIN origin_visit_stats USING (url, visit_type) WHERE - {" AND ".join(where_clauses)} + ({") AND (".join(where_clauses)}) ORDER BY {order_by} LIMIT %s """ + query_args.append(count) query = f""" WITH selected_origins AS ( {select_query} ), update_stats AS ( INSERT INTO origin_visit_stats ( url, visit_type, last_scheduled ) SELECT - url, visit_type, now() + url, visit_type, %s FROM selected_origins ON CONFLICT (url, visit_type) DO UPDATE SET last_scheduled = GREATEST( origin_visit_stats.last_scheduled, EXCLUDED.last_scheduled ) ) SELECT * FROM selected_origins """ + query_args.append(timestamp) - cur.execute(query, (visit_type, count)) + cur.execute(query, tuple(query_args)) return [ListedOrigin(**d) for d in cur] task_create_keys = [ "type", "arguments", "next_run", "policy", "status", "retries_left", "priority", ] task_keys = task_create_keys + ["id", "current_interval"] @db_transaction() def create_tasks(self, tasks, policy="recurring", db=None, cur=None): """Create new tasks. Args: tasks (list): each task is a dictionary with the following keys: - type (str): the task type - arguments (dict): the arguments for the task runner, keys: - args (list of str): arguments - kwargs (dict str -> str): keyword arguments - next_run (datetime.datetime): the next scheduled run for the task Returns: a list of created tasks. """ cur.execute("select swh_scheduler_mktemp_task()") db.copy_to( tasks, "tmp_task", self.task_create_keys, default_values={"policy": policy, "status": "next_run_not_scheduled"}, cur=cur, ) query = format_query( "select {keys} from swh_scheduler_create_tasks_from_temp()", self.task_keys, ) cur.execute(query) return cur.fetchall() @db_transaction() def set_status_tasks( self, task_ids, status="disabled", next_run=None, db=None, cur=None ): """Set the tasks' status whose ids are listed. If given, also set the next_run date. """ if not task_ids: return query = ["UPDATE task SET status = %s"] args = [status] if next_run: query.append(", next_run = %s") args.append(next_run) query.append(" WHERE id IN %s") args.append(tuple(task_ids)) cur.execute("".join(query), args) @db_transaction() def disable_tasks(self, task_ids, db=None, cur=None): """Disable the tasks whose ids are listed.""" return self.set_status_tasks(task_ids, db=db, cur=cur) @db_transaction() def search_tasks( self, task_id=None, task_type=None, status=None, priority=None, policy=None, before=None, after=None, limit=None, db=None, cur=None, ): """Search tasks from selected criterions""" where = [] args = [] if task_id: if isinstance(task_id, (str, int)): where.append("id = %s") else: where.append("id in %s") task_id = tuple(task_id) args.append(task_id) if task_type: if isinstance(task_type, str): where.append("type = %s") else: where.append("type in %s") task_type = tuple(task_type) args.append(task_type) if status: if isinstance(status, str): where.append("status = %s") else: where.append("status in %s") status = tuple(status) args.append(status) if priority: if isinstance(priority, str): where.append("priority = %s") else: priority = tuple(priority) where.append("priority in %s") args.append(priority) if policy: where.append("policy = %s") args.append(policy) if before: where.append("next_run <= %s") args.append(before) if after: where.append("next_run >= %s") args.append(after) query = "select * from task" if where: query += " where " + " and ".join(where) if limit: query += " limit %s :: bigint" args.append(limit) cur.execute(query, args) return cur.fetchall() @db_transaction() def get_tasks(self, task_ids, db=None, cur=None): """Retrieve the info of tasks whose ids are listed.""" query = format_query("select {keys} from task where id in %s", self.task_keys) cur.execute(query, (tuple(task_ids),)) return cur.fetchall() @db_transaction() def peek_ready_tasks( self, task_type, timestamp=None, num_tasks=None, num_tasks_priority=None, db=None, cur=None, ): """Fetch the list of ready tasks Args: task_type (str): filtering task per their type timestamp (datetime.datetime): peek tasks that need to be executed before that timestamp num_tasks (int): only peek at num_tasks tasks (with no priority) num_tasks_priority (int): only peek at num_tasks_priority tasks (with priority) Returns: a list of tasks """ if timestamp is None: timestamp = utcnow() cur.execute( """select * from swh_scheduler_peek_ready_tasks( %s, %s, %s :: bigint, %s :: bigint)""", (task_type, timestamp, num_tasks, num_tasks_priority), ) logger.debug("PEEK %s => %s" % (task_type, cur.rowcount)) return cur.fetchall() @db_transaction() def grab_ready_tasks( self, task_type, timestamp=None, num_tasks=None, num_tasks_priority=None, db=None, cur=None, ): """Fetch the list of ready tasks, and mark them as scheduled Args: task_type (str): filtering task per their type timestamp (datetime.datetime): grab tasks that need to be executed before that timestamp num_tasks (int): only grab num_tasks tasks (with no priority) num_tasks_priority (int): only grab oneshot num_tasks tasks (with priorities) Returns: a list of tasks """ if timestamp is None: timestamp = utcnow() cur.execute( """select * from swh_scheduler_grab_ready_tasks( %s, %s, %s :: bigint, %s :: bigint)""", (task_type, timestamp, num_tasks, num_tasks_priority), ) logger.debug("GRAB %s => %s" % (task_type, cur.rowcount)) return cur.fetchall() task_run_create_keys = ["task", "backend_id", "scheduled", "metadata"] @db_transaction() def schedule_task_run( self, task_id, backend_id, metadata=None, timestamp=None, db=None, cur=None ): """Mark a given task as scheduled, adding a task_run entry in the database. Args: task_id (int): the identifier for the task being scheduled backend_id (str): the identifier of the job in the backend metadata (dict): metadata to add to the task_run entry timestamp (datetime.datetime): the instant the event occurred Returns: a fresh task_run entry """ if metadata is None: metadata = {} if timestamp is None: timestamp = utcnow() cur.execute( "select * from swh_scheduler_schedule_task_run(%s, %s, %s, %s)", (task_id, backend_id, metadata, timestamp), ) return cur.fetchone() @db_transaction() def mass_schedule_task_runs(self, task_runs, db=None, cur=None): """Schedule a bunch of task runs. Args: task_runs (list): a list of dicts with keys: - task (int): the identifier for the task being scheduled - backend_id (str): the identifier of the job in the backend - metadata (dict): metadata to add to the task_run entry - scheduled (datetime.datetime): the instant the event occurred Returns: None """ cur.execute("select swh_scheduler_mktemp_task_run()") db.copy_to(task_runs, "tmp_task_run", self.task_run_create_keys, cur=cur) cur.execute("select swh_scheduler_schedule_task_run_from_temp()") @db_transaction() def start_task_run( self, backend_id, metadata=None, timestamp=None, db=None, cur=None ): """Mark a given task as started, updating the corresponding task_run entry in the database. Args: backend_id (str): the identifier of the job in the backend metadata (dict): metadata to add to the task_run entry timestamp (datetime.datetime): the instant the event occurred Returns: the updated task_run entry """ if metadata is None: metadata = {} if timestamp is None: timestamp = utcnow() cur.execute( "select * from swh_scheduler_start_task_run(%s, %s, %s)", (backend_id, metadata, timestamp), ) return cur.fetchone() @db_transaction() def end_task_run( self, backend_id, status, metadata=None, timestamp=None, result=None, db=None, cur=None, ): """Mark a given task as ended, updating the corresponding task_run entry in the database. Args: backend_id (str): the identifier of the job in the backend status (str): how the task ended; one of: 'eventful', 'uneventful', 'failed' metadata (dict): metadata to add to the task_run entry timestamp (datetime.datetime): the instant the event occurred Returns: the updated task_run entry """ if metadata is None: metadata = {} if timestamp is None: timestamp = utcnow() cur.execute( "select * from swh_scheduler_end_task_run(%s, %s, %s, %s)", (backend_id, status, metadata, timestamp), ) return cur.fetchone() @db_transaction() def filter_task_to_archive( self, after_ts: str, before_ts: str, limit: int = 10, page_token: Optional[str] = None, db=None, cur=None, ) -> Dict[str, Any]: """Compute the tasks to archive within the datetime interval [after_ts, before_ts[. The method returns a paginated result. Returns: dict with the following keys: - **next_page_token**: opaque token to be used as `page_token` to retrieve the next page of result. If absent, there is no more pages to gather. - **tasks**: list of task dictionaries with the following keys: **id** (str): origin task id **started** (Optional[datetime]): started date **scheduled** (datetime): scheduled date **arguments** (json dict): task's arguments ... """ assert not page_token or isinstance(page_token, str) last_id = -1 if page_token is None else int(page_token) tasks = [] cur.execute( "select * from swh_scheduler_task_to_archive(%s, %s, %s, %s)", (after_ts, before_ts, last_id, limit + 1), ) for row in cur: task = dict(row) # nested type index does not accept bare values # transform it as a dict to comply with this task["arguments"]["args"] = { i: v for i, v in enumerate(task["arguments"]["args"]) } kwargs = task["arguments"]["kwargs"] task["arguments"]["kwargs"] = json.dumps(kwargs) tasks.append(task) if len(tasks) >= limit + 1: # remains data, add pagination information result = { "tasks": tasks[:limit], "next_page_token": str(tasks[-1]["task_id"]), } else: result = {"tasks": tasks} return result @db_transaction() def delete_archived_tasks(self, task_ids, db=None, cur=None): """Delete archived tasks as much as possible. Only the task_ids whose complete associated task_run have been cleaned up will be. """ _task_ids = _task_run_ids = [] for task_id in task_ids: _task_ids.append(task_id["task_id"]) _task_run_ids.append(task_id["task_run_id"]) cur.execute( "select * from swh_scheduler_delete_archived_tasks(%s, %s)", (_task_ids, _task_run_ids), ) task_run_keys = [ "id", "task", "backend_id", "scheduled", "started", "ended", "metadata", "status", ] @db_transaction() def get_task_runs(self, task_ids, limit=None, db=None, cur=None): """Search task run for a task id""" where = [] args = [] if task_ids: if isinstance(task_ids, (str, int)): where.append("task = %s") else: where.append("task in %s") task_ids = tuple(task_ids) args.append(task_ids) else: return () query = "select * from task_run where " + " and ".join(where) if limit: query += " limit %s :: bigint" args.append(limit) cur.execute(query, args) return cur.fetchall() @db_transaction() def get_priority_ratios(self, db=None, cur=None): cur.execute("select id, ratio from priority_ratio") return {row["id"]: row["ratio"] for row in cur.fetchall()} @db_transaction() def origin_visit_stats_upsert( self, origin_visit_stats: Iterable[OriginVisitStats], db=None, cur=None ) -> None: pk_cols = OriginVisitStats.primary_key_columns() insert_cols, insert_meta = OriginVisitStats.insert_columns_and_metavars() query = f""" INSERT into origin_visit_stats AS ovi ({", ".join(insert_cols)}) VALUES %s ON CONFLICT ({", ".join(pk_cols)}) DO UPDATE SET last_eventful = coalesce(excluded.last_eventful, ovi.last_eventful), last_uneventful = coalesce(excluded.last_uneventful, ovi.last_uneventful), last_failed = coalesce(excluded.last_failed, ovi.last_failed), last_notfound = coalesce(excluded.last_notfound, ovi.last_notfound), last_snapshot = coalesce(excluded.last_snapshot, ovi.last_snapshot) """ # noqa try: psycopg2.extras.execute_values( cur=cur, sql=query, argslist=( attr.asdict(visit_stats) for visit_stats in origin_visit_stats ), template=f"({', '.join(insert_meta)})", page_size=1000, fetch=False, ) except CardinalityViolation as e: raise SchedulerException(repr(e)) @db_transaction() def origin_visit_stats_get( self, ids: Iterable[Tuple[str, str]], db=None, cur=None ) -> List[OriginVisitStats]: if not ids: return [] primary_keys = tuple((origin, visit_type) for (origin, visit_type) in ids) query = format_query( """ SELECT {keys} FROM (VALUES %s) as stats(url, visit_type) INNER JOIN origin_visit_stats USING (url, visit_type) """, OriginVisitStats.select_columns(), ) psycopg2.extras.execute_values(cur=cur, sql=query, argslist=primary_keys) return [OriginVisitStats(**row) for row in cur.fetchall()] @db_transaction() def update_metrics( self, lister_id: Optional[UUID] = None, timestamp: Optional[datetime.datetime] = None, db=None, cur=None, ) -> List[SchedulerMetrics]: """Update the performance metrics of this scheduler instance. Returns the updated metrics. Args: lister_id: if passed, update the metrics only for this lister instance timestamp: if passed, the date at which we're updating the metrics, defaults to the database NOW() """ query = format_query( "SELECT {keys} FROM update_metrics(%s, %s)", SchedulerMetrics.select_columns(), ) cur.execute(query, (lister_id, timestamp)) return [SchedulerMetrics(**row) for row in cur.fetchall()] @db_transaction() def get_metrics( self, lister_id: Optional[UUID] = None, visit_type: Optional[str] = None, db=None, cur=None, ) -> List[SchedulerMetrics]: """Retrieve the performance metrics of this scheduler instance. Args: lister_id: filter the metrics for this lister instance only visit_type: filter the metrics for this visit type only """ where_filters = [] where_args = [] if lister_id: where_filters.append("lister_id = %s") where_args.append(str(lister_id)) if visit_type: where_filters.append("visit_type = %s") where_args.append(visit_type) where_clause = "" if where_filters: where_clause = f"where {' and '.join(where_filters)}" query = format_query( "SELECT {keys} FROM scheduler_metrics %s" % where_clause, SchedulerMetrics.select_columns(), ) cur.execute(query, tuple(where_args)) return [SchedulerMetrics(**row) for row in cur.fetchall()] diff --git a/swh/scheduler/celery_backend/config.py b/swh/scheduler/celery_backend/config.py index 7f01219..20b21bd 100644 --- a/swh/scheduler/celery_backend/config.py +++ b/swh/scheduler/celery_backend/config.py @@ -1,342 +1,342 @@ # Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import functools import logging import os from time import monotonic as _monotonic import traceback from typing import Any, Dict import urllib.parse from celery import Celery from celery.signals import celeryd_after_setup, setup_logging, worker_init from celery.utils.log import ColorFormatter from celery.worker.control import Panel from kombu import Exchange, Queue import pkg_resources import requests from swh.core.config import load_named_config, merge_configs from swh.core.sentry import init_sentry from swh.scheduler import CONFIG as SWH_CONFIG try: from swh.core.logger import JournalHandler except ImportError: JournalHandler = None # type: ignore DEFAULT_CONFIG_NAME = "worker" CONFIG_NAME_ENVVAR = "SWH_WORKER_INSTANCE" CONFIG_NAME_TEMPLATE = "worker/%s" DEFAULT_CONFIG = { "task_broker": ("str", "amqp://guest@localhost//"), "task_modules": ("list[str]", []), "task_queues": ("list[str]", []), "task_soft_time_limit": ("int", 0), } logger = logging.getLogger(__name__) # Celery eats tracebacks in signal callbacks, this decorator catches # and prints them. # Also tries to notify Sentry if possible. def _print_errors(f): @functools.wraps(f) def newf(*args, **kwargs): try: return f(*args, **kwargs) except Exception as exc: traceback.print_exc() try: import sentry_sdk sentry_sdk.capture_exception(exc) except Exception: traceback.print_exc() return newf @setup_logging.connect @_print_errors def setup_log_handler( loglevel=None, logfile=None, format=None, colorize=None, log_console=None, log_journal=None, **kwargs, ): """Setup logging according to Software Heritage preferences. We use the command-line loglevel for tasks only, as we never really care about the debug messages from celery. """ if loglevel is None: loglevel = logging.DEBUG if isinstance(loglevel, str): loglevel = logging._nameToLevel[loglevel] formatter = logging.Formatter(format) root_logger = logging.getLogger("") root_logger.setLevel(logging.INFO) log_target = os.environ.get("SWH_LOG_TARGET", "console") if log_target == "console": log_console = True elif log_target == "journal": log_journal = True # this looks for log levels *higher* than DEBUG if loglevel <= logging.DEBUG and log_console is None: log_console = True if log_console: color_formatter = ColorFormatter(format) if colorize else formatter console = logging.StreamHandler() console.setLevel(logging.DEBUG) console.setFormatter(color_formatter) root_logger.addHandler(console) if log_journal: if not JournalHandler: root_logger.warning( "JournalHandler is not available, skipping. " "Please install swh-core[logging]." ) else: systemd_journal = JournalHandler() systemd_journal.setLevel(logging.DEBUG) systemd_journal.setFormatter(formatter) root_logger.addHandler(systemd_journal) logging.getLogger("celery").setLevel(logging.INFO) # Silence amqp heartbeat_tick messages logger = logging.getLogger("amqp") logger.addFilter(lambda record: not record.msg.startswith("heartbeat_tick")) logger.setLevel(loglevel) # Silence useless "Starting new HTTP connection" messages logging.getLogger("urllib3").setLevel(logging.WARNING) # Completely disable azure logspam azure_logger = logging.getLogger("azure.core.pipeline.policies.http_logging_policy") azure_logger.setLevel(logging.WARNING) logging.getLogger("swh").setLevel(loglevel) # get_task_logger makes the swh tasks loggers children of celery.task logging.getLogger("celery.task").setLevel(loglevel) return loglevel @celeryd_after_setup.connect @_print_errors def setup_queues_and_tasks(sender, instance, **kwargs): """Signal called on worker start. This automatically registers swh.scheduler.task.Task subclasses as available celery tasks. This also subscribes the worker to the "implicit" per-task queues defined for these task classes. """ logger.info("Setup Queues & Tasks for %s", sender) instance.app.conf["worker_name"] = sender @worker_init.connect @_print_errors def on_worker_init(*args, **kwargs): try: from sentry_sdk.integrations.celery import CeleryIntegration except ImportError: integrations = [] else: integrations = [CeleryIntegration()] sentry_dsn = None # will be set in `init_sentry` function init_sentry(sentry_dsn, integrations=integrations) @Panel.register def monotonic(state): """Get the current value for the monotonic clock""" return {"monotonic": _monotonic()} def route_for_task(name, args, kwargs, options, task=None, **kw): """Route tasks according to the task_queue attribute in the task class""" if name is not None and name.startswith("swh."): return {"queue": name} def get_queue_stats(app, queue_name): """Get the statistics regarding a queue on the broker. Arguments: queue_name: name of the queue to check Returns a dictionary raw from the RabbitMQ management API; or `None` if the current configuration does not use RabbitMQ. Interesting keys: - Consumers (number of consumers for the queue) - messages (number of messages in queue) - messages_unacknowledged (number of messages currently being processed) Documentation: https://www.rabbitmq.com/management.html#http-api """ conn_info = app.connection().info() if conn_info["transport"] == "memory": # We're running in a test environment, without RabbitMQ. return None url = "http://{hostname}:{port}/api/queues/{vhost}/{queue}".format( hostname=conn_info["hostname"], port=conn_info["port"] + 10000, vhost=urllib.parse.quote(conn_info["virtual_host"], safe=""), queue=urllib.parse.quote(queue_name, safe=""), ) credentials = (conn_info["userid"], conn_info["password"]) r = requests.get(url, auth=credentials) if r.status_code == 404: return {} if r.status_code != 200: raise ValueError( "Got error %s when reading queue stats: %s" % (r.status_code, r.json()) ) return r.json() def get_queue_length(app, queue_name): """Shortcut to get a queue's length""" stats = get_queue_stats(app, queue_name) if stats: return stats.get("messages") def register_task_class(app, name, cls): """Register a class-based task under the given name""" if name in app.tasks: return task_instance = cls() task_instance.name = name app.register_task(task_instance) INSTANCE_NAME = os.environ.get(CONFIG_NAME_ENVVAR) CONFIG_NAME = os.environ.get("SWH_CONFIG_FILENAME") CONFIG = {} # type: Dict[str, Any] if CONFIG_NAME: # load the celery config from the main config file given as # SWH_CONFIG_FILENAME environment variable. # This is expected to have a [celery] section in which we have the # celery specific configuration. SWH_CONFIG.clear() SWH_CONFIG.update(load_named_config(CONFIG_NAME)) CONFIG = SWH_CONFIG.get("celery", {}) if not CONFIG: # otherwise, back to compat config loading mechanism if INSTANCE_NAME: CONFIG_NAME = CONFIG_NAME_TEMPLATE % INSTANCE_NAME else: CONFIG_NAME = DEFAULT_CONFIG_NAME # Load the Celery config CONFIG = load_named_config(CONFIG_NAME, DEFAULT_CONFIG) CONFIG.setdefault("task_modules", []) # load tasks modules declared as plugin entry points for entrypoint in pkg_resources.iter_entry_points("swh.workers"): worker_registrer_fn = entrypoint.load() # The registry function is expected to return a dict which the 'tasks' key # is a string (or a list of strings) with the name of the python module in # which celery tasks are defined. task_modules = worker_registrer_fn().get("task_modules", []) CONFIG["task_modules"].extend(task_modules) # Celery Queues CELERY_QUEUES = [Queue("celery", Exchange("celery"), routing_key="celery")] CELERY_DEFAULT_CONFIG = dict( # Timezone configuration: all in UTC enable_utc=True, timezone="UTC", # Imported modules imports=CONFIG.get("task_modules", []), # Time (in seconds, or a timedelta object) for when after stored task # tombstones will be deleted. None means to never expire results. result_expires=None, # A string identifying the default serialization method to use. Can # be json (default), pickle, yaml, msgpack, or any custom # serialization methods that have been registered with task_serializer="msgpack", # Result serialization format result_serializer="msgpack", - # Late ack means the task messages will be acknowledged after the task has - # been executed, not just before, which is the default behavior. - task_acks_late=True, + # Acknowledge tasks as soon as they're received. We can do this as we have + # external monitoring to decide if we need to retry tasks. + task_acks_late=False, # A string identifying the default serialization method to use. # Can be pickle (default), json, yaml, msgpack or any custom serialization # methods that have been registered with kombu.serialization.registry accept_content=["msgpack", "json"], # If True the task will report its status as “started” # when the task is executed by a worker. task_track_started=True, # Default compression used for task messages. Can be gzip, bzip2 # (if available), or any custom compression schemes registered # in the Kombu compression registry. # result_compression='bzip2', # task_compression='bzip2', # Disable all rate limits, even if tasks has explicit rate limits set. # (Disabling rate limits altogether is recommended if you don’t have any # tasks using them.) worker_disable_rate_limits=True, # Task routing task_routes=route_for_task, # Allow pool restarts from remote worker_pool_restarts=True, # Do not prefetch tasks worker_prefetch_multiplier=1, # Send events worker_send_task_events=True, # Do not send useless task_sent events task_send_sent_event=False, ) def build_app(config=None): config = merge_configs( {k: v for (k, (_, v)) in DEFAULT_CONFIG.items()}, config or {} ) config["task_queues"] = CELERY_QUEUES + [ Queue(queue, Exchange(queue), routing_key=queue) for queue in config.get("task_queues", ()) ] logger.debug("Creating a Celery app with %s", config) # Instantiate the Celery app app = Celery(broker=config["task_broker"], task_cls="swh.scheduler.task:SWHTask") app.add_defaults(CELERY_DEFAULT_CONFIG) app.add_defaults(config) return app app = build_app(CONFIG) # XXX for BW compat Celery.get_queue_length = get_queue_length diff --git a/swh/scheduler/cli/simulator.py b/swh/scheduler/cli/simulator.py index 02280b7..beb7a8c 100644 --- a/swh/scheduler/cli/simulator.py +++ b/swh/scheduler/cli/simulator.py @@ -1,75 +1,87 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import time import click from . import cli @cli.group("simulator") def simulator(): """Scheduler simulator.""" pass @simulator.command("fill-test-data") @click.option( "--num-origins", "-n", type=int, default=100000, help="Number of listed origins to add", ) @click.pass_context def fill_test_data_command(ctx, num_origins): """Fill the scheduler with test data for simulation purposes.""" from swh.scheduler.simulator import fill_test_data click.echo(f"Filling {num_origins:,} listed origins data...") start = time.monotonic() fill_test_data(ctx.obj["scheduler"], num_origins=num_origins) runtime = time.monotonic() - start click.echo(f"Completed in {runtime:.2f} seconds") @simulator.command("run") @click.option( "--scheduler", "-s", type=click.Choice(["task_scheduler", "origin_scheduler"]), default="origin_scheduler", help="Scheduler to simulate", ) @click.option( "--policy", "-p", - type=click.Choice(["oldest_scheduled_first"]), default="oldest_scheduled_first", help="Scheduling policy to simulate (only for origin_scheduler)", ) @click.option("--runtime", "-t", type=float, help="Simulated runtime") +@click.option( + "--plots/--no-plots", + "-P", + "showplots", + help="Show results as plots (with plotille)", +) +@click.option( + "--csv", "-o", "csvfile", type=click.File("w"), help="Export results in a CSV file" +) @click.pass_context -def run_command(ctx, scheduler, policy, runtime): +def run_command(ctx, scheduler, policy, runtime, showplots, csvfile): """Run the scheduler simulator. By default, the simulation runs forever. You can cap the simulated runtime with the --runtime option, and you can always press Ctrl+C to interrupt the running simulation. 'task_scheduler' is the "classic" task-based scheduler; 'origin_scheduler' is the new origin-visit-aware simulator. The latter uses --policy to decide which origins to schedule first based on information from listers. """ from swh.scheduler.simulator import run policy = policy if scheduler == "origin_scheduler" else None - run( + report = run( scheduler=ctx.obj["scheduler"], scheduler_type=scheduler, policy=policy, runtime=runtime, ) + + print(report.format(with_plots=showplots)) + if csvfile is not None: + report.metrics_csv(csvfile) diff --git a/swh/scheduler/interface.py b/swh/scheduler/interface.py index 94c1841..afbcc1c 100644 --- a/swh/scheduler/interface.py +++ b/swh/scheduler/interface.py @@ -1,400 +1,411 @@ # Copyright (C) 2015-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from uuid import UUID from typing_extensions import Protocol, runtime_checkable from swh.core.api import remote_api_endpoint from swh.core.api.classes import PagedResult from swh.scheduler.model import ListedOrigin, Lister, OriginVisitStats, SchedulerMetrics ListedOriginPageToken = Tuple[str, str] class PaginatedListedOriginList(PagedResult[ListedOrigin, ListedOriginPageToken]): """A list of listed origins, with a continuation token""" def __init__( self, results: List[ListedOrigin], next_page_token: Union[None, ListedOriginPageToken, List[str]], ): parsed_next_page_token: Optional[Tuple[str, str]] = None if next_page_token is not None: if len(next_page_token) != 2: raise TypeError("Expected Tuple[str, str] or list of size 2.") parsed_next_page_token = tuple(next_page_token) # type: ignore super().__init__(results, parsed_next_page_token) @runtime_checkable class SchedulerInterface(Protocol): @remote_api_endpoint("task_type/create") def create_task_type(self, task_type): """Create a new task type ready for scheduling. Args: task_type (dict): a dictionary with the following keys: - type (str): an identifier for the task type - description (str): a human-readable description of what the task does - backend_name (str): the name of the task in the job-scheduling backend - default_interval (datetime.timedelta): the default interval between two task runs - min_interval (datetime.timedelta): the minimum interval between two task runs - max_interval (datetime.timedelta): the maximum interval between two task runs - backoff_factor (float): the factor by which the interval changes at each run - max_queue_length (int): the maximum length of the task queue for this task type """ ... @remote_api_endpoint("task_type/get") def get_task_type(self, task_type_name): """Retrieve the task type with id task_type_name""" ... @remote_api_endpoint("task_type/get_all") def get_task_types(self): """Retrieve all registered task types""" ... @remote_api_endpoint("task/create") def create_tasks(self, tasks, policy="recurring"): """Create new tasks. Args: tasks (list): each task is a dictionary with the following keys: - type (str): the task type - arguments (dict): the arguments for the task runner, keys: - args (list of str): arguments - kwargs (dict str -> str): keyword arguments - next_run (datetime.datetime): the next scheduled run for the task Returns: a list of created tasks. """ ... @remote_api_endpoint("task/set_status") def set_status_tasks(self, task_ids, status="disabled", next_run=None): """Set the tasks' status whose ids are listed. If given, also set the next_run date. """ ... @remote_api_endpoint("task/disable") def disable_tasks(self, task_ids): """Disable the tasks whose ids are listed.""" ... @remote_api_endpoint("task/search") def search_tasks( self, task_id=None, task_type=None, status=None, priority=None, policy=None, before=None, after=None, limit=None, ): """Search tasks from selected criterions""" ... @remote_api_endpoint("task/get") def get_tasks(self, task_ids): """Retrieve the info of tasks whose ids are listed.""" ... @remote_api_endpoint("task/peek_ready") def peek_ready_tasks( self, task_type, timestamp=None, num_tasks=None, num_tasks_priority=None, ): """Fetch the list of ready tasks Args: task_type (str): filtering task per their type timestamp (datetime.datetime): peek tasks that need to be executed before that timestamp num_tasks (int): only peek at num_tasks tasks (with no priority) num_tasks_priority (int): only peek at num_tasks_priority tasks (with priority) Returns: a list of tasks """ ... @remote_api_endpoint("task/grab_ready") def grab_ready_tasks( self, task_type, timestamp=None, num_tasks=None, num_tasks_priority=None, ): """Fetch the list of ready tasks, and mark them as scheduled Args: task_type (str): filtering task per their type timestamp (datetime.datetime): grab tasks that need to be executed before that timestamp num_tasks (int): only grab num_tasks tasks (with no priority) num_tasks_priority (int): only grab oneshot num_tasks tasks (with priorities) Returns: a list of tasks """ ... @remote_api_endpoint("task_run/schedule_one") def schedule_task_run(self, task_id, backend_id, metadata=None, timestamp=None): """Mark a given task as scheduled, adding a task_run entry in the database. Args: task_id (int): the identifier for the task being scheduled backend_id (str): the identifier of the job in the backend metadata (dict): metadata to add to the task_run entry timestamp (datetime.datetime): the instant the event occurred Returns: a fresh task_run entry """ ... @remote_api_endpoint("task_run/schedule") def mass_schedule_task_runs(self, task_runs): """Schedule a bunch of task runs. Args: task_runs (list): a list of dicts with keys: - task (int): the identifier for the task being scheduled - backend_id (str): the identifier of the job in the backend - metadata (dict): metadata to add to the task_run entry - scheduled (datetime.datetime): the instant the event occurred Returns: None """ ... @remote_api_endpoint("task_run/start") def start_task_run(self, backend_id, metadata=None, timestamp=None): """Mark a given task as started, updating the corresponding task_run entry in the database. Args: backend_id (str): the identifier of the job in the backend metadata (dict): metadata to add to the task_run entry timestamp (datetime.datetime): the instant the event occurred Returns: the updated task_run entry """ ... @remote_api_endpoint("task_run/end") def end_task_run( self, backend_id, status, metadata=None, timestamp=None, result=None, ): """Mark a given task as ended, updating the corresponding task_run entry in the database. Args: backend_id (str): the identifier of the job in the backend status (str): how the task ended; one of: 'eventful', 'uneventful', 'failed' metadata (dict): metadata to add to the task_run entry timestamp (datetime.datetime): the instant the event occurred Returns: the updated task_run entry """ ... @remote_api_endpoint("task/filter_for_archive") def filter_task_to_archive( self, after_ts: str, before_ts: str, limit: int = 10, page_token: Optional[str] = None, ) -> Dict[str, Any]: """Compute the tasks to archive within the datetime interval [after_ts, before_ts[. The method returns a paginated result. Returns: dict with the following keys: - **next_page_token**: opaque token to be used as `page_token` to retrieve the next page of result. If absent, there is no more pages to gather. - **tasks**: list of task dictionaries with the following keys: **id** (str): origin task id **started** (Optional[datetime]): started date **scheduled** (datetime): scheduled date **arguments** (json dict): task's arguments ... """ ... @remote_api_endpoint("task/delete_archived") def delete_archived_tasks(self, task_ids): """Delete archived tasks as much as possible. Only the task_ids whose complete associated task_run have been cleaned up will be. """ ... @remote_api_endpoint("task_run/get") def get_task_runs(self, task_ids, limit=None): """Search task run for a task id""" ... @remote_api_endpoint("lister/get") def get_lister( self, name: str, instance_name: Optional[str] = None ) -> Optional[Lister]: """Retrieve information about the given instance of the lister from the database. """ ... @remote_api_endpoint("lister/get_or_create") def get_or_create_lister( self, name: str, instance_name: Optional[str] = None ) -> Lister: """Retrieve information about the given instance of the lister from the database, or create the entry if it did not exist. """ ... @remote_api_endpoint("lister/update") def update_lister(self, lister: Lister) -> Lister: """Update the state for the given lister instance in the database. Returns: a new Lister object, with all fields updated from the database Raises: StaleData if the `updated` timestamp for the lister instance in database doesn't match the one passed by the user. """ ... @remote_api_endpoint("origins/record") def record_listed_origins( self, listed_origins: Iterable[ListedOrigin] ) -> List[ListedOrigin]: """Record a set of origins that a lister has listed. This performs an "upsert": origins with the same (lister_id, url, visit_type) values are updated with new values for extra_loader_arguments, last_update and last_seen. """ ... @remote_api_endpoint("origins/get") def get_listed_origins( self, lister_id: Optional[UUID] = None, url: Optional[str] = None, limit: int = 1000, page_token: Optional[ListedOriginPageToken] = None, ) -> PaginatedListedOriginList: """Get information on the listed origins matching either the `url` or `lister_id`, or both arguments. Use the `limit` and `page_token` arguments for continuation. The next page token, if any, is returned in the PaginatedListedOriginList object. """ ... @remote_api_endpoint("origins/grab_next") def grab_next_visits( - self, visit_type: str, count: int, policy: str + self, + visit_type: str, + count: int, + policy: str, + timestamp: Optional[datetime.datetime] = None, ) -> List[ListedOrigin]: """Get at most the `count` next origins that need to be visited with the `visit_type` loader according to the given scheduling `policy`. - This will mark the origins as "being visited" in the listed_origins + This will mark the origins as scheduled in the origin_visit_stats table, to avoid scheduling multiple visits to the same origin. + + Arguments: + visit_type: type of visits to schedule + count: number of visits to schedule + policy: the scheduling policy used to select which visits to schedule + timestamp: the mocked timestamp at which we're recording that the visits are + being scheduled (defaults to the current time) """ ... @remote_api_endpoint("priority_ratios/get") def get_priority_ratios(self): ... @remote_api_endpoint("visit_stats/upsert") def origin_visit_stats_upsert( self, origin_visit_stats: Iterable[OriginVisitStats] ) -> None: """Create a new origin visit stats """ ... @remote_api_endpoint("visit_stats/get") def origin_visit_stats_get( self, ids: Iterable[Tuple[str, str]] ) -> List[OriginVisitStats]: """Retrieve the stats for an origin with a given visit type If some visit_stats are not found, they are filtered out of the result. So the output list may be of length inferior to the length of the input list. """ ... @remote_api_endpoint("scheduler_metrics/update") def update_metrics( self, lister_id: Optional[UUID] = None, timestamp: Optional[datetime.datetime] = None, ) -> List[SchedulerMetrics]: """Update the performance metrics of this scheduler instance. Returns the updated metrics. Args: lister_id: if passed, update the metrics only for this lister instance timestamp: if passed, the date at which we're updating the metrics, defaults to the database NOW() """ ... @remote_api_endpoint("scheduler_metrics/get") def get_metrics( self, lister_id: Optional[UUID] = None, visit_type: Optional[str] = None ) -> List[SchedulerMetrics]: """Retrieve the performance metrics of this scheduler instance. Args: lister_id: filter the metrics for this lister instance only visit_type: filter the metrics for this visit type only """ ... diff --git a/swh/scheduler/simulator/__init__.py b/swh/scheduler/simulator/__init__.py index 1cde89f..7de91c8 100644 --- a/swh/scheduler/simulator/__init__.py +++ b/swh/scheduler/simulator/__init__.py @@ -1,167 +1,163 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """This package runs the scheduler in a simulated environment, to evaluate various metrics. See :ref:`swh-scheduler-simulator`. This module orchestrates of the simulator by initializing processes and connecting them together; these processes are defined in modules in the package and simulate/call specific components.""" from datetime import datetime, timedelta, timezone import logging from typing import Dict, Generator, Optional from simpy import Event from swh.scheduler.interface import SchedulerInterface -from swh.scheduler.model import ListedOrigin from . import origin_scheduler, task_scheduler from .common import Environment, Queue, SimulationReport, Task -from .origins import load_task_process +from .origins import generate_listed_origin, lister_process, load_task_process logger = logging.getLogger(__name__) def update_metrics_process( env: Environment, update_interval: int ) -> Generator[Event, None, None]: """Update the scheduler metrics every `update_interval` (simulated) seconds, and add them to the SimulationReport """ t0 = env.time while True: metrics = env.scheduler.update_metrics(timestamp=env.time) env.report.record_metrics(env.time, metrics) dt = env.time - t0 logger.info("time:%s visits:%s", dt, env.report.total_visits) yield env.timeout(update_interval) def worker_process( env: Environment, name: str, task_queue: Queue, status_queue: Queue ) -> Generator[Event, Task, None]: """A worker which consumes tasks from the input task_queue. Tasks themselves send OriginVisitStatus objects to the status_queue.""" logger.debug("%s worker %s: Start", env.time, name) while True: task = yield task_queue.get() logger.debug( "%s worker %s: Run task %s origin=%s", env.time, name, task.visit_type, task.origin, ) yield env.process(load_task_process(env, task, status_queue=status_queue)) def setup( env: Environment, scheduler_type: str, policy: Optional[str], workers_per_type: Dict[str, int], task_queue_capacity: int, min_batch_size: int, metrics_update_interval: int, ): task_queues = { visit_type: Queue(env, capacity=task_queue_capacity) for visit_type in workers_per_type } status_queue = Queue(env) if scheduler_type == "origin_scheduler": if policy is None: raise ValueError("origin_scheduler needs a scheduling policy") env.process( origin_scheduler.scheduler_runner_process( env, task_queues, policy, min_batch_size=min_batch_size ) ) env.process( origin_scheduler.scheduler_journal_client_process(env, status_queue) ) elif scheduler_type == "task_scheduler": if policy is not None: raise ValueError("task_scheduler doesn't support a scheduling policy") env.process( task_scheduler.scheduler_runner_process( env, task_queues, min_batch_size=min_batch_size ) ) env.process(task_scheduler.scheduler_listener_process(env, status_queue)) else: raise ValueError(f"Unknown scheduler type to simulate: {scheduler_type}") env.process(update_metrics_process(env, metrics_update_interval)) for visit_type, num_workers in workers_per_type.items(): task_queue = task_queues[visit_type] for i in range(num_workers): worker_name = f"worker-{visit_type}-{i}" env.process(worker_process(env, worker_name, task_queue, status_queue)) + lister = env.scheduler.get_or_create_lister(name="example") + assert lister.id + env.process(lister_process(env, lister.id)) + def fill_test_data(scheduler: SchedulerInterface, num_origins: int = 100000): """Fills the database with mock data to test the simulator.""" stored_lister = scheduler.get_or_create_lister(name="example") assert stored_lister.id is not None - origins = [ - ListedOrigin( - lister_id=stored_lister.id, - url=f"https://example.com/{i:04d}.git", - visit_type="git", - last_update=datetime(2020, 6, 15, 16, 0, 0, i, tzinfo=timezone.utc), - ) - for i in range(num_origins) - ] + # Generate 'num_origins' new origins + origins = [generate_listed_origin(stored_lister.id) for _ in range(num_origins)] scheduler.record_listed_origins(origins) scheduler.create_tasks( [ { **origin.as_task_dict(), "policy": "recurring", "next_run": origin.last_update, "interval": timedelta(days=64), } for origin in origins ] ) def run( scheduler: SchedulerInterface, scheduler_type: str, policy: Optional[str], runtime: Optional[int], ): NUM_WORKERS = 48 start_time = datetime.now(tz=timezone.utc) env = Environment(start_time=start_time) env.scheduler = scheduler env.report = SimulationReport() setup( env, scheduler_type=scheduler_type, policy=policy, workers_per_type={"git": NUM_WORKERS}, task_queue_capacity=10000, min_batch_size=1000, metrics_update_interval=3600, ) try: env.run(until=runtime) except KeyboardInterrupt: pass finally: end_time = env.time print("total simulated time:", end_time - start_time) metrics = env.scheduler.update_metrics(timestamp=end_time) env.report.record_metrics(end_time, metrics) - print(env.report.format()) + return env.report diff --git a/swh/scheduler/simulator/common.py b/swh/scheduler/simulator/common.py index 7a26341..c421d36 100644 --- a/swh/scheduler/simulator/common.py +++ b/swh/scheduler/simulator/common.py @@ -1,132 +1,191 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import csv from dataclasses import dataclass, field from datetime import datetime, timedelta import textwrap -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, TextIO, Tuple import uuid import plotille from simpy import Environment as _Environment from simpy import Store from swh.model.model import OriginVisitStatus from swh.scheduler.interface import SchedulerInterface from swh.scheduler.model import SchedulerMetrics @dataclass class SimulationReport: DURATION_THRESHOLD = 3600 """Max duration for histograms""" total_visits: int = 0 """Total count of finished visits""" visit_runtimes: Dict[Tuple[str, bool], List[float]] = field(default_factory=dict) """Collected visit runtimes for each (status, eventful) tuple""" - metrics: List[Tuple[datetime, List[SchedulerMetrics]]] = field(default_factory=list) - """Collected scheduler metrics for every timestamp""" + scheduler_metrics: List[Tuple[datetime, List[SchedulerMetrics]]] = field( + default_factory=list + ) + """Collected scheduler metrics + + This is a list of couples (timestamp, [SchedulerMetrics,]): the list of + scheduler metrics collected at given timestamp. + """ + + visit_metrics: List[Tuple[datetime, int]] = field(default_factory=list) + """Collected visit metrics over time""" + + latest_snapshots: Dict[Tuple[str, str], bytes] = field(default_factory=dict) + """Collected latest snapshots for origins""" + + def record_visit( + self, + origin: Tuple[str, str], + duration: float, + status: str, + snapshot=Optional[bytes], + ) -> None: + eventful = False + if status == "full": + eventful = snapshot != self.latest_snapshots.get(origin) + self.latest_snapshots[origin] = snapshot - def record_visit(self, duration: float, eventful: bool, status: str) -> None: self.total_visits += 1 self.visit_runtimes.setdefault((status, eventful), []).append(duration) - def record_metrics(self, timestamp: datetime, metrics: List[SchedulerMetrics]): - self.metrics.append((timestamp, metrics)) + def record_metrics( + self, timestamp: datetime, scheduler_metrics: List[SchedulerMetrics] + ): + self.scheduler_metrics.append((timestamp, scheduler_metrics)) + self.visit_metrics.append((timestamp, self.total_visits)) @property - def useless_visits(self): + def uneventful_visits(self): """Number of uneventful, full visits""" return len(self.visit_runtimes.get(("full", False), [])) def runtime_histogram(self, status: str, eventful: bool) -> str: runtimes = self.visit_runtimes.get((status, eventful), []) return plotille.hist( [runtime for runtime in runtimes if runtime <= self.DURATION_THRESHOLD] ) def metrics_plot(self) -> str: - timestamps, metric_lists = zip(*self.metrics) + timestamps, metric_lists = zip(*self.scheduler_metrics) known = [sum(m.origins_known for m in metrics) for metrics in metric_lists] never_visited = [ sum(m.origins_never_visited for m in metrics) for metrics in metric_lists ] figure = plotille.Figure() figure.x_label = "simulated time" figure.y_label = "origins" figure.scatter(timestamps, known, label="Known origins") figure.scatter(timestamps, never_visited, label="Origins never visited") + visit_timestamps, n_visits = zip(*self.visit_metrics) + figure.scatter(visit_timestamps, n_visits, label="Visits over time") + return figure.show(legend=True) - def format(self): + def metrics_csv(self, fobj: TextIO) -> None: + """Export scheduling metrics in a csv file""" + csv_writer = csv.writer(fobj) + csv_writer.writerow( + [ + "timestamp", + "known_origins", + "enabled_origins", + "never_visited_origins", + "origins_with_pending_changes", + ] + ) + + timestamps, metric_lists = zip(*self.scheduler_metrics) + known = (sum(m.origins_known for m in metrics) for metrics in metric_lists) + enabled = (sum(m.origins_enabled for m in metrics) for metrics in metric_lists) + never_visited = ( + sum(m.origins_never_visited for m in metrics) for metrics in metric_lists + ) + pending_changes = ( + sum(m.origins_with_pending_changes for m in metrics) + for metrics in metric_lists + ) + csv_writer.writerows( + zip(timestamps, known, enabled, never_visited, pending_changes) + ) + + def format(self, with_plots=True): full_visits = self.visit_runtimes.get(("full", True), []) - histogram = self.runtime_histogram("full", True) - plot = self.metrics_plot() long_tasks = sum(runtime > self.DURATION_THRESHOLD for runtime in full_visits) - return ( - textwrap.dedent( - f"""\ + output = textwrap.dedent( + f"""\ Total visits: {self.total_visits} - Useless visits: {self.useless_visits} + Uneventful visits: {self.uneventful_visits} Eventful visits: {len(full_visits)} Very long running tasks: {long_tasks} - Visit time histogram for eventful visits: - """ - ) - + histogram - + "\n" - + textwrap.dedent( - """\ - Metrics over time: """ - ) - + plot ) + if with_plots: + histogram = self.runtime_histogram("full", True) + plot = self.metrics_plot() + output += ( + "Visit time histogram for eventful visits:" + + histogram + + "\n" + + textwrap.dedent( + """\ + Metrics over time: + """ + ) + + plot + ) + return output @dataclass class Task: visit_type: str origin: str backend_id: uuid.UUID = field(default_factory=uuid.uuid4) @dataclass class TaskEvent: task: Task status: OriginVisitStatus eventful: bool = field(default=False) class Queue(Store): """Model a queue of objects to be passed between processes.""" def __len__(self): return len(self.items or []) def slots_remaining(self): return self.capacity - len(self) class Environment(_Environment): report: SimulationReport scheduler: SchedulerInterface def __init__(self, start_time: datetime): if start_time.tzinfo is None: raise ValueError("start_time must have timezone information") self.start_time = start_time super().__init__() @property def time(self): """Get the current simulated wall clock time""" return self.start_time + timedelta(seconds=self.now) diff --git a/swh/scheduler/simulator/origin_scheduler.py b/swh/scheduler/simulator/origin_scheduler.py index 3b9d59a..ca16912 100644 --- a/swh/scheduler/simulator/origin_scheduler.py +++ b/swh/scheduler/simulator/origin_scheduler.py @@ -1,68 +1,68 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """Agents using the new origin-aware scheduler.""" import logging from typing import Any, Dict, Generator, Iterator, List from simpy import Event from swh.scheduler.journal_client import process_journal_objects from .common import Environment, Queue, Task, TaskEvent logger = logging.getLogger(__name__) def scheduler_runner_process( env: Environment, task_queues: Dict[str, Queue], policy: str, min_batch_size: int ) -> Iterator[Event]: """Scheduler runner. Grabs next visits from the database according to the scheduling policy, and fills the task_queues accordingly.""" while True: for visit_type, queue in task_queues.items(): remaining = queue.slots_remaining() if remaining < min_batch_size: continue next_origins = env.scheduler.grab_next_visits( - visit_type, remaining, policy=policy + visit_type, remaining, policy=policy, timestamp=env.time ) logger.debug( "%s runner: running %s %s tasks", env.time, visit_type, len(next_origins), ) for origin in next_origins: yield queue.put(Task(visit_type=origin.visit_type, origin=origin.url)) yield env.timeout(10.0) def scheduler_journal_client_process( env: Environment, status_queue: Queue ) -> Generator[Event, TaskEvent, None]: """Scheduler journal client. Every once in a while, pulls `OriginVisitStatus`es from the status_queue to update the scheduler origin_visit_stats table.""" BATCH_SIZE = 100 statuses: List[Dict[str, Any]] = [] while True: task_event = yield status_queue.get() statuses.append(task_event.status.to_dict()) if len(statuses) < BATCH_SIZE: continue logger.debug( "%s journal client: processing %s statuses", env.time, len(statuses) ) process_journal_objects( {"origin_visit_status": statuses}, scheduler=env.scheduler ) statuses = [] diff --git a/swh/scheduler/simulator/origins.py b/swh/scheduler/simulator/origins.py index 3320790..a0ec081 100644 --- a/swh/scheduler/simulator/origins.py +++ b/swh/scheduler/simulator/origins.py @@ -1,130 +1,227 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """This module implements a model of the frequency of updates of an origin -and how long it takes to load it.""" +and how long it takes to load it. -from datetime import timedelta +For each origin, a commit frequency is chosen deterministically based on the +hash of its URL and assume all origins were created on an arbitrary epoch. +From this we compute a number of commits, that is the product of these two. + +And the run time of a load task is approximated as proportional to the number +of commits since the previous visit of the origin (possibly 0).""" + +from datetime import datetime, timedelta, timezone import hashlib import logging -import os -from typing import Iterator, Optional, Tuple +from typing import Dict, Generator, Iterator, List, Optional, Tuple +import uuid import attr from simpy import Event from swh.model.model import OriginVisitStatus -from swh.scheduler.model import OriginVisitStats +from swh.scheduler.model import ListedOrigin from .common import Environment, Queue, Task, TaskEvent logger = logging.getLogger(__name__) +_nb_generated_origins = 0 +_visit_times: Dict[Tuple[str, str], datetime] = {} +"""Cache of the time of the last visit of (visit_type, origin_url), +to spare an SQL query (high latency).""" + + +def generate_listed_origin( + lister_id: uuid.UUID, now: Optional[datetime] = None +) -> ListedOrigin: + """Returns a globally unique new origin. Seed the `last_update` value + according to the OriginModel and the passed timestamp. + + Arguments: + lister: instance of the lister that generated this origin + now: time of listing, to emulate last_update (defaults to :func:`datetime.now`) + """ + global _nb_generated_origins + _nb_generated_origins += 1 + assert _nb_generated_origins < 10 ** 6, "Too many origins!" + + if now is None: + now = datetime.now(tz=timezone.utc) + + url = f"https://example.com/{_nb_generated_origins:06d}.git" + visit_type = "git" + origin = OriginModel(visit_type, url) + + return ListedOrigin( + lister_id=lister_id, + url=url, + visit_type=visit_type, + last_update=origin.get_last_update(now), + ) + + class OriginModel: MIN_RUN_TIME = 0.5 """Minimal run time for a visit (retrieved from production data)""" MAX_RUN_TIME = 7200 """Max run time for a visit""" PER_COMMIT_RUN_TIME = 0.1 """Run time per commit""" + EPOCH = datetime(2015, 9, 1, 0, 0, 0, tzinfo=timezone.utc) + """The origin of all origins (at least according to Software Heritage)""" + def __init__(self, type: str, origin: str): self.type = type self.origin = origin def seconds_between_commits(self): """Returns a random 'average time between two commits' of this origin, used to estimate the run time of a load task, and how much the loading architecture is lagging behind origin updates.""" n_bytes = 2 num_buckets = 2 ** (8 * n_bytes) # Deterministic seed to generate "random" characteristics of this origin bucket = int.from_bytes( hashlib.md5(self.origin.encode()).digest()[0:n_bytes], "little" ) # minimum: 1 second (bucket == 0) # max: 10 years (bucket == num_buckets - 1) ten_y = 10 * 365 * 24 * 3600 return ten_y ** (bucket / num_buckets) # return 1 + (ten_y - 1) * (bucket / (num_buckets - 1)) + def get_last_update(self, now: datetime) -> datetime: + """Get the last_update value for this origin. + + We assume that the origin had its first commit at `EPOCH`, and that one + commit happened every `self.seconds_between_commits()`. This returns + the last commit date before or equal to `now`. + """ + _, time_since_last_commit = divmod( + (now - self.EPOCH).total_seconds(), self.seconds_between_commits() + ) + + return now - timedelta(seconds=time_since_last_commit) + + def get_current_snapshot_id(self, now: datetime) -> bytes: + """Get the current snapshot for this origin. + + To generate a snapshot id, we calculate the number of commits since the + EPOCH, and hash it alongside the origin type and url. + """ + commits_since_epoch, _ = divmod( + (now - self.EPOCH).total_seconds(), self.seconds_between_commits() + ) + return hashlib.sha1( + f"{self.type} {self.origin} {commits_since_epoch}".encode() + ).digest() + def load_task_characteristics( - self, env: Environment, stats: Optional[OriginVisitStats] - ) -> Tuple[float, bool, str]: - """Returns the (run_time, eventfulness, end_status) of the next + self, now: datetime + ) -> Tuple[float, str, Optional[bytes]]: + """Returns the (run_time, end_status, snapshot id) of the next origin visit.""" - if stats and stats.last_eventful: - time_since_last_successful_run = env.time - stats.last_eventful - else: - time_since_last_successful_run = timedelta(days=365) + current_snapshot = self.get_current_snapshot_id(now) + + key = (self.type, self.origin) + last_visit = _visit_times.get(key, now - timedelta(days=365)) + time_since_last_successful_run = now - last_visit + + _visit_times[key] = now seconds_between_commits = self.seconds_between_commits() + seconds_since_last_successful = time_since_last_successful_run.total_seconds() + n_commits = int(seconds_since_last_successful / seconds_between_commits) logger.debug( - "Interval between commits: %s", timedelta(seconds=seconds_between_commits) + "%s characteristics %s origin=%s: Interval: %s, n_commits: %s", + now, + self.type, + self.origin, + timedelta(seconds=seconds_between_commits), + n_commits, ) - seconds_since_last_successful = time_since_last_successful_run.total_seconds() - if seconds_since_last_successful < seconds_between_commits: - # No commits since last visit => uneventful - return (self.MIN_RUN_TIME, False, "full") + run_time = self.MIN_RUN_TIME + self.PER_COMMIT_RUN_TIME * n_commits + if run_time > self.MAX_RUN_TIME: + # Long visits usually fail + return (self.MAX_RUN_TIME, "partial", None) else: - n_commits = seconds_since_last_successful / seconds_between_commits - run_time = self.MIN_RUN_TIME + self.PER_COMMIT_RUN_TIME * n_commits - if run_time > self.MAX_RUN_TIME: - return (self.MAX_RUN_TIME, False, "partial") - else: - return (run_time, True, "full") + return (run_time, "full", current_snapshot) + + +def lister_process( + env: Environment, lister_id: uuid.UUID +) -> Generator[Event, Event, None]: + """Every hour, generate new origins and update the `last_update` field for + the ones this process generated in the past""" + NUM_NEW_ORIGINS = 100 + + origins: List[ListedOrigin] = [] + + while True: + updated_origins = [] + for origin in origins: + model = OriginModel(origin.visit_type, origin.url) + updated_origins.append( + attr.evolve(origin, last_update=model.get_last_update(env.time)) + ) + + origins = updated_origins + origins.extend( + generate_listed_origin(lister_id, now=env.time) + for _ in range(NUM_NEW_ORIGINS) + ) + + env.scheduler.record_listed_origins(origins) + + yield env.timeout(3600) def load_task_process( env: Environment, task: Task, status_queue: Queue ) -> Iterator[Event]: """A loading task. This pushes OriginVisitStatus objects to the status_queue to simulate the visible outcomes of the task. Uses the `load_task_duration` function to determine its run time. """ - # This is cheating; actual tasks access the state from the storage, not the - # scheduler - pk = task.origin, task.visit_type - visit_stats = env.scheduler.origin_visit_stats_get([pk]) - stats: Optional[OriginVisitStats] = visit_stats[0] if len(visit_stats) > 0 else None - last_snapshot = stats.last_snapshot if stats else None - status = OriginVisitStatus( origin=task.origin, visit=42, type=task.visit_type, status="created", date=env.time, snapshot=None, ) - origin_model = OriginModel(task.visit_type, task.origin) - (run_time, eventful, end_status) = origin_model.load_task_characteristics( - env, stats - ) logger.debug("%s task %s origin=%s: Start", env.time, task.visit_type, task.origin) yield status_queue.put(TaskEvent(task=task, status=status)) + + origin_model = OriginModel(task.visit_type, task.origin) + (run_time, end_status, snapshot) = origin_model.load_task_characteristics(env.time) yield env.timeout(run_time) + logger.debug("%s task %s origin=%s: End", env.time, task.visit_type, task.origin) - new_snapshot = os.urandom(20) if eventful else last_snapshot yield status_queue.put( TaskEvent( task=task, status=attr.evolve( - status, status=end_status, date=env.time, snapshot=new_snapshot + status, status=end_status, date=env.time, snapshot=snapshot ), - eventful=eventful, ) ) - env.report.record_visit(run_time, eventful, end_status) + env.report.record_visit( + (task.visit_type, task.origin), run_time, end_status, snapshot + ) diff --git a/swh/scheduler/tests/test_scheduler.py b/swh/scheduler/tests/test_scheduler.py index 8a03624..c7c27c5 100644 --- a/swh/scheduler/tests/test_scheduler.py +++ b/swh/scheduler/tests/test_scheduler.py @@ -1,1216 +1,1235 @@ # Copyright (C) 2017-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import defaultdict import copy import datetime import inspect import random from typing import Any, Dict, List, Optional, Tuple import uuid import attr import pytest from swh.model.hashutil import hash_to_bytes from swh.scheduler.exc import SchedulerException, StaleData, UnknownPolicy from swh.scheduler.interface import ListedOriginPageToken, SchedulerInterface from swh.scheduler.model import ListedOrigin, OriginVisitStats, SchedulerMetrics from swh.scheduler.utils import utcnow from .common import LISTERS, TASK_TYPES, TEMPLATES, tasks_from_template ONEDAY = datetime.timedelta(days=1) def subdict(d, keys=None, excl=()): if keys is None: keys = [k for k in d.keys()] return {k: d[k] for k in keys if k not in excl} def metrics_sort_key(m: SchedulerMetrics) -> Tuple[uuid.UUID, str]: return (m.lister_id, m.visit_type) def assert_metrics_equal(left, right): assert sorted(left, key=metrics_sort_key) == sorted(right, key=metrics_sort_key) class TestScheduler: def test_interface(self, swh_scheduler): """Checks all methods of SchedulerInterface are implemented by this backend, and that they have the same signature.""" # Create an instance of the protocol (which cannot be instantiated # directly, so this creates a subclass, then instantiates it) interface = type("_", (SchedulerInterface,), {})() assert "create_task_type" in dir(interface) missing_methods = [] for meth_name in dir(interface): if meth_name.startswith("_"): continue interface_meth = getattr(interface, meth_name) try: concrete_meth = getattr(swh_scheduler, meth_name) except AttributeError: if not getattr(interface_meth, "deprecated_endpoint", False): # The backend is missing a (non-deprecated) endpoint missing_methods.append(meth_name) continue expected_signature = inspect.signature(interface_meth) actual_signature = inspect.signature(concrete_meth) assert expected_signature == actual_signature, meth_name assert missing_methods == [] def test_get_priority_ratios(self, swh_scheduler): assert swh_scheduler.get_priority_ratios() == { "high": 0.5, "normal": 0.3, "low": 0.2, } def test_add_task_type(self, swh_scheduler): tt = TASK_TYPES["git"] swh_scheduler.create_task_type(tt) assert tt == swh_scheduler.get_task_type(tt["type"]) tt2 = TASK_TYPES["hg"] swh_scheduler.create_task_type(tt2) assert tt == swh_scheduler.get_task_type(tt["type"]) assert tt2 == swh_scheduler.get_task_type(tt2["type"]) def test_create_task_type_idempotence(self, swh_scheduler): tt = TASK_TYPES["git"] swh_scheduler.create_task_type(tt) swh_scheduler.create_task_type(tt) assert tt == swh_scheduler.get_task_type(tt["type"]) def test_get_task_types(self, swh_scheduler): tt, tt2 = TASK_TYPES["git"], TASK_TYPES["hg"] swh_scheduler.create_task_type(tt) swh_scheduler.create_task_type(tt2) actual_task_types = swh_scheduler.get_task_types() assert tt in actual_task_types assert tt2 in actual_task_types def test_create_tasks(self, swh_scheduler): priority_ratio = self._priority_ratio(swh_scheduler) self._create_task_types(swh_scheduler) num_tasks_priority = 100 tasks_1 = tasks_from_template(TEMPLATES["git"], utcnow(), 100) tasks_2 = tasks_from_template( TEMPLATES["hg"], utcnow(), 100, num_tasks_priority, priorities=priority_ratio, ) tasks = tasks_1 + tasks_2 # tasks are returned only once with their ids ret1 = swh_scheduler.create_tasks(tasks + tasks_1 + tasks_2) set_ret1 = set([t["id"] for t in ret1]) # creating the same set result in the same ids ret = swh_scheduler.create_tasks(tasks) set_ret = set([t["id"] for t in ret]) # Idempotence results assert set_ret == set_ret1 assert len(ret) == len(ret1) ids = set() actual_priorities = defaultdict(int) for task, orig_task in zip(ret, tasks): task = copy.deepcopy(task) task_type = TASK_TYPES[orig_task["type"].split("-")[-1]] assert task["id"] not in ids assert task["status"] == "next_run_not_scheduled" assert task["current_interval"] == task_type["default_interval"] assert task["policy"] == orig_task.get("policy", "recurring") priority = task.get("priority") if priority: actual_priorities[priority] += 1 assert task["retries_left"] == (task_type["num_retries"] or 0) ids.add(task["id"]) del task["id"] del task["status"] del task["current_interval"] del task["retries_left"] if "policy" not in orig_task: del task["policy"] if "priority" not in orig_task: del task["priority"] assert task == orig_task assert dict(actual_priorities) == { priority: int(ratio * num_tasks_priority) for priority, ratio in priority_ratio.items() } def test_peek_ready_tasks_no_priority(self, swh_scheduler): self._create_task_types(swh_scheduler) t = utcnow() task_type = TEMPLATES["git"]["type"] tasks = tasks_from_template(TEMPLATES["git"], t, 100) random.shuffle(tasks) swh_scheduler.create_tasks(tasks) ready_tasks = swh_scheduler.peek_ready_tasks(task_type) assert len(ready_tasks) == len(tasks) for i in range(len(ready_tasks) - 1): assert ready_tasks[i]["next_run"] <= ready_tasks[i + 1]["next_run"] # Only get the first few ready tasks limit = random.randrange(5, 5 + len(tasks) // 2) ready_tasks_limited = swh_scheduler.peek_ready_tasks(task_type, num_tasks=limit) assert len(ready_tasks_limited) == limit assert ready_tasks_limited == ready_tasks[:limit] # Limit by timestamp max_ts = tasks[limit - 1]["next_run"] ready_tasks_timestamped = swh_scheduler.peek_ready_tasks( task_type, timestamp=max_ts ) for ready_task in ready_tasks_timestamped: assert ready_task["next_run"] <= max_ts # Make sure we get proper behavior for the first ready tasks assert ready_tasks[: len(ready_tasks_timestamped)] == ready_tasks_timestamped # Limit by both ready_tasks_both = swh_scheduler.peek_ready_tasks( task_type, timestamp=max_ts, num_tasks=limit // 3 ) assert len(ready_tasks_both) <= limit // 3 for ready_task in ready_tasks_both: assert ready_task["next_run"] <= max_ts assert ready_task in ready_tasks[: limit // 3] def _priority_ratio(self, swh_scheduler): return swh_scheduler.get_priority_ratios() def test_peek_ready_tasks_mixed_priorities(self, swh_scheduler): priority_ratio = self._priority_ratio(swh_scheduler) self._create_task_types(swh_scheduler) t = utcnow() task_type = TEMPLATES["git"]["type"] num_tasks_priority = 100 num_tasks_no_priority = 100 # Create tasks with and without priorities tasks = tasks_from_template( TEMPLATES["git"], t, num=num_tasks_no_priority, num_priority=num_tasks_priority, priorities=priority_ratio, ) random.shuffle(tasks) swh_scheduler.create_tasks(tasks) # take all available tasks ready_tasks = swh_scheduler.peek_ready_tasks(task_type) assert len(ready_tasks) == len(tasks) assert num_tasks_priority + num_tasks_no_priority == len(ready_tasks) count_tasks_per_priority = defaultdict(int) for task in ready_tasks: priority = task.get("priority") if priority: count_tasks_per_priority[priority] += 1 assert dict(count_tasks_per_priority) == { priority: int(ratio * num_tasks_priority) for priority, ratio in priority_ratio.items() } # Only get some ready tasks num_tasks = random.randrange(5, 5 + num_tasks_no_priority // 2) num_tasks_priority = random.randrange(5, num_tasks_priority // 2) ready_tasks_limited = swh_scheduler.peek_ready_tasks( task_type, num_tasks=num_tasks, num_tasks_priority=num_tasks_priority ) count_tasks_per_priority = defaultdict(int) for task in ready_tasks_limited: priority = task.get("priority") count_tasks_per_priority[priority] += 1 import math for priority, ratio in priority_ratio.items(): expected_count = math.ceil(ratio * num_tasks_priority) actual_prio = count_tasks_per_priority[priority] assert actual_prio == expected_count or actual_prio == expected_count + 1 assert count_tasks_per_priority[None] == num_tasks def test_grab_ready_tasks(self, swh_scheduler): priority_ratio = self._priority_ratio(swh_scheduler) self._create_task_types(swh_scheduler) t = utcnow() task_type = TEMPLATES["git"]["type"] num_tasks_priority = 100 num_tasks_no_priority = 100 # Create tasks with and without priorities tasks = tasks_from_template( TEMPLATES["git"], t, num=num_tasks_no_priority, num_priority=num_tasks_priority, priorities=priority_ratio, ) random.shuffle(tasks) swh_scheduler.create_tasks(tasks) first_ready_tasks = swh_scheduler.peek_ready_tasks( task_type, num_tasks=10, num_tasks_priority=10 ) grabbed_tasks = swh_scheduler.grab_ready_tasks( task_type, num_tasks=10, num_tasks_priority=10 ) for peeked, grabbed in zip(first_ready_tasks, grabbed_tasks): assert peeked["status"] == "next_run_not_scheduled" del peeked["status"] assert grabbed["status"] == "next_run_scheduled" del grabbed["status"] assert peeked == grabbed assert peeked["priority"] == grabbed["priority"] def test_get_tasks(self, swh_scheduler): self._create_task_types(swh_scheduler) t = utcnow() tasks = tasks_from_template(TEMPLATES["git"], t, 100) tasks = swh_scheduler.create_tasks(tasks) random.shuffle(tasks) while len(tasks) > 1: length = random.randrange(1, len(tasks)) cur_tasks = sorted(tasks[:length], key=lambda x: x["id"]) tasks[:length] = [] ret = swh_scheduler.get_tasks(task["id"] for task in cur_tasks) # result is not guaranteed to be sorted ret.sort(key=lambda x: x["id"]) assert ret == cur_tasks def test_search_tasks(self, swh_scheduler): def make_real_dicts(lst): """RealDictRow is not a real dict.""" return [dict(d.items()) for d in lst] self._create_task_types(swh_scheduler) t = utcnow() tasks = tasks_from_template(TEMPLATES["git"], t, 100) tasks = swh_scheduler.create_tasks(tasks) assert make_real_dicts(swh_scheduler.search_tasks()) == make_real_dicts(tasks) def assert_filtered_task_ok( self, task: Dict[str, Any], after: datetime.datetime, before: datetime.datetime ) -> None: """Ensure filtered tasks have the right expected properties (within the range, recurring disabled, etc..) """ started = task["started"] date = started if started is not None else task["scheduled"] assert after <= date and date <= before if task["task_policy"] == "oneshot": assert task["task_status"] in ["completed", "disabled"] if task["task_policy"] == "recurring": assert task["task_status"] in ["disabled"] def test_filter_task_to_archive(self, swh_scheduler): """Filtering only list disabled recurring or completed oneshot tasks """ self._create_task_types(swh_scheduler) _time = utcnow() recurring = tasks_from_template(TEMPLATES["git"], _time, 12) oneshots = tasks_from_template(TEMPLATES["hg"], _time, 12) total_tasks = len(recurring) + len(oneshots) # simulate scheduling tasks pending_tasks = swh_scheduler.create_tasks(recurring + oneshots) backend_tasks = [ { "task": task["id"], "backend_id": str(uuid.uuid4()), "scheduled": utcnow(), } for task in pending_tasks ] swh_scheduler.mass_schedule_task_runs(backend_tasks) # we simulate the task are being done _tasks = [] for task in backend_tasks: t = swh_scheduler.end_task_run(task["backend_id"], status="eventful") _tasks.append(t) # Randomly update task's status per policy status_per_policy = {"recurring": 0, "oneshot": 0} status_choice = { # policy: [tuple (1-for-filtering, 'associated-status')] "recurring": [ (1, "disabled"), (0, "completed"), (0, "next_run_not_scheduled"), ], "oneshot": [ (0, "next_run_not_scheduled"), (1, "disabled"), (1, "completed"), ], } tasks_to_update = defaultdict(list) _task_ids = defaultdict(list) # randomize 'disabling' recurring task or 'complete' oneshot task for task in pending_tasks: policy = task["policy"] _task_ids[policy].append(task["id"]) status = random.choice(status_choice[policy]) if status[0] != 1: continue # elected for filtering status_per_policy[policy] += status[0] tasks_to_update[policy].append(task["id"]) swh_scheduler.disable_tasks(tasks_to_update["recurring"]) # hack: change the status to something else than completed/disabled swh_scheduler.set_status_tasks( _task_ids["oneshot"], status="next_run_not_scheduled" ) # complete the tasks to update swh_scheduler.set_status_tasks(tasks_to_update["oneshot"], status="completed") total_tasks_filtered = ( status_per_policy["recurring"] + status_per_policy["oneshot"] ) # no pagination scenario # retrieve tasks to archive after = _time - ONEDAY after_ts = after.strftime("%Y-%m-%d") before = utcnow() + ONEDAY before_ts = before.strftime("%Y-%m-%d") tasks_result = swh_scheduler.filter_task_to_archive( after_ts=after_ts, before_ts=before_ts, limit=total_tasks ) tasks_to_archive = tasks_result["tasks"] assert len(tasks_to_archive) == total_tasks_filtered assert tasks_result.get("next_page_token") is None actual_filtered_per_status = {"recurring": 0, "oneshot": 0} for task in tasks_to_archive: self.assert_filtered_task_ok(task, after, before) actual_filtered_per_status[task["task_policy"]] += 1 assert actual_filtered_per_status == status_per_policy # pagination scenario nb_tasks = 3 tasks_result = swh_scheduler.filter_task_to_archive( after_ts=after_ts, before_ts=before_ts, limit=nb_tasks ) tasks_to_archive2 = tasks_result["tasks"] assert len(tasks_to_archive2) == nb_tasks next_page_token = tasks_result["next_page_token"] assert next_page_token is not None all_tasks = tasks_to_archive2 while next_page_token is not None: # Retrieve paginated results tasks_result = swh_scheduler.filter_task_to_archive( after_ts=after_ts, before_ts=before_ts, limit=nb_tasks, page_token=next_page_token, ) tasks_to_archive2 = tasks_result["tasks"] assert len(tasks_to_archive2) <= nb_tasks all_tasks.extend(tasks_to_archive2) next_page_token = tasks_result.get("next_page_token") actual_filtered_per_status = {"recurring": 0, "oneshot": 0} for task in all_tasks: self.assert_filtered_task_ok(task, after, before) actual_filtered_per_status[task["task_policy"]] += 1 assert actual_filtered_per_status == status_per_policy def test_delete_archived_tasks(self, swh_scheduler): self._create_task_types(swh_scheduler) _time = utcnow() recurring = tasks_from_template(TEMPLATES["git"], _time, 12) oneshots = tasks_from_template(TEMPLATES["hg"], _time, 12) total_tasks = len(recurring) + len(oneshots) pending_tasks = swh_scheduler.create_tasks(recurring + oneshots) backend_tasks = [ { "task": task["id"], "backend_id": str(uuid.uuid4()), "scheduled": utcnow(), } for task in pending_tasks ] swh_scheduler.mass_schedule_task_runs(backend_tasks) _tasks = [] percent = random.randint(0, 100) # random election removal boundary for task in backend_tasks: t = swh_scheduler.end_task_run(task["backend_id"], status="eventful") c = random.randint(0, 100) if c <= percent: _tasks.append({"task_id": t["task"], "task_run_id": t["id"]}) swh_scheduler.delete_archived_tasks(_tasks) all_tasks = [task["id"] for task in swh_scheduler.search_tasks()] tasks_count = len(all_tasks) tasks_run_count = len(swh_scheduler.get_task_runs(all_tasks)) assert tasks_count == total_tasks - len(_tasks) assert tasks_run_count == total_tasks - len(_tasks) def test_get_task_runs_no_task(self, swh_scheduler): """No task exist in the scheduler's db, get_task_runs() should always return an empty list. """ assert not swh_scheduler.get_task_runs(task_ids=()) assert not swh_scheduler.get_task_runs(task_ids=(1, 2, 3)) assert not swh_scheduler.get_task_runs(task_ids=(1, 2, 3), limit=10) def test_get_task_runs_no_task_executed(self, swh_scheduler): """No task has been executed yet, get_task_runs() should always return an empty list. """ self._create_task_types(swh_scheduler) _time = utcnow() recurring = tasks_from_template(TEMPLATES["git"], _time, 12) oneshots = tasks_from_template(TEMPLATES["hg"], _time, 12) swh_scheduler.create_tasks(recurring + oneshots) assert not swh_scheduler.get_task_runs(task_ids=()) assert not swh_scheduler.get_task_runs(task_ids=(1, 2, 3)) assert not swh_scheduler.get_task_runs(task_ids=(1, 2, 3), limit=10) def test_get_task_runs_with_scheduled(self, swh_scheduler): """Some tasks have been scheduled but not executed yet, get_task_runs() should not return an empty list. limit should behave as expected. """ self._create_task_types(swh_scheduler) _time = utcnow() recurring = tasks_from_template(TEMPLATES["git"], _time, 12) oneshots = tasks_from_template(TEMPLATES["hg"], _time, 12) total_tasks = len(recurring) + len(oneshots) pending_tasks = swh_scheduler.create_tasks(recurring + oneshots) backend_tasks = [ { "task": task["id"], "backend_id": str(uuid.uuid4()), "scheduled": utcnow(), } for task in pending_tasks ] swh_scheduler.mass_schedule_task_runs(backend_tasks) assert not swh_scheduler.get_task_runs(task_ids=[total_tasks + 1]) btask = backend_tasks[0] runs = swh_scheduler.get_task_runs(task_ids=[btask["task"]]) assert len(runs) == 1 run = runs[0] assert subdict(run, excl=("id",)) == { "task": btask["task"], "backend_id": btask["backend_id"], "scheduled": btask["scheduled"], "started": None, "ended": None, "metadata": None, "status": "scheduled", } runs = swh_scheduler.get_task_runs( task_ids=[bt["task"] for bt in backend_tasks], limit=2 ) assert len(runs) == 2 runs = swh_scheduler.get_task_runs( task_ids=[bt["task"] for bt in backend_tasks] ) assert len(runs) == total_tasks keys = ("task", "backend_id", "scheduled") assert ( sorted([subdict(x, keys) for x in runs], key=lambda x: x["task"]) == backend_tasks ) def test_get_task_runs_with_executed(self, swh_scheduler): """Some tasks have been executed, get_task_runs() should not return an empty list. limit should behave as expected. """ self._create_task_types(swh_scheduler) _time = utcnow() recurring = tasks_from_template(TEMPLATES["git"], _time, 12) oneshots = tasks_from_template(TEMPLATES["hg"], _time, 12) pending_tasks = swh_scheduler.create_tasks(recurring + oneshots) backend_tasks = [ { "task": task["id"], "backend_id": str(uuid.uuid4()), "scheduled": utcnow(), } for task in pending_tasks ] swh_scheduler.mass_schedule_task_runs(backend_tasks) btask = backend_tasks[0] ts = utcnow() swh_scheduler.start_task_run( btask["backend_id"], metadata={"something": "stupid"}, timestamp=ts ) runs = swh_scheduler.get_task_runs(task_ids=[btask["task"]]) assert len(runs) == 1 assert subdict(runs[0], excl=("id")) == { "task": btask["task"], "backend_id": btask["backend_id"], "scheduled": btask["scheduled"], "started": ts, "ended": None, "metadata": {"something": "stupid"}, "status": "started", } ts2 = utcnow() swh_scheduler.end_task_run( btask["backend_id"], metadata={"other": "stuff"}, timestamp=ts2, status="eventful", ) runs = swh_scheduler.get_task_runs(task_ids=[btask["task"]]) assert len(runs) == 1 assert subdict(runs[0], excl=("id")) == { "task": btask["task"], "backend_id": btask["backend_id"], "scheduled": btask["scheduled"], "started": ts, "ended": ts2, "metadata": {"something": "stupid", "other": "stuff"}, "status": "eventful", } def test_get_or_create_lister(self, swh_scheduler): db_listers = [] for lister_args in LISTERS: db_listers.append(swh_scheduler.get_or_create_lister(**lister_args)) for lister, lister_args in zip(db_listers, LISTERS): assert lister.name == lister_args["name"] assert lister.instance_name == lister_args.get("instance_name", "") lister_get_again = swh_scheduler.get_or_create_lister( lister.name, lister.instance_name ) assert lister == lister_get_again def test_get_lister(self, swh_scheduler): for lister_args in LISTERS: assert swh_scheduler.get_lister(**lister_args) is None db_listers = [] for lister_args in LISTERS: db_listers.append(swh_scheduler.get_or_create_lister(**lister_args)) for lister, lister_args in zip(db_listers, LISTERS): lister_get_again = swh_scheduler.get_lister( lister.name, lister.instance_name ) assert lister == lister_get_again def test_update_lister(self, swh_scheduler, stored_lister): lister = attr.evolve(stored_lister, current_state={"updated": "now"}) updated_lister = swh_scheduler.update_lister(lister) assert updated_lister.updated > lister.updated assert updated_lister == attr.evolve(lister, updated=updated_lister.updated) def test_update_lister_stale(self, swh_scheduler, stored_lister): swh_scheduler.update_lister(stored_lister) with pytest.raises(StaleData) as exc: swh_scheduler.update_lister(stored_lister) assert "state not updated" in exc.value.args[0] def test_record_listed_origins(self, swh_scheduler, listed_origins): ret = swh_scheduler.record_listed_origins(listed_origins) assert set(returned.url for returned in ret) == set( origin.url for origin in listed_origins ) assert all(origin.first_seen == origin.last_seen for origin in ret) def test_record_listed_origins_upsert(self, swh_scheduler, listed_origins): # First, insert `cutoff` origins cutoff = 100 assert cutoff < len(listed_origins) ret = swh_scheduler.record_listed_origins(listed_origins[:cutoff]) assert len(ret) == cutoff # Then, insert all origins, including the `cutoff` first. ret = swh_scheduler.record_listed_origins(listed_origins) assert len(ret) == len(listed_origins) # Two different "first seen" values assert len(set(origin.first_seen for origin in ret)) == 2 # But a single "last seen" value assert len(set(origin.last_seen for origin in ret)) == 1 def test_get_listed_origins_exact(self, swh_scheduler, listed_origins): swh_scheduler.record_listed_origins(listed_origins) for i, origin in enumerate(listed_origins): ret = swh_scheduler.get_listed_origins( lister_id=origin.lister_id, url=origin.url ) assert ret.next_page_token is None assert len(ret.results) == 1 assert ret.results[0].lister_id == origin.lister_id assert ret.results[0].url == origin.url @pytest.mark.parametrize("num_origins,limit", [(20, 6), (5, 42), (20, 20)]) def test_get_listed_origins_limit( self, swh_scheduler, listed_origins, num_origins, limit ) -> None: added_origins = sorted( listed_origins[:num_origins], key=lambda o: (o.lister_id, o.url) ) swh_scheduler.record_listed_origins(added_origins) returned_origins: List[ListedOrigin] = [] call_count = 0 next_page_token: Optional[ListedOriginPageToken] = None while True: call_count += 1 ret = swh_scheduler.get_listed_origins( lister_id=listed_origins[0].lister_id, limit=limit, page_token=next_page_token, ) returned_origins.extend(ret.results) next_page_token = ret.next_page_token if next_page_token is None: break assert call_count == (num_origins // limit) + 1 assert len(returned_origins) == num_origins assert [(origin.lister_id, origin.url) for origin in returned_origins] == [ (origin.lister_id, origin.url) for origin in added_origins ] def test_get_listed_origins_all(self, swh_scheduler, listed_origins) -> None: swh_scheduler.record_listed_origins(listed_origins) ret = swh_scheduler.get_listed_origins(limit=len(listed_origins) + 1) assert ret.next_page_token is None assert len(ret.results) == len(listed_origins) def _grab_next_visits_setup(self, swh_scheduler, listed_origins_by_type): """Basic origins setup for scheduling policy tests""" visit_type = next(iter(listed_origins_by_type)) origins = listed_origins_by_type[visit_type][:100] assert len(origins) > 0 recorded_origins = swh_scheduler.record_listed_origins(origins) return visit_type, recorded_origins def _check_grab_next_visit(self, swh_scheduler, visit_type, policy, expected): """Calls grab_next_visits with the passed policy, and check that all the origins returned are the expected ones (in the same order), and that no extra origins are returned. Also checks the origin visits have - been marked as scheduled.""" + been marked as scheduled, and are only re-scheduled a week later""" assert len(expected) != 0 before = utcnow() ret = swh_scheduler.grab_next_visits( visit_type=visit_type, # Request one more than expected to check that no extra origin is returned count=len(expected) + 1, policy=policy, ) after = utcnow() assert ret == expected visit_stats_list = swh_scheduler.origin_visit_stats_get( [(origin.url, origin.visit_type) for origin in expected] ) assert len(visit_stats_list) == len(expected) for visit_stats in visit_stats_list: # Check that last_scheduled got updated assert before <= visit_stats.last_scheduled <= after + # They should not be scheduled again + ret = swh_scheduler.grab_next_visits( + visit_type=visit_type, count=len(expected) + 1, policy=policy, + ) + assert ret == [], "grab_next_visits returned already-scheduled origins" + + # But a week later, they should + ret = swh_scheduler.grab_next_visits( + visit_type=visit_type, + count=len(expected) + 1, + policy=policy, + timestamp=after + datetime.timedelta(days=7), + ) + # We need to sort them because their 'last_scheduled' field is updated to + # exactly the same value, so the order is not deterministic + assert sorted(ret) == sorted( + expected + ), "grab_next_visits didn't reschedule visits after a week" + def test_grab_next_visits_oldest_scheduled_first( self, swh_scheduler, listed_origins_by_type, ): visit_type, origins = self._grab_next_visits_setup( swh_scheduler, listed_origins_by_type ) # Give all origins but one a last_scheduled date base_date = datetime.datetime(2020, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc) visit_stats = [ OriginVisitStats( url=origin.url, visit_type=origin.visit_type, last_snapshot=None, last_eventful=None, last_uneventful=None, last_failed=None, last_notfound=None, last_scheduled=base_date - datetime.timedelta(seconds=i), ) for i, origin in enumerate(origins[1:]) ] swh_scheduler.origin_visit_stats_upsert(visit_stats) # We expect to retrieve the origin with a NULL last_scheduled # as well as those with the oldest values (i.e. the last ones), in order. expected = [origins[0]] + origins[1:][::-1] self._check_grab_next_visit( swh_scheduler, visit_type=visit_type, policy="oldest_scheduled_first", expected=expected, ) def test_grab_next_visits_never_visited_oldest_update_first( self, swh_scheduler, listed_origins_by_type, ): visit_type, origins = self._grab_next_visits_setup( swh_scheduler, listed_origins_by_type ) # Update known origins with a `last_update` field that we control base_date = datetime.datetime(2020, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc) updated_origins = [ attr.evolve(origin, last_update=base_date - datetime.timedelta(seconds=i)) for i, origin in enumerate(origins) ] updated_origins = swh_scheduler.record_listed_origins(updated_origins) # We expect to retrieve origins with the oldest update date, that is # origins at the end of our updated_origins list. expected_origins = sorted(updated_origins, key=lambda o: o.last_update) self._check_grab_next_visit( swh_scheduler, visit_type=visit_type, policy="never_visited_oldest_update_first", expected=expected_origins, ) def test_grab_next_visits_already_visited_order_by_lag( self, swh_scheduler, listed_origins_by_type, ): visit_type, origins = self._grab_next_visits_setup( swh_scheduler, listed_origins_by_type ) # Update known origins with a `last_update` field that we control base_date = datetime.datetime(2020, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc) updated_origins = [ attr.evolve(origin, last_update=base_date - datetime.timedelta(seconds=i)) for i, origin in enumerate(origins) ] updated_origins = swh_scheduler.record_listed_origins(updated_origins) # Update the visit stats with a known visit at a controlled date for # half the origins. Pick the date in the middle of the # updated_origins' `last_update` range visit_date = updated_origins[len(updated_origins) // 2].last_update visited_origins = updated_origins[::2] visit_stats = [ OriginVisitStats( url=origin.url, visit_type=origin.visit_type, last_snapshot=hash_to_bytes("d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"), last_eventful=visit_date, last_uneventful=None, last_failed=None, last_notfound=None, ) for origin in visited_origins ] swh_scheduler.origin_visit_stats_upsert(visit_stats) # We expect to retrieve visited origins with the largest lag, but only # those which haven't been visited since their last update expected_origins = sorted( [origin for origin in visited_origins if origin.last_update > visit_date], key=lambda o: visit_date - o.last_update, ) self._check_grab_next_visit( swh_scheduler, visit_type=visit_type, policy="already_visited_order_by_lag", expected=expected_origins, ) def test_grab_next_visits_underflow(self, swh_scheduler, listed_origins_by_type): """Check that grab_next_visits works when there not enough origins in the database""" visit_type = next(iter(listed_origins_by_type)) # Only add 5 origins to the database origins = listed_origins_by_type[visit_type][:5] assert origins swh_scheduler.record_listed_origins(origins) ret = swh_scheduler.grab_next_visits( visit_type, len(origins) + 2, policy="oldest_scheduled_first" ) assert len(ret) == 5 def test_grab_next_visits_unknown_policy(self, swh_scheduler): unknown_policy = "non_existing_policy" NUM_RESULTS = 5 with pytest.raises(UnknownPolicy, match=unknown_policy): swh_scheduler.grab_next_visits("type", NUM_RESULTS, policy=unknown_policy) def _create_task_types(self, scheduler): for tt in TASK_TYPES.values(): scheduler.create_task_type(tt) def test_origin_visit_stats_get_empty(self, swh_scheduler) -> None: assert swh_scheduler.origin_visit_stats_get([]) == [] def test_origin_visit_stats_upsert(self, swh_scheduler) -> None: eventful_date = utcnow() url = "https://github.com/test" visit_stats = OriginVisitStats( url=url, visit_type="git", last_eventful=eventful_date, last_uneventful=None, last_failed=None, last_notfound=None, ) swh_scheduler.origin_visit_stats_upsert([visit_stats]) swh_scheduler.origin_visit_stats_upsert([visit_stats]) assert swh_scheduler.origin_visit_stats_get([(url, "git")]) == [visit_stats] assert swh_scheduler.origin_visit_stats_get([(url, "svn")]) == [] uneventful_date = utcnow() visit_stats = OriginVisitStats( url=url, visit_type="git", last_eventful=None, last_uneventful=uneventful_date, last_failed=None, last_notfound=None, ) swh_scheduler.origin_visit_stats_upsert([visit_stats]) uneventful_visits = swh_scheduler.origin_visit_stats_get([(url, "git")]) expected_visit_stats = OriginVisitStats( url=url, visit_type="git", last_eventful=eventful_date, last_uneventful=uneventful_date, last_failed=None, last_notfound=None, ) assert uneventful_visits == [expected_visit_stats] failed_date = utcnow() visit_stats = OriginVisitStats( url=url, visit_type="git", last_eventful=None, last_uneventful=None, last_failed=failed_date, last_notfound=None, ) swh_scheduler.origin_visit_stats_upsert([visit_stats]) failed_visits = swh_scheduler.origin_visit_stats_get([(url, "git")]) expected_visit_stats = OriginVisitStats( url=url, visit_type="git", last_eventful=eventful_date, last_uneventful=uneventful_date, last_failed=failed_date, last_notfound=None, ) assert failed_visits == [expected_visit_stats] def test_origin_visit_stats_upsert_with_snapshot(self, swh_scheduler) -> None: eventful_date = utcnow() url = "https://github.com/666/test" visit_stats = OriginVisitStats( url=url, visit_type="git", last_eventful=eventful_date, last_uneventful=None, last_failed=None, last_notfound=None, last_snapshot=hash_to_bytes("d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"), ) swh_scheduler.origin_visit_stats_upsert([visit_stats]) assert swh_scheduler.origin_visit_stats_get([(url, "git")]) == [visit_stats] assert swh_scheduler.origin_visit_stats_get([(url, "svn")]) == [] def test_origin_visit_stats_upsert_batch(self, swh_scheduler) -> None: """Batch upsert is ok""" visit_stats = [ OriginVisitStats( url="foo", visit_type="git", last_eventful=utcnow(), last_uneventful=None, last_failed=None, last_notfound=None, last_snapshot=hash_to_bytes("d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"), ), OriginVisitStats( url="bar", visit_type="git", last_eventful=None, last_uneventful=utcnow(), last_notfound=None, last_failed=None, last_snapshot=hash_to_bytes("fffcc0710eb6cf9efd5b920a8453e1e07157bfff"), ), ] swh_scheduler.origin_visit_stats_upsert(visit_stats) for visit_stat in swh_scheduler.origin_visit_stats_get( [(vs.url, vs.visit_type) for vs in visit_stats] ): assert visit_stat is not None def test_origin_visit_stats_upsert_cardinality_failing(self, swh_scheduler) -> None: """Batch upsert does not support altering multiple times the same origin-visit-status """ with pytest.raises(SchedulerException, match="CardinalityViolation"): swh_scheduler.origin_visit_stats_upsert( [ OriginVisitStats( url="foo", visit_type="git", last_eventful=None, last_uneventful=utcnow(), last_notfound=None, last_failed=None, last_snapshot=None, ), OriginVisitStats( url="foo", visit_type="git", last_eventful=None, last_uneventful=utcnow(), last_notfound=None, last_failed=None, last_snapshot=None, ), ] ) def test_metrics_origins_known(self, swh_scheduler, listed_origins): swh_scheduler.record_listed_origins(listed_origins) ret = swh_scheduler.update_metrics() assert sum(metric.origins_known for metric in ret) == len(listed_origins) def test_metrics_origins_enabled(self, swh_scheduler, listed_origins): swh_scheduler.record_listed_origins(listed_origins) disabled_origin = attr.evolve(listed_origins[0], enabled=False) swh_scheduler.record_listed_origins([disabled_origin]) ret = swh_scheduler.update_metrics(lister_id=disabled_origin.lister_id) for metric in ret: if metric.visit_type == disabled_origin.visit_type: # We disabled one of these origins assert metric.origins_known - metric.origins_enabled == 1 else: # But these are still all enabled assert metric.origins_known == metric.origins_enabled def test_metrics_origins_never_visited(self, swh_scheduler, listed_origins): swh_scheduler.record_listed_origins(listed_origins) # Pretend that we've recorded a visit on one origin visited_origin = listed_origins[0] swh_scheduler.origin_visit_stats_upsert( [ OriginVisitStats( url=visited_origin.url, visit_type=visited_origin.visit_type, last_eventful=utcnow(), last_uneventful=None, last_failed=None, last_notfound=None, last_snapshot=hash_to_bytes( "d81cc0710eb6cf9efd5b920a8453e1e07157b6cd" ), ), ] ) ret = swh_scheduler.update_metrics(lister_id=visited_origin.lister_id) for metric in ret: if metric.visit_type == visited_origin.visit_type: # We visited one of these origins assert metric.origins_known - metric.origins_never_visited == 1 else: # But none of these have been visited assert metric.origins_known == metric.origins_never_visited def test_metrics_origins_with_pending_changes(self, swh_scheduler, listed_origins): swh_scheduler.record_listed_origins(listed_origins) # Pretend that we've recorded a visit on one origin, in the past with # respect to the "last update" time for the origin visited_origin = listed_origins[0] assert visited_origin.last_update is not None swh_scheduler.origin_visit_stats_upsert( [ OriginVisitStats( url=visited_origin.url, visit_type=visited_origin.visit_type, last_eventful=visited_origin.last_update - datetime.timedelta(days=1), last_uneventful=None, last_failed=None, last_notfound=None, last_snapshot=hash_to_bytes( "d81cc0710eb6cf9efd5b920a8453e1e07157b6cd" ), ), ] ) ret = swh_scheduler.update_metrics(lister_id=visited_origin.lister_id) for metric in ret: if metric.visit_type == visited_origin.visit_type: # We visited one of these origins, in the past assert metric.origins_with_pending_changes == 1 else: # But none of these have been visited assert metric.origins_with_pending_changes == 0 def test_update_metrics_explicit_lister(self, swh_scheduler, listed_origins): swh_scheduler.record_listed_origins(listed_origins) fake_uuid = uuid.uuid4() assert all(fake_uuid != origin.lister_id for origin in listed_origins) ret = swh_scheduler.update_metrics(lister_id=fake_uuid) assert len(ret) == 0 def test_update_metrics_explicit_timestamp(self, swh_scheduler, listed_origins): swh_scheduler.record_listed_origins(listed_origins) ts = datetime.datetime(2020, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc) ret = swh_scheduler.update_metrics(timestamp=ts) assert all(metric.last_update == ts for metric in ret) def test_update_metrics_twice(self, swh_scheduler, listed_origins): swh_scheduler.record_listed_origins(listed_origins) ts = utcnow() ret = swh_scheduler.update_metrics(timestamp=ts) assert all(metric.last_update == ts for metric in ret) second_ts = ts + datetime.timedelta(seconds=1) ret = swh_scheduler.update_metrics(timestamp=second_ts) assert all(metric.last_update == second_ts for metric in ret) def test_get_metrics(self, swh_scheduler, listed_origins): swh_scheduler.record_listed_origins(listed_origins) updated = swh_scheduler.update_metrics() retrieved = swh_scheduler.get_metrics() assert_metrics_equal(updated, retrieved) def test_get_metrics_by_lister(self, swh_scheduler, listed_origins): lister_id = listed_origins[0].lister_id assert lister_id is not None swh_scheduler.record_listed_origins(listed_origins) updated = swh_scheduler.update_metrics() retrieved = swh_scheduler.get_metrics(lister_id=lister_id) assert len(retrieved) > 0 assert_metrics_equal( [metric for metric in updated if metric.lister_id == lister_id], retrieved ) def test_get_metrics_by_visit_type(self, swh_scheduler, listed_origins): visit_type = listed_origins[0].visit_type assert visit_type is not None swh_scheduler.record_listed_origins(listed_origins) updated = swh_scheduler.update_metrics() retrieved = swh_scheduler.get_metrics(visit_type=visit_type) assert len(retrieved) > 0 assert_metrics_equal( [metric for metric in updated if metric.visit_type == visit_type], retrieved ) diff --git a/swh/scheduler/tests/test_simulator.py b/swh/scheduler/tests/test_simulator.py index 7c7aaca..95085a5 100644 --- a/swh/scheduler/tests/test_simulator.py +++ b/swh/scheduler/tests/test_simulator.py @@ -1,53 +1,69 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import io + import pytest from swh.core.api.classes import stream_results import swh.scheduler.simulator as simulator from swh.scheduler.tests.common import TASK_TYPES NUM_ORIGINS = 42 TEST_RUNTIME = 1000 def test_fill_test_data(swh_scheduler): for task_type in TASK_TYPES.values(): swh_scheduler.create_task_type(task_type) simulator.fill_test_data(swh_scheduler, num_origins=NUM_ORIGINS) origins = list(stream_results(swh_scheduler.get_listed_origins)) assert len(origins) == NUM_ORIGINS res = swh_scheduler.search_tasks() assert len(res) == NUM_ORIGINS -@pytest.mark.parametrize("policy", ("oldest_scheduled_first",)) +@pytest.mark.parametrize( + "policy", + ( + "oldest_scheduled_first", + "never_visited_oldest_update_first", + "already_visited_order_by_lag", + ), +) def test_run_origin_scheduler(swh_scheduler, policy): for task_type in TASK_TYPES.values(): swh_scheduler.create_task_type(task_type) simulator.fill_test_data(swh_scheduler, num_origins=NUM_ORIGINS) simulator.run( swh_scheduler, scheduler_type="origin_scheduler", policy=policy, runtime=TEST_RUNTIME, ) def test_run_task_scheduler(swh_scheduler): for task_type in TASK_TYPES.values(): swh_scheduler.create_task_type(task_type) simulator.fill_test_data(swh_scheduler, num_origins=NUM_ORIGINS) - simulator.run( + report = simulator.run( swh_scheduler, scheduler_type="task_scheduler", policy=None, runtime=TEST_RUNTIME, ) + + # just check these SimulationReport methods do not crash + assert report.format(with_plots=True) + assert report.format(with_plots=False) + fobj = io.StringIO() + report.metrics_csv(fobj=fobj) + assert fobj.getvalue()