diff --git a/swh/scheduler/backend.py b/swh/scheduler/backend.py --- a/swh/scheduler/backend.py +++ b/swh/scheduler/backend.py @@ -3,10 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import binascii -import datetime import json -import tempfile import logging from arrow import Arrow, utcnow @@ -29,70 +26,13 @@ psycopg2.extensions.register_adapter(Arrow, adapt_arrow) -class DbBackend(BaseDb): - """Base class intended to be used for scheduling db backend classes +def format_query(query, keys): + """Format a query with the given keys""" - cf. swh.scheduler.backend.SchedulerBackend, and - swh.scheduler.updater.backend.SchedulerUpdaterBackend + query_keys = ', '.join(keys) + placeholders = ', '.join(['%s'] * len(keys)) - """ - cursor = BaseDb._cursor - - def _format_query(self, query, keys): - """Format a query with the given keys""" - - query_keys = ', '.join(keys) - placeholders = ', '.join(['%s'] * len(keys)) - - return query.format(keys=query_keys, placeholders=placeholders) - - def copy_to(self, items, tblname, columns, default_columns={}, - cursor=None, item_cb=None): - def escape(data): - if data is None: - return '' - if isinstance(data, bytes): - return '\\x%s' % binascii.hexlify(data).decode('ascii') - elif isinstance(data, str): - return '"%s"' % data.replace('"', '""') - elif isinstance(data, (datetime.datetime, Arrow)): - # We escape twice to make sure the string generated by - # isoformat gets escaped - return escape(data.isoformat()) - elif isinstance(data, dict): - return escape(json.dumps(data)) - elif isinstance(data, list): - return escape("{%s}" % ','.join(escape(d) for d in data)) - elif isinstance(data, psycopg2.extras.Range): - # We escape twice here too, so that we make sure - # everything gets passed to copy properly - return escape( - '%s%s,%s%s' % ( - '[' if data.lower_inc else '(', - '-infinity' if data.lower_inf else escape(data.lower), - 'infinity' if data.upper_inf else escape(data.upper), - ']' if data.upper_inc else ')', - ) - ) - else: - # We don't escape here to make sure we pass literals properly - return str(data) - with tempfile.TemporaryFile('w+') as f: - for d in items: - if item_cb is not None: - item_cb(d) - line = [] - for k in columns: - v = d.get(k) - if not v: - v = default_columns.get(k) - v = escape(v) - line.append(v) - f.write(','.join(line)) - f.write('\n') - f.seek(0) - cursor.copy_expert('COPY %s (%s) FROM STDIN CSV' % ( - tblname, ', '.join(columns)), f) + return query.format(keys=query_keys, placeholders=placeholders) class SchedulerBackend: @@ -108,7 +48,7 @@ """ if isinstance(db, psycopg2.extensions.connection): self._pool = None - self._db = DbBackend(db) + self._db = BaseDb(db) else: self._pool = psycopg2.pool.ThreadedConnectionPool( min_pool_conns, max_pool_conns, db, @@ -119,7 +59,7 @@ def get_db(self): if self._db: return self._db - return DbBackend.from_pool(self._pool) + return BaseDb.from_pool(self._pool) task_type_keys = [ 'type', 'description', 'backend_name', 'default_interval', @@ -152,7 +92,7 @@ """ keys = [key for key in self.task_type_keys if key in task_type] - query = db._format_query( + query = format_query( """insert into task_type ({keys}) values ({placeholders})""", keys) cur.execute(query, [task_type[key] for key in keys]) @@ -160,7 +100,7 @@ @db_transaction() def get_task_type(self, task_type_name, db=None, cur=None): """Retrieve the task type with id task_type_name""" - query = db._format_query( + query = format_query( "select {keys} from task_type where type=%s", self.task_type_keys, ) @@ -170,7 +110,7 @@ @db_transaction() def get_task_types(self, db=None, cur=None): """Retrieve all registered task types""" - query = db._format_query( + query = format_query( "select {keys} from task_type", self.task_type_keys, ) @@ -209,8 +149,8 @@ 'policy': policy, 'status': 'next_run_not_scheduled' }, - cursor=cur) - query = db._format_query( + cur=cur) + query = format_query( 'select {keys} from swh_scheduler_create_tasks_from_temp()', self.task_keys, ) @@ -297,8 +237,8 @@ @db_transaction() def get_tasks(self, task_ids, db=None, cur=None): """Retrieve the info of tasks whose ids are listed.""" - query = db._format_query('select {keys} from task where id in %s', - self.task_keys) + query = format_query('select {keys} from task where id in %s', + self.task_keys) cur.execute(query, (tuple(task_ids),)) return cur.fetchall() @@ -406,7 +346,7 @@ """ cur.execute('select swh_scheduler_mktemp_task_run()') db.copy_to(task_runs, 'tmp_task_run', self.task_run_create_keys, - cursor=cur) + cur=cur) cur.execute('select swh_scheduler_schedule_task_run_from_temp()') @db_transaction() diff --git a/swh/scheduler/updater/backend.py b/swh/scheduler/updater/backend.py --- a/swh/scheduler/updater/backend.py +++ b/swh/scheduler/updater/backend.py @@ -7,8 +7,10 @@ from arrow import utcnow import psycopg2.pool import psycopg2.extras -from swh.scheduler.backend import DbBackend + +from swh.core.db import BaseDb from swh.core.db.common import db_transaction, db_transaction_generator +from swh.scheduler.backend import format_query class SchedulerUpdaterBackend: @@ -24,7 +26,7 @@ """ if isinstance(db, psycopg2.extensions.connection): self._pool = None - self._db = DbBackend(db) + self._db = BaseDb(db) else: self._pool = psycopg2.pool.ThreadedConnectionPool( min_pool_conns, max_pool_conns, db, @@ -36,7 +38,7 @@ def get_db(self): if self._db: return self._db - return DbBackend.from_pool(self._pool) + return BaseDb.from_pool(self._pool) cache_put_keys = ['url', 'cnt', 'last_seen', 'origin_type'] @@ -56,8 +58,8 @@ event['last_seen'] = timestamp yield event cur.execute('select swh_mktemp_cache()') - db.copy_to(prepare_events(events), - 'tmp_cache', self.cache_put_keys, cursor=cur) + db.copy_to(prepare_events(events, timestamp), + 'tmp_cache', self.cache_put_keys, cur=cur) cur.execute('select swh_cache_put()') cache_read_keys = ['id', 'url', 'origin_type', 'cnt', 'first_seen', @@ -74,8 +76,8 @@ if not limit: limit = self.limit - q = db._format_query('select {keys} from swh_cache_read(%s, %s)', - self.cache_read_keys) + q = format_query('select {keys} from swh_cache_read(%s, %s)', + self.cache_read_keys) cur.execute(q, (timestamp, limit)) for r in cur.fetchall(): r['id'] = r['id'].tobytes()