Changeset View
Changeset View
Standalone View
Standalone View
swh/scheduler/backend.py
# Copyright (C) 2015-2018 The Software Heritage developers | # Copyright (C) 2015-2018 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
import binascii | |||||
import datetime | |||||
import json | import json | ||||
import tempfile | |||||
import logging | import logging | ||||
from arrow import Arrow, utcnow | from arrow import Arrow, utcnow | ||||
import psycopg2.pool | import psycopg2.pool | ||||
import psycopg2.extras | import psycopg2.extras | ||||
from psycopg2.extensions import AsIs | from psycopg2.extensions import AsIs | ||||
from swh.core.db import BaseDb | from swh.core.db import BaseDb | ||||
from swh.core.db.common import db_transaction, db_transaction_generator | from swh.core.db.common import db_transaction, db_transaction_generator | ||||
logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
def adapt_arrow(arrow): | def adapt_arrow(arrow): | ||||
return AsIs("'%s'::timestamptz" % arrow.isoformat()) | return AsIs("'%s'::timestamptz" % arrow.isoformat()) | ||||
psycopg2.extensions.register_adapter(dict, psycopg2.extras.Json) | psycopg2.extensions.register_adapter(dict, psycopg2.extras.Json) | ||||
psycopg2.extensions.register_adapter(Arrow, adapt_arrow) | psycopg2.extensions.register_adapter(Arrow, adapt_arrow) | ||||
class DbBackend(BaseDb): | def format_query(query, keys): | ||||
"""Base class intended to be used for scheduling db backend classes | |||||
cf. swh.scheduler.backend.SchedulerBackend, and | |||||
swh.scheduler.updater.backend.SchedulerUpdaterBackend | |||||
""" | |||||
cursor = BaseDb._cursor | |||||
def _format_query(self, query, keys): | |||||
"""Format a query with the given keys""" | """Format a query with the given keys""" | ||||
query_keys = ', '.join(keys) | query_keys = ', '.join(keys) | ||||
placeholders = ', '.join(['%s'] * len(keys)) | placeholders = ', '.join(['%s'] * len(keys)) | ||||
return query.format(keys=query_keys, placeholders=placeholders) | return query.format(keys=query_keys, placeholders=placeholders) | ||||
def copy_to(self, items, tblname, columns, default_columns={}, | |||||
cursor=None, item_cb=None): | |||||
def escape(data): | |||||
if data is None: | |||||
return '' | |||||
if isinstance(data, bytes): | |||||
return '\\x%s' % binascii.hexlify(data).decode('ascii') | |||||
elif isinstance(data, str): | |||||
return '"%s"' % data.replace('"', '""') | |||||
elif isinstance(data, (datetime.datetime, Arrow)): | |||||
# We escape twice to make sure the string generated by | |||||
# isoformat gets escaped | |||||
return escape(data.isoformat()) | |||||
elif isinstance(data, dict): | |||||
return escape(json.dumps(data)) | |||||
elif isinstance(data, list): | |||||
return escape("{%s}" % ','.join(escape(d) for d in data)) | |||||
elif isinstance(data, psycopg2.extras.Range): | |||||
# We escape twice here too, so that we make sure | |||||
# everything gets passed to copy properly | |||||
return escape( | |||||
'%s%s,%s%s' % ( | |||||
'[' if data.lower_inc else '(', | |||||
'-infinity' if data.lower_inf else escape(data.lower), | |||||
'infinity' if data.upper_inf else escape(data.upper), | |||||
']' if data.upper_inc else ')', | |||||
) | |||||
) | |||||
else: | |||||
# We don't escape here to make sure we pass literals properly | |||||
return str(data) | |||||
with tempfile.TemporaryFile('w+') as f: | |||||
for d in items: | |||||
if item_cb is not None: | |||||
item_cb(d) | |||||
line = [] | |||||
for k in columns: | |||||
v = d.get(k) | |||||
if not v: | |||||
v = default_columns.get(k) | |||||
v = escape(v) | |||||
line.append(v) | |||||
f.write(','.join(line)) | |||||
f.write('\n') | |||||
f.seek(0) | |||||
cursor.copy_expert('COPY %s (%s) FROM STDIN CSV' % ( | |||||
tblname, ', '.join(columns)), f) | |||||
class SchedulerBackend: | class SchedulerBackend: | ||||
"""Backend for the Software Heritage scheduling database. | """Backend for the Software Heritage scheduling database. | ||||
""" | """ | ||||
def __init__(self, db, min_pool_conns=1, max_pool_conns=10): | def __init__(self, db, min_pool_conns=1, max_pool_conns=10): | ||||
""" | """ | ||||
Args: | Args: | ||||
db_conn: either a libpq connection string, or a psycopg2 connection | db_conn: either a libpq connection string, or a psycopg2 connection | ||||
""" | """ | ||||
if isinstance(db, psycopg2.extensions.connection): | if isinstance(db, psycopg2.extensions.connection): | ||||
self._pool = None | self._pool = None | ||||
self._db = DbBackend(db) | self._db = BaseDb(db) | ||||
else: | else: | ||||
self._pool = psycopg2.pool.ThreadedConnectionPool( | self._pool = psycopg2.pool.ThreadedConnectionPool( | ||||
min_pool_conns, max_pool_conns, db, | min_pool_conns, max_pool_conns, db, | ||||
cursor_factory=psycopg2.extras.RealDictCursor, | cursor_factory=psycopg2.extras.RealDictCursor, | ||||
) | ) | ||||
self._db = None | self._db = None | ||||
def get_db(self): | def get_db(self): | ||||
if self._db: | if self._db: | ||||
return self._db | return self._db | ||||
return DbBackend.from_pool(self._pool) | return BaseDb.from_pool(self._pool) | ||||
task_type_keys = [ | task_type_keys = [ | ||||
'type', 'description', 'backend_name', 'default_interval', | 'type', 'description', 'backend_name', 'default_interval', | ||||
'min_interval', 'max_interval', 'backoff_factor', 'max_queue_length', | 'min_interval', 'max_interval', 'backoff_factor', 'max_queue_length', | ||||
'num_retries', 'retry_delay', | 'num_retries', 'retry_delay', | ||||
] | ] | ||||
@db_transaction() | @db_transaction() | ||||
Show All 16 Lines | def create_task_type(self, task_type, db=None, cur=None): | ||||
between two task runs | between two task runs | ||||
- backoff_factor (float): the factor by which the interval | - backoff_factor (float): the factor by which the interval | ||||
changes at each run | changes at each run | ||||
- max_queue_length (int): the maximum length of the task queue | - max_queue_length (int): the maximum length of the task queue | ||||
for this task type | for this task type | ||||
""" | """ | ||||
keys = [key for key in self.task_type_keys if key in task_type] | keys = [key for key in self.task_type_keys if key in task_type] | ||||
query = db._format_query( | query = format_query( | ||||
"""insert into task_type ({keys}) values ({placeholders})""", | """insert into task_type ({keys}) values ({placeholders})""", | ||||
keys) | keys) | ||||
cur.execute(query, [task_type[key] for key in keys]) | cur.execute(query, [task_type[key] for key in keys]) | ||||
@db_transaction() | @db_transaction() | ||||
def get_task_type(self, task_type_name, db=None, cur=None): | def get_task_type(self, task_type_name, db=None, cur=None): | ||||
"""Retrieve the task type with id task_type_name""" | """Retrieve the task type with id task_type_name""" | ||||
query = db._format_query( | query = format_query( | ||||
"select {keys} from task_type where type=%s", | "select {keys} from task_type where type=%s", | ||||
self.task_type_keys, | self.task_type_keys, | ||||
) | ) | ||||
cur.execute(query, (task_type_name,)) | cur.execute(query, (task_type_name,)) | ||||
return cur.fetchone() | return cur.fetchone() | ||||
@db_transaction() | @db_transaction() | ||||
def get_task_types(self, db=None, cur=None): | def get_task_types(self, db=None, cur=None): | ||||
"""Retrieve all registered task types""" | """Retrieve all registered task types""" | ||||
query = db._format_query( | query = format_query( | ||||
"select {keys} from task_type", | "select {keys} from task_type", | ||||
self.task_type_keys, | self.task_type_keys, | ||||
) | ) | ||||
cur.execute(query) | cur.execute(query) | ||||
return cur.fetchall() | return cur.fetchall() | ||||
task_create_keys = [ | task_create_keys = [ | ||||
'type', 'arguments', 'next_run', 'policy', 'status', 'retries_left', | 'type', 'arguments', 'next_run', 'policy', 'status', 'retries_left', | ||||
Show All 22 Lines | def create_tasks(self, tasks, policy='recurring', db=None, cur=None): | ||||
""" | """ | ||||
cur.execute('select swh_scheduler_mktemp_task()') | cur.execute('select swh_scheduler_mktemp_task()') | ||||
db.copy_to(tasks, 'tmp_task', self.task_create_keys, | db.copy_to(tasks, 'tmp_task', self.task_create_keys, | ||||
default_columns={ | default_columns={ | ||||
'policy': policy, | 'policy': policy, | ||||
'status': 'next_run_not_scheduled' | 'status': 'next_run_not_scheduled' | ||||
}, | }, | ||||
cursor=cur) | cur=cur) | ||||
query = db._format_query( | query = format_query( | ||||
'select {keys} from swh_scheduler_create_tasks_from_temp()', | 'select {keys} from swh_scheduler_create_tasks_from_temp()', | ||||
self.task_keys, | self.task_keys, | ||||
) | ) | ||||
cur.execute(query) | cur.execute(query) | ||||
return cur.fetchall() | return cur.fetchall() | ||||
@db_transaction() | @db_transaction() | ||||
def set_status_tasks(self, task_ids, status='disabled', next_run=None, | def set_status_tasks(self, task_ids, status='disabled', next_run=None, | ||||
▲ Show 20 Lines • Show All 70 Lines • ▼ Show 20 Lines | def search_tasks(self, task_id=None, task_type=None, status=None, | ||||
query += ' limit %s :: bigint' | query += ' limit %s :: bigint' | ||||
args.append(limit) | args.append(limit) | ||||
cur.execute(query, args) | cur.execute(query, args) | ||||
return cur.fetchall() | return cur.fetchall() | ||||
@db_transaction() | @db_transaction() | ||||
def get_tasks(self, task_ids, db=None, cur=None): | def get_tasks(self, task_ids, db=None, cur=None): | ||||
"""Retrieve the info of tasks whose ids are listed.""" | """Retrieve the info of tasks whose ids are listed.""" | ||||
query = db._format_query('select {keys} from task where id in %s', | query = format_query('select {keys} from task where id in %s', | ||||
self.task_keys) | self.task_keys) | ||||
cur.execute(query, (tuple(task_ids),)) | cur.execute(query, (tuple(task_ids),)) | ||||
return cur.fetchall() | return cur.fetchall() | ||||
@db_transaction() | @db_transaction() | ||||
def peek_ready_tasks(self, task_type, timestamp=None, num_tasks=None, | def peek_ready_tasks(self, task_type, timestamp=None, num_tasks=None, | ||||
num_tasks_priority=None, | num_tasks_priority=None, | ||||
db=None, cur=None): | db=None, cur=None): | ||||
"""Fetch the list of ready tasks | """Fetch the list of ready tasks | ||||
▲ Show 20 Lines • Show All 91 Lines • ▼ Show 20 Lines | def mass_schedule_task_runs(self, task_runs, db=None, cur=None): | ||||
- metadata (dict): metadata to add to the task_run entry | - metadata (dict): metadata to add to the task_run entry | ||||
- scheduled (datetime.datetime): the instant the event occurred | - scheduled (datetime.datetime): the instant the event occurred | ||||
Returns: | Returns: | ||||
None | None | ||||
""" | """ | ||||
cur.execute('select swh_scheduler_mktemp_task_run()') | cur.execute('select swh_scheduler_mktemp_task_run()') | ||||
db.copy_to(task_runs, 'tmp_task_run', self.task_run_create_keys, | db.copy_to(task_runs, 'tmp_task_run', self.task_run_create_keys, | ||||
cursor=cur) | cur=cur) | ||||
cur.execute('select swh_scheduler_schedule_task_run_from_temp()') | cur.execute('select swh_scheduler_schedule_task_run_from_temp()') | ||||
@db_transaction() | @db_transaction() | ||||
def start_task_run(self, backend_id, metadata=None, timestamp=None, | def start_task_run(self, backend_id, metadata=None, timestamp=None, | ||||
db=None, cur=None): | db=None, cur=None): | ||||
"""Mark a given task as started, updating the corresponding task_run | """Mark a given task as started, updating the corresponding task_run | ||||
entry in the database. | entry in the database. | ||||
▲ Show 20 Lines • Show All 99 Lines • Show Last 20 Lines |