Changeset View
Changeset View
Standalone View
Standalone View
swh/vault/backend.py
Show First 20 Lines • Show All 76 Lines • ▼ Show 20 Lines | class VaultBackend: | ||||
def __init__(self, config): | def __init__(self, config): | ||||
self.config = config | self.config = config | ||||
self.cache = VaultCache(**self.config['cache']) | self.cache = VaultCache(**self.config['cache']) | ||||
self.db = None | self.db = None | ||||
self.reconnect() | self.reconnect() | ||||
self.smtp_server = smtplib.SMTP('localhost') | self.smtp_server = smtplib.SMTP('localhost') | ||||
def reconnect(self): | def reconnect(self): | ||||
"""Reconnect to the database.""" | |||||
if not self.db or self.db.closed: | if not self.db or self.db.closed: | ||||
self.db = psycopg2.connect( | self.db = psycopg2.connect( | ||||
dsn=self.config['vault_db'], | dsn=self.config['vault_db'], | ||||
cursor_factory=psycopg2.extras.RealDictCursor, | cursor_factory=psycopg2.extras.RealDictCursor, | ||||
) | ) | ||||
def close(self): | def close(self): | ||||
"""Close the underlying database connection.""" | |||||
self.db.close() | self.db.close() | ||||
def cursor(self): | def cursor(self): | ||||
"""Return a fresh cursor on the database, with auto-reconnection in | """Return a fresh cursor on the database, with auto-reconnection in | ||||
case of failure""" | case of failure""" | ||||
cur = None | cur = None | ||||
# Get a fresh cursor and reconnect at most three times | # Get a fresh cursor and reconnect at most three times | ||||
Show All 16 Lines | def commit(self): | ||||
self.db.commit() | self.db.commit() | ||||
def rollback(self): | def rollback(self): | ||||
"""Rollback a transaction""" | """Rollback a transaction""" | ||||
self.db.rollback() | self.db.rollback() | ||||
@autocommit | @autocommit | ||||
def task_info(self, obj_type, obj_id, cursor=None): | def task_info(self, obj_type, obj_id, cursor=None): | ||||
"""Fetch information from a bundle""" | |||||
obj_id = hashutil.hash_to_bytes(obj_id) | obj_id = hashutil.hash_to_bytes(obj_id) | ||||
cursor.execute(''' | cursor.execute(''' | ||||
SELECT id, type, object_id, task_uuid, task_status, | SELECT id, type, object_id, task_uuid, task_status, sticky, | ||||
ts_created, ts_done, ts_last_access, progress_msg | ts_created, ts_done, ts_last_access, progress_msg | ||||
FROM vault_bundle | FROM vault_bundle | ||||
WHERE type = %s AND object_id = %s''', (obj_type, obj_id)) | WHERE type = %s AND object_id = %s''', (obj_type, obj_id)) | ||||
res = cursor.fetchone() | res = cursor.fetchone() | ||||
if res: | if res: | ||||
res['object_id'] = bytes(res['object_id']) | res['object_id'] = bytes(res['object_id']) | ||||
return res | return res | ||||
def _send_task(task_uuid, args): | def _send_task(task_uuid, args): | ||||
"""Send a cooking task to the celery scheduler""" | |||||
task = get_task(cooking_task_name) | task = get_task(cooking_task_name) | ||||
task.apply_async(args, task_id=task_uuid) | task.apply_async(args, task_id=task_uuid) | ||||
@autocommit | @autocommit | ||||
def create_task(self, obj_type, obj_id, cursor=None): | def create_task(self, obj_type, obj_id, sticky=False, cursor=None): | ||||
"""Create and send a cooking task""" | |||||
obj_id = hashutil.hash_to_bytes(obj_id) | obj_id = hashutil.hash_to_bytes(obj_id) | ||||
args = [self.config, obj_type, obj_id] | args = [self.config, obj_type, obj_id] | ||||
CookerCls = get_cooker(obj_type) | CookerCls = get_cooker(obj_type) | ||||
cooker = CookerCls(*args) | cooker = CookerCls(*args) | ||||
cooker.check_exists() | cooker.check_exists() | ||||
task_uuid = celery.uuid() | task_uuid = celery.uuid() | ||||
cursor.execute(''' | cursor.execute(''' | ||||
INSERT INTO vault_bundle (type, object_id, task_uuid) | INSERT INTO vault_bundle (type, object_id, task_uuid, sticky) | ||||
VALUES (%s, %s, %s)''', (obj_type, obj_id, task_uuid)) | VALUES (%s, %s, %s, %s)''', | ||||
(obj_type, obj_id, task_uuid, sticky)) | |||||
self.commit() | self.commit() | ||||
self._send_task(task_uuid, args) | self._send_task(task_uuid, args) | ||||
@autocommit | @autocommit | ||||
def add_notif_email(self, obj_type, obj_id, email, cursor=None): | def add_notif_email(self, obj_type, obj_id, email, cursor=None): | ||||
"""Add an e-mail address to notify when a given bundle is ready""" | |||||
obj_id = hashutil.hash_to_bytes(obj_id) | obj_id = hashutil.hash_to_bytes(obj_id) | ||||
cursor.execute(''' | cursor.execute(''' | ||||
INSERT INTO vault_notif_email (email, bundle_id) | INSERT INTO vault_notif_email (email, bundle_id) | ||||
VALUES (%s, (SELECT id FROM vault_bundle | VALUES (%s, (SELECT id FROM vault_bundle | ||||
WHERE type = %s AND object_id = %s))''', | WHERE type = %s AND object_id = %s))''', | ||||
(email, obj_type, obj_id)) | (email, obj_type, obj_id)) | ||||
@autocommit | @autocommit | ||||
def cook_request(self, obj_type, obj_id, email=None, cursor=None): | def cook_request(self, obj_type, obj_id, *, sticky=False, | ||||
email=None, cursor=None): | |||||
"""Main entry point for cooking requests. This starts a cooking task if | |||||
needed, and add the given e-mail to the notify list""" | |||||
info = self.task_info(obj_type, obj_id) | info = self.task_info(obj_type, obj_id) | ||||
if info is None: | if info is None: | ||||
self.create_task(obj_type, obj_id) | self.create_task(obj_type, obj_id, sticky) | ||||
if email is not None: | if email is not None: | ||||
if info is not None and info['task_status'] == 'done': | if info is not None and info['task_status'] == 'done': | ||||
self.send_notification(None, email, obj_type, obj_id) | self.send_notification(None, email, obj_type, obj_id) | ||||
else: | else: | ||||
self.add_notif_email(obj_type, obj_id, email) | self.add_notif_email(obj_type, obj_id, email) | ||||
info = self.task_info(obj_type, obj_id) | info = self.task_info(obj_type, obj_id) | ||||
return info | return info | ||||
@autocommit | @autocommit | ||||
def is_available(self, obj_type, obj_id, cursor=None): | def is_available(self, obj_type, obj_id, cursor=None): | ||||
"""Check whether a bundle is available for retrieval""" | |||||
info = self.task_info(obj_type, obj_id, cursor=cursor) | info = self.task_info(obj_type, obj_id, cursor=cursor) | ||||
return (info is not None | return (info is not None | ||||
and info['task_status'] == 'done' | and info['task_status'] == 'done' | ||||
and self.cache.is_cached(obj_type, obj_id)) | and self.cache.is_cached(obj_type, obj_id)) | ||||
@autocommit | @autocommit | ||||
def fetch(self, obj_type, obj_id, cursor=None): | def fetch(self, obj_type, obj_id, cursor=None): | ||||
"""Retrieve a bundle from the cache""" | |||||
if not self.is_available(obj_type, obj_id, cursor=cursor): | if not self.is_available(obj_type, obj_id, cursor=cursor): | ||||
return None | return None | ||||
self.update_access_ts(obj_type, obj_id, cursor=cursor) | self.update_access_ts(obj_type, obj_id, cursor=cursor) | ||||
return self.cache.get(obj_type, obj_id) | return self.cache.get(obj_type, obj_id) | ||||
@autocommit | @autocommit | ||||
def update_access_ts(self, obj_type, obj_id, cursor=None): | def update_access_ts(self, obj_type, obj_id, cursor=None): | ||||
"""Update the last access timestamp of a bundle""" | |||||
obj_id = hashutil.hash_to_bytes(obj_id) | obj_id = hashutil.hash_to_bytes(obj_id) | ||||
cursor.execute(''' | cursor.execute(''' | ||||
UPDATE vault_bundle | UPDATE vault_bundle | ||||
SET ts_last_access = NOW() | SET ts_last_access = NOW() | ||||
WHERE type = %s AND object_id = %s''', | WHERE type = %s AND object_id = %s''', | ||||
(obj_type, obj_id)) | (obj_type, obj_id)) | ||||
@autocommit | @autocommit | ||||
def set_status(self, obj_type, obj_id, status, cursor=None): | def set_status(self, obj_type, obj_id, status, cursor=None): | ||||
"""Set the cooking status of a bundle""" | |||||
obj_id = hashutil.hash_to_bytes(obj_id) | obj_id = hashutil.hash_to_bytes(obj_id) | ||||
req = (''' | req = (''' | ||||
UPDATE vault_bundle | UPDATE vault_bundle | ||||
SET task_status = %s ''' | SET task_status = %s ''' | ||||
+ (''', ts_done = NOW() ''' if status == 'done' else '') | + (''', ts_done = NOW() ''' if status == 'done' else '') | ||||
+ '''WHERE type = %s AND object_id = %s''') | + '''WHERE type = %s AND object_id = %s''') | ||||
cursor.execute(req, (status, obj_type, obj_id)) | cursor.execute(req, (status, obj_type, obj_id)) | ||||
@autocommit | @autocommit | ||||
def set_progress(self, obj_type, obj_id, progress, cursor=None): | def set_progress(self, obj_type, obj_id, progress, cursor=None): | ||||
"""Set the cooking progress of a bundle""" | |||||
obj_id = hashutil.hash_to_bytes(obj_id) | obj_id = hashutil.hash_to_bytes(obj_id) | ||||
cursor.execute(''' | cursor.execute(''' | ||||
UPDATE vault_bundle | UPDATE vault_bundle | ||||
SET progress_msg = %s | SET progress_msg = %s | ||||
WHERE type = %s AND object_id = %s''', | WHERE type = %s AND object_id = %s''', | ||||
(progress, obj_type, obj_id)) | (progress, obj_type, obj_id)) | ||||
@autocommit | @autocommit | ||||
def send_all_notifications(self, obj_type, obj_id, cursor=None): | def send_all_notifications(self, obj_type, obj_id, cursor=None): | ||||
"""Send all the e-mails in the notification list of a bundle""" | |||||
obj_id = hashutil.hash_to_bytes(obj_id) | obj_id = hashutil.hash_to_bytes(obj_id) | ||||
cursor.execute(''' | cursor.execute(''' | ||||
SELECT vault_notif_email.id AS id, email | SELECT vault_notif_email.id AS id, email | ||||
FROM vault_notif_email | FROM vault_notif_email | ||||
INNER JOIN vault_bundle ON bundle_id = vault_bundle.id | INNER JOIN vault_bundle ON bundle_id = vault_bundle.id | ||||
WHERE vault_bundle.type = %s AND vault_bundle.object_id = %s''', | WHERE vault_bundle.type = %s AND vault_bundle.object_id = %s''', | ||||
(obj_type, obj_id)) | (obj_type, obj_id)) | ||||
for d in cursor: | for d in cursor: | ||||
self.send_notification(d['id'], d['email'], obj_type, obj_id) | self.send_notification(d['id'], d['email'], obj_type, obj_id) | ||||
@autocommit | @autocommit | ||||
def send_notification(self, n_id, email, obj_type, obj_id, cursor=None): | def send_notification(self, n_id, email, obj_type, obj_id, cursor=None): | ||||
"""Send the notification of a bundle to a specific e-mail""" | |||||
hex_id = hashutil.hash_to_hex(obj_id) | hex_id = hashutil.hash_to_hex(obj_id) | ||||
short_id = hex_id[:7] | short_id = hex_id[:7] | ||||
# TODO: instead of hardcoding this, we should probably: | # TODO: instead of hardcoding this, we should probably: | ||||
# * add a "fetch_url" field in the vault_notif_email table | # * add a "fetch_url" field in the vault_notif_email table | ||||
# * generate the url with flask.url_for() on the web-ui side | # * generate the url with flask.url_for() on the web-ui side | ||||
# * send this url as part of the cook request and store it in | # * send this url as part of the cook request and store it in | ||||
# the table | # the table | ||||
Show All 10 Lines | def send_notification(self, n_id, email, obj_type, obj_id, cursor=None): | ||||
msg['To'] = email | msg['To'] = email | ||||
self.smtp_server.send_message(msg) | self.smtp_server.send_message(msg) | ||||
if n_id is not None: | if n_id is not None: | ||||
cursor.execute(''' | cursor.execute(''' | ||||
DELETE FROM vault_notif_email | DELETE FROM vault_notif_email | ||||
WHERE id = %s''', (n_id,)) | WHERE id = %s''', (n_id,)) | ||||
@autocommit | |||||
zack: please add a short docstring here explaining this is the low-level expiry method, used by the… | |||||
def _cache_expire(self, cond, *args, cursor=None): | |||||
"""Low-level expiration method, used by cache_expire_* methods""" | |||||
# Embedded SELECT query to be able to use ORDER BY and LIMIT | |||||
cursor.execute(''' | |||||
DELETE FROM vault_bundle | |||||
WHERE ctid IN ( | |||||
SELECT ctid | |||||
FROM vault_bundle | |||||
WHERE sticky = false | |||||
{} | |||||
) | |||||
RETURNING type, object_id | |||||
'''.format(cond), args) | |||||
for d in cursor: | |||||
self.cache.delete(d['type'], bytes(d['object_id'])) | |||||
Not Done Inline Actionshow about "cache_expire_recent" here? Also, please add a docstring here explaining the policy behind the cache expiry done by this method. zack: how about "cache_expire_recent" here?
"count" doesn't really convey the policy for picking the… | |||||
@autocommit | |||||
def cache_expire_oldest(self, n=1, by='last_access', cursor=None): | |||||
"""Expire the `n` oldest bundles""" | |||||
assert by in ('created', 'done', 'last_access') | |||||
filter = '''ORDER BY ts_{} LIMIT {}'''.format(by, n) | |||||
return self._cache_expire(filter) | |||||
Not Done Inline Actionsditto: docstring missing here explaining the policy implemented by this expiry method zack: ditto: docstring missing here explaining the policy implemented by this expiry method | |||||
@autocommit | |||||
def cache_expire_until(self, date, by='last_access', cursor=None): | |||||
"""Expire all the bundles until a certain date""" | |||||
assert by in ('created', 'done', 'last_access') | |||||
filter = '''AND ts_{} <= %s'''.format(by) | |||||
return self._cache_expire(filter, date) |
please add a short docstring here explaining this is the low-level expiry method, used by the other top-level ones