Page MenuHomeSoftware Heritage

D1034.id3290.diff
No OneTemporary

D1034.id3290.diff

diff --git a/swh/scheduler/backend.py b/swh/scheduler/backend.py
--- a/swh/scheduler/backend.py
+++ b/swh/scheduler/backend.py
@@ -3,10 +3,7 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
-import binascii
-import datetime
import json
-import tempfile
import logging
from arrow import Arrow, utcnow
@@ -29,70 +26,13 @@
psycopg2.extensions.register_adapter(Arrow, adapt_arrow)
-class DbBackend(BaseDb):
- """Base class intended to be used for scheduling db backend classes
+def format_query(query, keys):
+ """Format a query with the given keys"""
- cf. swh.scheduler.backend.SchedulerBackend, and
- swh.scheduler.updater.backend.SchedulerUpdaterBackend
+ query_keys = ', '.join(keys)
+ placeholders = ', '.join(['%s'] * len(keys))
- """
- cursor = BaseDb._cursor
-
- def _format_query(self, query, keys):
- """Format a query with the given keys"""
-
- query_keys = ', '.join(keys)
- placeholders = ', '.join(['%s'] * len(keys))
-
- return query.format(keys=query_keys, placeholders=placeholders)
-
- def copy_to(self, items, tblname, columns, default_columns={},
- cursor=None, item_cb=None):
- def escape(data):
- if data is None:
- return ''
- if isinstance(data, bytes):
- return '\\x%s' % binascii.hexlify(data).decode('ascii')
- elif isinstance(data, str):
- return '"%s"' % data.replace('"', '""')
- elif isinstance(data, (datetime.datetime, Arrow)):
- # We escape twice to make sure the string generated by
- # isoformat gets escaped
- return escape(data.isoformat())
- elif isinstance(data, dict):
- return escape(json.dumps(data))
- elif isinstance(data, list):
- return escape("{%s}" % ','.join(escape(d) for d in data))
- elif isinstance(data, psycopg2.extras.Range):
- # We escape twice here too, so that we make sure
- # everything gets passed to copy properly
- return escape(
- '%s%s,%s%s' % (
- '[' if data.lower_inc else '(',
- '-infinity' if data.lower_inf else escape(data.lower),
- 'infinity' if data.upper_inf else escape(data.upper),
- ']' if data.upper_inc else ')',
- )
- )
- else:
- # We don't escape here to make sure we pass literals properly
- return str(data)
- with tempfile.TemporaryFile('w+') as f:
- for d in items:
- if item_cb is not None:
- item_cb(d)
- line = []
- for k in columns:
- v = d.get(k)
- if not v:
- v = default_columns.get(k)
- v = escape(v)
- line.append(v)
- f.write(','.join(line))
- f.write('\n')
- f.seek(0)
- cursor.copy_expert('COPY %s (%s) FROM STDIN CSV' % (
- tblname, ', '.join(columns)), f)
+ return query.format(keys=query_keys, placeholders=placeholders)
class SchedulerBackend:
@@ -108,7 +48,7 @@
"""
if isinstance(db, psycopg2.extensions.connection):
self._pool = None
- self._db = DbBackend(db)
+ self._db = BaseDb(db)
else:
self._pool = psycopg2.pool.ThreadedConnectionPool(
min_pool_conns, max_pool_conns, db,
@@ -119,7 +59,7 @@
def get_db(self):
if self._db:
return self._db
- return DbBackend.from_pool(self._pool)
+ return BaseDb.from_pool(self._pool)
task_type_keys = [
'type', 'description', 'backend_name', 'default_interval',
@@ -152,7 +92,7 @@
"""
keys = [key for key in self.task_type_keys if key in task_type]
- query = db._format_query(
+ query = format_query(
"""insert into task_type ({keys}) values ({placeholders})""",
keys)
cur.execute(query, [task_type[key] for key in keys])
@@ -160,7 +100,7 @@
@db_transaction()
def get_task_type(self, task_type_name, db=None, cur=None):
"""Retrieve the task type with id task_type_name"""
- query = db._format_query(
+ query = format_query(
"select {keys} from task_type where type=%s",
self.task_type_keys,
)
@@ -170,7 +110,7 @@
@db_transaction()
def get_task_types(self, db=None, cur=None):
"""Retrieve all registered task types"""
- query = db._format_query(
+ query = format_query(
"select {keys} from task_type",
self.task_type_keys,
)
@@ -209,8 +149,8 @@
'policy': policy,
'status': 'next_run_not_scheduled'
},
- cursor=cur)
- query = db._format_query(
+ cur=cur)
+ query = format_query(
'select {keys} from swh_scheduler_create_tasks_from_temp()',
self.task_keys,
)
@@ -297,8 +237,8 @@
@db_transaction()
def get_tasks(self, task_ids, db=None, cur=None):
"""Retrieve the info of tasks whose ids are listed."""
- query = db._format_query('select {keys} from task where id in %s',
- self.task_keys)
+ query = format_query('select {keys} from task where id in %s',
+ self.task_keys)
cur.execute(query, (tuple(task_ids),))
return cur.fetchall()
@@ -406,7 +346,7 @@
"""
cur.execute('select swh_scheduler_mktemp_task_run()')
db.copy_to(task_runs, 'tmp_task_run', self.task_run_create_keys,
- cursor=cur)
+ cur=cur)
cur.execute('select swh_scheduler_schedule_task_run_from_temp()')
@db_transaction()
diff --git a/swh/scheduler/updater/backend.py b/swh/scheduler/updater/backend.py
--- a/swh/scheduler/updater/backend.py
+++ b/swh/scheduler/updater/backend.py
@@ -7,8 +7,10 @@
from arrow import utcnow
import psycopg2.pool
import psycopg2.extras
-from swh.scheduler.backend import DbBackend
+
+from swh.core.db import BaseDb
from swh.core.db.common import db_transaction, db_transaction_generator
+from swh.scheduler.backend import format_query
class SchedulerUpdaterBackend:
@@ -24,7 +26,7 @@
"""
if isinstance(db, psycopg2.extensions.connection):
self._pool = None
- self._db = DbBackend(db)
+ self._db = BaseDb(db)
else:
self._pool = psycopg2.pool.ThreadedConnectionPool(
min_pool_conns, max_pool_conns, db,
@@ -36,7 +38,7 @@
def get_db(self):
if self._db:
return self._db
- return DbBackend.from_pool(self._pool)
+ return BaseDb.from_pool(self._pool)
cache_put_keys = ['url', 'cnt', 'last_seen', 'origin_type']
@@ -56,8 +58,8 @@
event['last_seen'] = timestamp
yield event
cur.execute('select swh_mktemp_cache()')
- db.copy_to(prepare_events(events),
- 'tmp_cache', self.cache_put_keys, cursor=cur)
+ db.copy_to(prepare_events(events, timestamp),
+ 'tmp_cache', self.cache_put_keys, cur=cur)
cur.execute('select swh_cache_put()')
cache_read_keys = ['id', 'url', 'origin_type', 'cnt', 'first_seen',
@@ -74,8 +76,8 @@
if not limit:
limit = self.limit
- q = db._format_query('select {keys} from swh_cache_read(%s, %s)',
- self.cache_read_keys)
+ q = format_query('select {keys} from swh_cache_read(%s, %s)',
+ self.cache_read_keys)
cur.execute(q, (timestamp, limit))
for r in cur.fetchall():
r['id'] = r['id'].tobytes()

File Metadata

Mime Type
text/plain
Expires
Thu, Dec 19, 7:03 AM (9 h, 42 m ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3217373

Event Timeline