diff --git a/swh/scheduler/api/server.py b/swh/scheduler/api/server.py --- a/swh/scheduler/api/server.py +++ b/swh/scheduler/api/server.py @@ -18,7 +18,7 @@ 'scheduler': ('dict', { 'cls': 'local', 'args': { - 'scheduling_db': 'dbname=softwareheritage-scheduler-dev', + 'db': 'dbname=softwareheritage-scheduler-dev', }, }) } diff --git a/swh/scheduler/backend.py b/swh/scheduler/backend.py --- a/swh/scheduler/backend.py +++ b/swh/scheduler/backend.py @@ -5,17 +5,17 @@ import binascii import datetime -from functools import wraps import json import tempfile import logging from arrow import Arrow, utcnow -import psycopg2 +import psycopg2.pool import psycopg2.extras from psycopg2.extensions import AsIs -from swh.core.config import SWHConfig +from swh.core.db import BaseDb +from swh.core.db.common import db_transaction, db_transaction_generator logger = logging.getLogger(__name__) @@ -29,78 +29,14 @@ psycopg2.extensions.register_adapter(Arrow, adapt_arrow) -def autocommit(fn): - @wraps(fn) - def wrapped(self, *args, **kwargs): - autocommit = False - if 'cursor' not in kwargs or not kwargs['cursor']: - autocommit = True - kwargs['cursor'] = self.cursor() - - try: - ret = fn(self, *args, **kwargs) - except Exception: - if autocommit: - self.rollback() - raise - - if autocommit: - self.commit() - - return ret - - return wrapped - - -class DbBackend: - """Mixin intended to be used within scheduling db backend classes +class DbBackend(BaseDb): + """Base class intended to be used for scheduling db backend classes cf. swh.scheduler.backend.SchedulerBackend, and swh.scheduler.updater.backend.SchedulerUpdaterBackend """ - def reconnect(self): - if not self.db or self.db.closed: - self.db = psycopg2.connect( - dsn=self.db_conn_dsn, - cursor_factory=psycopg2.extras.RealDictCursor, - ) - - def cursor(self): - """Return a fresh cursor on the database, with auto-reconnection in - case of failure - - """ - cur = None - - # Get a fresh cursor and reconnect at most three times - tries = 0 - while True: - tries += 1 - try: - cur = self.db.cursor() - cur.execute('select 1') - break - except psycopg2.OperationalError: - if tries < 3: - self.reconnect() - else: - raise - - return cur - - def commit(self): - """Commit a transaction""" - self.db.commit() - - def rollback(self): - """Rollback a transaction""" - self.db.rollback() - - def close_connection(self): - """Close db connection""" - if self.db and not self.db.closed: - self.db.close() + cursor = BaseDb._cursor def _format_query(self, query, keys): """Format a query with the given keys""" @@ -110,20 +46,6 @@ return query.format(keys=query_keys, placeholders=placeholders) - def _format_multiquery(self, query, keys, values): - """Format a query with placeholders generated for multiple values""" - query_keys = ', '.join(keys) - placeholders = '), ('.join( - [', '.join(['%s'] * len(keys))] * len(values) - ) - ret_values = sum([[value[key] for key in keys] - for value in values], []) - - return ( - query.format(keys=query_keys, placeholders=placeholders), - ret_values, - ) - def copy_to(self, items, tblname, columns, default_columns={}, cursor=None, item_cb=None): def escape(data): @@ -173,23 +95,31 @@ tblname, ', '.join(columns)), f) -class SchedulerBackend(SWHConfig, DbBackend): +class SchedulerBackend: """Backend for the Software Heritage scheduling database. """ - CONFIG_BASE_FILENAME = 'scheduler' - DEFAULT_CONFIG = { - 'scheduling_db': ('str', 'dbname=softwareheritage-scheduler-dev'), - } - - def __init__(self, **override_config): - super().__init__() - self.config = self.parse_config_file(global_config=False) - self.config.update(override_config) - self.db = None - self.db_conn_dsn = self.config['scheduling_db'] - self.reconnect() - logger.debug('SchedulerBackend config=%s' % self.config) + + def __init__(self, db, min_pool_conns=1, max_pool_conns=10): + """ + Args: + db_conn: either a libpq connection string, or a psycopg2 connection + + """ + if isinstance(db, psycopg2.extensions.connection): + self._pool = None + self._db = DbBackend(db) + else: + self._pool = psycopg2.pool.ThreadedConnectionPool( + min_pool_conns, max_pool_conns, db, + cursor_factory=psycopg2.extras.RealDictCursor, + ) + self._db = None + + def get_db(self): + if self._db: + return self._db + return DbBackend.from_pool(self._pool) task_type_keys = [ 'type', 'description', 'backend_name', 'default_interval', @@ -197,8 +127,8 @@ 'num_retries', 'retry_delay', ] - @autocommit - def create_task_type(self, task_type, cursor=None): + @db_transaction() + def create_task_type(self, task_type, db=None, cur=None): """Create a new task type ready for scheduling. Args: @@ -222,33 +152,30 @@ """ keys = [key for key in self.task_type_keys if key in task_type] - query = self._format_query( + query = db._format_query( """insert into task_type ({keys}) values ({placeholders})""", keys) - cursor.execute(query, [task_type[key] for key in keys]) + cur.execute(query, [task_type[key] for key in keys]) - @autocommit - def get_task_type(self, task_type_name, cursor=None): + @db_transaction() + def get_task_type(self, task_type_name, db=None, cur=None): """Retrieve the task type with id task_type_name""" - query = self._format_query( + query = db._format_query( "select {keys} from task_type where type=%s", self.task_type_keys, ) - cursor.execute(query, (task_type_name,)) - - ret = cursor.fetchone() + cur.execute(query, (task_type_name,)) + return cur.fetchone() - return ret - - @autocommit - def get_task_types(self, cursor=None): - query = self._format_query( + @db_transaction() + def get_task_types(self, db=None, cur=None): + """Retrieve all registered task types""" + query = db._format_query( "select {keys} from task_type", self.task_type_keys, ) - cursor.execute(query) - ret = cursor.fetchall() - return ret + cur.execute(query) + return cur.fetchall() task_create_keys = [ 'type', 'arguments', 'next_run', 'policy', 'status', 'retries_left', @@ -256,8 +183,8 @@ ] task_keys = task_create_keys + ['id', 'current_interval', 'status'] - @autocommit - def create_tasks(self, tasks, policy='recurring', cursor=None): + @db_transaction() + def create_tasks(self, tasks, policy='recurring', db=None, cur=None): """Create new tasks. Args: @@ -276,23 +203,23 @@ a list of created tasks. """ - cursor.execute('select swh_scheduler_mktemp_task()') - self.copy_to(tasks, 'tmp_task', self.task_create_keys, - default_columns={ - 'policy': policy, - 'status': 'next_run_not_scheduled' - }, - cursor=cursor) - query = self._format_query( + cur.execute('select swh_scheduler_mktemp_task()') + db.copy_to(tasks, 'tmp_task', self.task_create_keys, + default_columns={ + 'policy': policy, + 'status': 'next_run_not_scheduled' + }, + cursor=cur) + query = db._format_query( 'select {keys} from swh_scheduler_create_tasks_from_temp()', self.task_keys, ) - cursor.execute(query) - return cursor.fetchall() + cur.execute(query) + return cur.fetchall() - @autocommit - def set_status_tasks(self, task_ids, - status='disabled', next_run=None, cursor=None): + @db_transaction() + def set_status_tasks(self, task_ids, status='disabled', next_run=None, + db=None, cur=None): """Set the tasks' status whose ids are listed. If given, also set the next_run date. @@ -307,17 +234,17 @@ query.append(" WHERE id IN %s") args.append(tuple(task_ids)) - cursor.execute(''.join(query), args) + cur.execute(''.join(query), args) - @autocommit - def disable_tasks(self, task_ids, cursor=None): + @db_transaction() + def disable_tasks(self, task_ids, db=None, cur=None): """Disable the tasks whose ids are listed.""" - return self.set_status_tasks(task_ids) + return self.set_status_tasks(task_ids, db=db, cur=cur) - @autocommit + @db_transaction() def search_tasks(self, task_id=None, task_type=None, status=None, priority=None, policy=None, before=None, after=None, - limit=None, cursor=None): + limit=None, db=None, cur=None): """Search tasks from selected criterions""" where = [] args = [] @@ -364,21 +291,21 @@ if limit: query += ' limit %s :: bigint' args.append(limit) - cursor.execute(query, args) - return cursor.fetchall() + cur.execute(query, args) + return cur.fetchall() - @autocommit - def get_tasks(self, task_ids, cursor=None): + @db_transaction() + def get_tasks(self, task_ids, db=None, cur=None): """Retrieve the info of tasks whose ids are listed.""" - query = self._format_query('select {keys} from task where id in %s', - self.task_keys) - cursor.execute(query, (tuple(task_ids),)) - return cursor.fetchall() + query = db._format_query('select {keys} from task where id in %s', + self.task_keys) + cur.execute(query, (tuple(task_ids),)) + return cur.fetchall() - @autocommit + @db_transaction() def peek_ready_tasks(self, task_type, timestamp=None, num_tasks=None, num_tasks_priority=None, - cursor=None): + db=None, cur=None): """Fetch the list of ready tasks Args: @@ -396,17 +323,17 @@ if timestamp is None: timestamp = utcnow() - cursor.execute( + cur.execute( '''select * from swh_scheduler_peek_ready_tasks( %s, %s, %s :: bigint, %s :: bigint)''', (task_type, timestamp, num_tasks, num_tasks_priority) ) - logger.debug('PEEK %s => %s' % (task_type, cursor.rowcount)) - return cursor.fetchall() + logger.debug('PEEK %s => %s' % (task_type, cur.rowcount)) + return cur.fetchall() - @autocommit + @db_transaction() def grab_ready_tasks(self, task_type, timestamp=None, num_tasks=None, - num_tasks_priority=None, cursor=None): + num_tasks_priority=None, db=None, cur=None): """Fetch the list of ready tasks, and mark them as scheduled Args: @@ -423,19 +350,19 @@ """ if timestamp is None: timestamp = utcnow() - cursor.execute( + cur.execute( '''select * from swh_scheduler_grab_ready_tasks( %s, %s, %s :: bigint, %s :: bigint)''', (task_type, timestamp, num_tasks, num_tasks_priority) ) - logger.debug('GRAB %s => %s' % (task_type, cursor.rowcount)) - return cursor.fetchall() + logger.debug('GRAB %s => %s' % (task_type, cur.rowcount)) + return cur.fetchall() task_run_create_keys = ['task', 'backend_id', 'scheduled', 'metadata'] - @autocommit + @db_transaction() def schedule_task_run(self, task_id, backend_id, metadata=None, - timestamp=None, cursor=None): + timestamp=None, db=None, cur=None): """Mark a given task as scheduled, adding a task_run entry in the database. Args: @@ -455,15 +382,15 @@ if timestamp is None: timestamp = utcnow() - cursor.execute( + cur.execute( 'select * from swh_scheduler_schedule_task_run(%s, %s, %s, %s)', (task_id, backend_id, metadata, timestamp) ) - return cursor.fetchone() + return cur.fetchone() - @autocommit - def mass_schedule_task_runs(self, task_runs, cursor=None): + @db_transaction() + def mass_schedule_task_runs(self, task_runs, db=None, cur=None): """Schedule a bunch of task runs. Args: @@ -477,14 +404,14 @@ Returns: None """ - cursor.execute('select swh_scheduler_mktemp_task_run()') - self.copy_to(task_runs, 'tmp_task_run', self.task_run_create_keys, - cursor=cursor) - cursor.execute('select swh_scheduler_schedule_task_run_from_temp()') + cur.execute('select swh_scheduler_mktemp_task_run()') + db.copy_to(task_runs, 'tmp_task_run', self.task_run_create_keys, + cursor=cur) + cur.execute('select swh_scheduler_schedule_task_run_from_temp()') - @autocommit + @db_transaction() def start_task_run(self, backend_id, metadata=None, timestamp=None, - cursor=None): + db=None, cur=None): """Mark a given task as started, updating the corresponding task_run entry in the database. @@ -504,16 +431,16 @@ if timestamp is None: timestamp = utcnow() - cursor.execute( + cur.execute( 'select * from swh_scheduler_start_task_run(%s, %s, %s)', (backend_id, metadata, timestamp) ) - return cursor.fetchone() + return cur.fetchone() - @autocommit + @db_transaction() def end_task_run(self, backend_id, status, metadata=None, timestamp=None, - result=None, cursor=None): + result=None, db=None, cur=None): """Mark a given task as ended, updating the corresponding task_run entry in the database. @@ -535,27 +462,26 @@ if timestamp is None: timestamp = utcnow() - cursor.execute( + cur.execute( 'select * from swh_scheduler_end_task_run(%s, %s, %s, %s)', (backend_id, status, metadata, timestamp) ) + return cur.fetchone() - return cursor.fetchone() - - @autocommit + @db_transaction_generator() def filter_task_to_archive(self, after_ts, before_ts, limit=10, last_id=-1, - cursor=None): + db=None, cur=None): """Returns the list of task/task_run prior to a given date to archive. """ last_task_run_id = None while True: row = None - cursor.execute( + cur.execute( "select * from swh_scheduler_task_to_archive(%s, %s, %s, %s)", (after_ts, before_ts, last_id, limit) ) - for row in cursor: + for row in cur: # nested type index does not accept bare values # transform it as a dict to comply with this row['arguments']['args'] = { @@ -574,8 +500,8 @@ last_id = _id last_task_run_id = _task_run_id - @autocommit - def delete_archived_tasks(self, task_ids, cursor=None): + @db_transaction() + def delete_archived_tasks(self, task_ids, db=None, cur=None): """Delete archived tasks as much as possible. Only the task_ids whose complete associated task_run have been cleaned up will be. @@ -585,6 +511,6 @@ _task_ids.append(task_id['task_id']) _task_run_ids.append(task_id['task_run_id']) - cursor.execute( + cur.execute( "select * from swh_scheduler_delete_archived_tasks(%s, %s)", (_task_ids, _task_run_ids)) diff --git a/swh/scheduler/celery_backend/runner.py b/swh/scheduler/celery_backend/runner.py --- a/swh/scheduler/celery_backend/runner.py +++ b/swh/scheduler/celery_backend/runner.py @@ -39,10 +39,9 @@ """ all_backend_tasks = [] while True: - cursor = backend.cursor() task_types = {} pending_tasks = [] - for task_type in backend.get_task_types(cursor=cursor): + for task_type in backend.get_task_types(): task_type_name = task_type['type'] task_types[task_type_name] = task_type max_queue_length = task_type['max_queue_length'] @@ -68,8 +67,7 @@ grabbed_tasks = backend.grab_ready_tasks( task_type_name, num_tasks=num_tasks, - num_tasks_priority=num_tasks_priority, - cursor=cursor) + num_tasks_priority=num_tasks_priority) if grabbed_tasks: pending_tasks.extend(grabbed_tasks) logger.info('Grabbed %s tasks %s', @@ -96,8 +94,7 @@ backend_tasks.append(data) logger.debug('Sent %s celery tasks', len(backend_tasks)) - backend.mass_schedule_task_runs(backend_tasks, cursor=cursor) - backend.commit() + backend.mass_schedule_task_runs(backend_tasks) all_backend_tasks.extend(backend_tasks)