diff --git a/swh/storage/common.py b/swh/storage/common.py index 11eddf8d..e32ba824 100644 --- a/swh/storage/common.py +++ b/swh/storage/common.py @@ -1,80 +1,6 @@ # Copyright (C) 2015-2016 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import inspect -import functools - - -def apply_options(cursor, options): - """Applies the given postgresql client options to the given cursor. - - Returns a dictionary with the old values if they changed.""" - old_options = {} - for option, value in options.items(): - cursor.execute('SHOW %s' % option) - old_value = cursor.fetchall()[0][0] - if old_value != value: - cursor.execute('SET LOCAL %s TO %%s' % option, (value,)) - old_options[option] = old_value - return old_options - - -def db_transaction(**client_options): - """decorator to execute Storage methods within DB transactions - - The decorated method must accept a `cur` and `db` keyword argument - - Client options are passed as `set` options to the postgresql server - """ - def decorator(meth, __client_options=client_options): - if inspect.isgeneratorfunction(meth): - raise ValueError( - 'Use db_transaction_generator for generator functions.') - - @functools.wraps(meth) - def _meth(self, *args, **kwargs): - if 'cur' in kwargs and kwargs['cur']: - cur = kwargs['cur'] - old_options = apply_options(cur, __client_options) - ret = meth(self, *args, **kwargs) - apply_options(cur, old_options) - return ret - else: - db = self.get_db() - with db.transaction() as cur: - apply_options(cur, __client_options) - return meth(self, *args, db=db, cur=cur, **kwargs) - return _meth - - return decorator - - -def db_transaction_generator(**client_options): - """decorator to execute Storage methods within DB transactions, while - returning a generator - - The decorated method must accept a `cur` and `db` keyword argument - - Client options are passed as `set` options to the postgresql server - """ - def decorator(meth, __client_options=client_options): - if not inspect.isgeneratorfunction(meth): - raise ValueError( - 'Use db_transaction for non-generator functions.') - - @functools.wraps(meth) - def _meth(self, *args, **kwargs): - if 'cur' in kwargs and kwargs['cur']: - cur = kwargs['cur'] - old_options = apply_options(cur, __client_options) - yield from meth(self, *args, **kwargs) - apply_options(cur, old_options) - else: - db = self.get_db() - with db.transaction() as cur: - apply_options(cur, __client_options) - yield from meth(self, *args, db=db, cur=cur, **kwargs) - return _meth - return decorator +from swh.core.db.common import * # noqa diff --git a/swh/storage/db.py b/swh/storage/db.py index f24b0a0a..8022ad23 100644 --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -1,1059 +1,848 @@ -# Copyright (C) 2015-2018 The Software Heritage developers +# Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import binascii -import datetime -import enum -import functools -import json -import os import select -import threading -from contextlib import contextmanager - -import psycopg2 -import psycopg2.extras - -from .db_utils import execute_values_generator - -TMP_CONTENT_TABLE = 'tmp_content' - - -psycopg2.extras.register_uuid() - - -def stored_procedure(stored_proc): - """decorator to execute remote stored procedure, specified as argument - - Generally, the body of the decorated function should be empty. If it is - not, the stored procedure will be executed first; the function body then. - - """ - def wrap(meth): - @functools.wraps(meth) - def _meth(self, *args, **kwargs): - cur = kwargs.get('cur', None) - self._cursor(cur).execute('SELECT %s()' % stored_proc) - meth(self, *args, **kwargs) - return _meth - return wrap - - -def jsonize(value): - """Convert a value to a psycopg2 JSON object if necessary""" - if isinstance(value, dict): - return psycopg2.extras.Json(value) - - return value - - -def entry_to_bytes(entry): - """Convert an entry coming from the database to bytes""" - if isinstance(entry, memoryview): - return entry.tobytes() - if isinstance(entry, list): - return [entry_to_bytes(value) for value in entry] - return entry - - -def line_to_bytes(line): - """Convert a line coming from the database to bytes""" - if not line: - return line - if isinstance(line, dict): - return {k: entry_to_bytes(v) for k, v in line.items()} - return line.__class__(entry_to_bytes(entry) for entry in line) - - -def cursor_to_bytes(cursor): - """Yield all the data from a cursor as bytes""" - yield from (line_to_bytes(line) for line in cursor) - - -def execute_values_to_bytes(*args, **kwargs): - for line in execute_values_generator(*args, **kwargs): - yield line_to_bytes(line) - - -class BaseDb: - """Base class for swh.storage.*Db. - - cf. swh.storage.db.Db, swh.archiver.db.ArchiverDb - - """ - - @classmethod - def connect(cls, *args, **kwargs): - """factory method to create a DB proxy - - Accepts all arguments of psycopg2.connect; only some specific - possibilities are reported below. - - Args: - connstring: libpq2 connection string - - """ - conn = psycopg2.connect(*args, **kwargs) - return cls(conn) - - @classmethod - def from_pool(cls, pool): - return cls(pool.getconn(), pool=pool) - - def _cursor(self, cur_arg): - """get a cursor: from cur_arg if given, or a fresh one otherwise - - meant to avoid boilerplate if/then/else in methods that proxy stored - procedures - - """ - if cur_arg is not None: - return cur_arg - # elif self.cur is not None: - # return self.cur - else: - return self.conn.cursor() - - def __init__(self, conn, pool=None): - """create a DB proxy - - Args: - conn: psycopg2 connection to the SWH DB - pool: psycopg2 pool of connections - - """ - self.conn = conn - self.pool = pool - - def __del__(self): - if self.pool: - self.pool.putconn(self.conn) - - @contextmanager - def transaction(self): - """context manager to execute within a DB transaction - - Yields: - a psycopg2 cursor - - """ - with self.conn.cursor() as cur: - try: - yield cur - self.conn.commit() - except Exception: - if not self.conn.closed: - self.conn.rollback() - raise - - def copy_to(self, items, tblname, columns, cur=None, item_cb=None): - """Copy items' entries to table tblname with columns information. - - Args: - items (dict): dictionary of data to copy over tblname - tblname (str): Destination table's name - columns ([str]): keys to access data in items and also the - column names in the destination table. - item_cb (fn): optional function to apply to items's entry - - """ - def escape(data): - if data is None: - return '' - if isinstance(data, bytes): - return '\\x%s' % binascii.hexlify(data).decode('ascii') - elif isinstance(data, str): - return '"%s"' % data.replace('"', '""') - elif isinstance(data, datetime.datetime): - # We escape twice to make sure the string generated by - # isoformat gets escaped - return escape(data.isoformat()) - elif isinstance(data, dict): - return escape(json.dumps(data)) - elif isinstance(data, list): - return escape("{%s}" % ','.join(escape(d) for d in data)) - elif isinstance(data, psycopg2.extras.Range): - # We escape twice here too, so that we make sure - # everything gets passed to copy properly - return escape( - '%s%s,%s%s' % ( - '[' if data.lower_inc else '(', - '-infinity' if data.lower_inf else escape(data.lower), - 'infinity' if data.upper_inf else escape(data.upper), - ']' if data.upper_inc else ')', - ) - ) - elif isinstance(data, enum.IntEnum): - return escape(int(data)) - else: - # We don't escape here to make sure we pass literals properly - return str(data) - - read_file, write_file = os.pipe() - - def writer(): - cursor = self._cursor(cur) - with open(read_file, 'r') as f: - cursor.copy_expert('COPY %s (%s) FROM STDIN CSV' % ( - tblname, ', '.join(columns)), f) - - write_thread = threading.Thread(target=writer) - write_thread.start() - - try: - with open(write_file, 'w') as f: - for d in items: - if item_cb is not None: - item_cb(d) - line = [escape(d.get(k)) for k in columns] - f.write(','.join(line)) - f.write('\n') - finally: - # No problem bubbling up exceptions, but we still need to make sure - # we finish copying, even though we're probably going to cancel the - # transaction. - write_thread.join() - - def mktemp(self, tblname, cur=None): - self._cursor(cur).execute('SELECT swh_mktemp(%s)', (tblname,)) +from swh.core.db import BaseDb +from swh.core.db.db_utils import stored_procedure, execute_values_to_bytes +from swh.core.db.db_utils import cursor_to_bytes, line_to_bytes, jsonize class Db(BaseDb): """Proxy to the SWH DB, with wrappers around stored procedures """ def mktemp_dir_entry(self, entry_type, cur=None): self._cursor(cur).execute('SELECT swh_mktemp_dir_entry(%s)', (('directory_entry_%s' % entry_type),)) @stored_procedure('swh_mktemp_revision') def mktemp_revision(self, cur=None): pass @stored_procedure('swh_mktemp_release') def mktemp_release(self, cur=None): pass @stored_procedure('swh_mktemp_snapshot_branch') def mktemp_snapshot_branch(self, cur=None): pass def register_listener(self, notify_queue, cur=None): """Register a listener for NOTIFY queue `notify_queue`""" self._cursor(cur).execute("LISTEN %s" % notify_queue) def listen_notifies(self, timeout): """Listen to notifications for `timeout` seconds""" if select.select([self.conn], [], [], timeout) == ([], [], []): return else: self.conn.poll() while self.conn.notifies: yield self.conn.notifies.pop(0) @stored_procedure('swh_content_add') def content_add_from_temp(self, cur=None): pass @stored_procedure('swh_directory_add') def directory_add_from_temp(self, cur=None): pass @stored_procedure('swh_skipped_content_add') def skipped_content_add_from_temp(self, cur=None): pass @stored_procedure('swh_revision_add') def revision_add_from_temp(self, cur=None): pass @stored_procedure('swh_release_add') def release_add_from_temp(self, cur=None): pass def content_update_from_temp(self, keys_to_update, cur=None): cur = self._cursor(cur) cur.execute("""select swh_content_update(ARRAY[%s] :: text[])""" % keys_to_update) content_get_metadata_keys = [ 'sha1', 'sha1_git', 'sha256', 'blake2s256', 'length', 'status'] skipped_content_keys = [ 'sha1', 'sha1_git', 'sha256', 'blake2s256', 'length', 'reason', 'status', 'origin'] def content_get_metadata_from_sha1s(self, sha1s, cur=None): cur = self._cursor(cur) yield from execute_values_to_bytes( cur, """ select t.sha1, %s from (values %%s) as t (sha1) left join content using (sha1) """ % ', '.join(self.content_get_metadata_keys[1:]), ((sha1,) for sha1 in sha1s), ) def content_get_range(self, start, end, limit=None, cur=None): """Retrieve contents within range [start, end]. """ cur = self._cursor(cur) query = """select %s from content where %%s <= sha1 and sha1 <= %%s order by sha1 limit %%s""" % ', '.join(self.content_get_metadata_keys) cur.execute(query, (start, end, limit)) yield from cursor_to_bytes(cur) content_hash_keys = ['sha1', 'sha1_git', 'sha256', 'blake2s256'] def content_missing_from_list(self, contents, cur=None): cur = self._cursor(cur) keys = ', '.join(self.content_hash_keys) equality = ' AND '.join( ('t.%s = c.%s' % (key, key)) for key in self.content_hash_keys ) yield from execute_values_to_bytes( cur, """ SELECT %s FROM (VALUES %%s) as t(%s) WHERE NOT EXISTS ( SELECT 1 FROM content c WHERE %s ) """ % (keys, keys, equality), (tuple(c[key] for key in self.content_hash_keys) for c in contents) ) def content_missing_per_sha1(self, sha1s, cur=None): cur = self._cursor(cur) yield from execute_values_to_bytes(cur, """ SELECT t.sha1 FROM (VALUES %s) AS t(sha1) WHERE NOT EXISTS ( SELECT 1 FROM content c WHERE c.sha1 = t.sha1 )""", ((sha1,) for sha1 in sha1s)) def skipped_content_missing_from_temp(self, cur=None): cur = self._cursor(cur) cur.execute("""SELECT sha1, sha1_git, sha256, blake2s256 FROM swh_skipped_content_missing()""") yield from cursor_to_bytes(cur) def snapshot_exists(self, snapshot_id, cur=None): """Check whether a snapshot with the given id exists""" cur = self._cursor(cur) cur.execute("""SELECT 1 FROM snapshot where id=%s""", (snapshot_id,)) return bool(cur.fetchone()) def snapshot_add(self, origin, visit, snapshot_id, cur=None): """Add a snapshot for origin/visit from the temporary table""" cur = self._cursor(cur) cur.execute("""SELECT swh_snapshot_add(%s, %s, %s)""", (origin, visit, snapshot_id)) snapshot_count_cols = ['target_type', 'count'] def snapshot_count_branches(self, snapshot_id, cur=None): cur = self._cursor(cur) query = """\ SELECT %s FROM swh_snapshot_count_branches(%%s) """ % ', '.join(self.snapshot_count_cols) cur.execute(query, (snapshot_id,)) yield from cursor_to_bytes(cur) snapshot_get_cols = ['snapshot_id', 'name', 'target', 'target_type'] def snapshot_get_by_id(self, snapshot_id, branches_from=b'', branches_count=None, target_types=None, cur=None): cur = self._cursor(cur) query = """\ SELECT %s FROM swh_snapshot_get_by_id(%%s, %%s, %%s, %%s :: snapshot_target[]) """ % ', '.join(self.snapshot_get_cols) cur.execute(query, (snapshot_id, branches_from, branches_count, target_types)) yield from cursor_to_bytes(cur) def snapshot_get_by_origin_visit(self, origin_id, visit_id, cur=None): cur = self._cursor(cur) query = """\ SELECT swh_snapshot_get_by_origin_visit(%s, %s) """ cur.execute(query, (origin_id, visit_id)) ret = cur.fetchone() if ret: return line_to_bytes(ret)[0] content_find_cols = ['sha1', 'sha1_git', 'sha256', 'blake2s256', 'length', 'ctime', 'status'] def content_find(self, sha1=None, sha1_git=None, sha256=None, blake2s256=None, cur=None): """Find the content optionally on a combination of the following checksums sha1, sha1_git, sha256 or blake2s256. Args: sha1: sha1 content git_sha1: the sha1 computed `a la git` sha1 of the content sha256: sha256 content blake2s256: blake2s256 content Returns: The tuple (sha1, sha1_git, sha256, blake2s256) if found or None. """ cur = self._cursor(cur) cur.execute("""SELECT %s FROM swh_content_find(%%s, %%s, %%s, %%s) LIMIT 1""" % ','.join(self.content_find_cols), (sha1, sha1_git, sha256, blake2s256)) content = line_to_bytes(cur.fetchone()) if set(content) == {None}: return None else: return content def directory_missing_from_list(self, directories, cur=None): cur = self._cursor(cur) yield from execute_values_to_bytes( cur, """ SELECT id FROM (VALUES %s) as t(id) WHERE NOT EXISTS ( SELECT 1 FROM directory d WHERE d.id = t.id ) """, ((id,) for id in directories)) directory_ls_cols = ['dir_id', 'type', 'target', 'name', 'perms', 'status', 'sha1', 'sha1_git', 'sha256', 'length'] def directory_walk_one(self, directory, cur=None): cur = self._cursor(cur) cols = ', '.join(self.directory_ls_cols) query = 'SELECT %s FROM swh_directory_walk_one(%%s)' % cols cur.execute(query, (directory,)) yield from cursor_to_bytes(cur) def directory_walk(self, directory, cur=None): cur = self._cursor(cur) cols = ', '.join(self.directory_ls_cols) query = 'SELECT %s FROM swh_directory_walk(%%s)' % cols cur.execute(query, (directory,)) yield from cursor_to_bytes(cur) def directory_entry_get_by_path(self, directory, paths, cur=None): """Retrieve a directory entry by path. """ cur = self._cursor(cur) cols = ', '.join(self.directory_ls_cols) query = ( 'SELECT %s FROM swh_find_directory_entry_by_path(%%s, %%s)' % cols) cur.execute(query, (directory, paths)) data = cur.fetchone() if set(data) == {None}: return None return line_to_bytes(data) def revision_missing_from_list(self, revisions, cur=None): cur = self._cursor(cur) yield from execute_values_to_bytes( cur, """ SELECT id FROM (VALUES %s) as t(id) WHERE NOT EXISTS ( SELECT 1 FROM revision r WHERE r.id = t.id ) """, ((id,) for id in revisions)) revision_add_cols = [ 'id', 'date', 'date_offset', 'date_neg_utc_offset', 'committer_date', 'committer_date_offset', 'committer_date_neg_utc_offset', 'type', 'directory', 'message', 'author_fullname', 'author_name', 'author_email', 'committer_fullname', 'committer_name', 'committer_email', 'metadata', 'synthetic', ] revision_get_cols = revision_add_cols + [ 'author_id', 'committer_id', 'parents'] def origin_visit_add(self, origin, ts, cur=None): """Add a new origin_visit for origin origin at timestamp ts with status 'ongoing'. Args: origin: origin concerned by the visit ts: the date of the visit Returns: The new visit index step for that origin """ cur = self._cursor(cur) self._cursor(cur).execute('SELECT swh_origin_visit_add(%s, %s)', (origin, ts)) return cur.fetchone()[0] def origin_visit_update(self, origin, visit_id, status, metadata, cur=None): """Update origin_visit's status.""" cur = self._cursor(cur) update = """UPDATE origin_visit SET status=%s, metadata=%s WHERE origin=%s AND visit=%s""" cur.execute(update, (status, jsonize(metadata), origin, visit_id)) origin_visit_get_cols = ['origin', 'visit', 'date', 'status', 'metadata', 'snapshot'] def origin_visit_get_all(self, origin_id, last_visit=None, limit=None, cur=None): """Retrieve all visits for origin with id origin_id. Args: origin_id: The occurrence's origin Yields: The occurrence's history visits """ cur = self._cursor(cur) if last_visit: extra_condition = 'and visit > %s' args = (origin_id, last_visit, limit) else: extra_condition = '' args = (origin_id, limit) query = """\ SELECT %s, (select id from snapshot where object_id = snapshot_id) as snapshot FROM origin_visit WHERE origin=%%s %s order by visit asc limit %%s""" % ( ', '.join(self.origin_visit_get_cols[:-1]), extra_condition ) cur.execute(query, args) yield from cursor_to_bytes(cur) def origin_visit_get(self, origin_id, visit_id, cur=None): """Retrieve information on visit visit_id of origin origin_id. Args: origin_id: the origin concerned visit_id: The visit step for that origin Returns: The origin_visit information """ cur = self._cursor(cur) query = """\ SELECT %s, (select id from snapshot where object_id = snapshot_id) as snapshot FROM origin_visit WHERE origin = %%s AND visit = %%s """ % (', '.join(self.origin_visit_get_cols[:-1])) cur.execute(query, (origin_id, visit_id)) r = cur.fetchall() if not r: return None return line_to_bytes(r[0]) def origin_visit_exists(self, origin_id, visit_id, cur=None): """Check whether an origin visit with the given ids exists""" cur = self._cursor(cur) query = "SELECT 1 FROM origin_visit where origin = %s AND visit = %s" cur.execute(query, (origin_id, visit_id)) return bool(cur.fetchone()) def origin_visit_get_latest_snapshot(self, origin_id, allowed_statuses=None, cur=None): """Retrieve the most recent origin_visit which references a snapshot Args: origin_id: the origin concerned allowed_statuses: the visit statuses allowed for the returned visit Returns: The origin_visit information, or None if no visit matches. """ cur = self._cursor(cur) extra_clause = "" if allowed_statuses: extra_clause = cur.mogrify("AND status IN %s", (tuple(allowed_statuses),)).decode() query = """\ SELECT %s, (select id from snapshot where object_id = snapshot_id) as snapshot FROM origin_visit WHERE origin = %%s AND snapshot_id is not null %s ORDER BY date DESC, visit DESC LIMIT 1 """ % (', '.join(self.origin_visit_get_cols[:-1]), extra_clause) cur.execute(query, (origin_id,)) r = cur.fetchone() if not r: return None return line_to_bytes(r) @staticmethod def mangle_query_key(key, main_table): if key == 'id': return 't.id' if key == 'parents': return ''' ARRAY( SELECT rh.parent_id::bytea FROM revision_history rh WHERE rh.id = t.id ORDER BY rh.parent_rank )''' if '_' not in key: return '%s.%s' % (main_table, key) head, tail = key.split('_', 1) if (head in ('author', 'committer') and tail in ('name', 'email', 'id', 'fullname')): return '%s.%s' % (head, tail) return '%s.%s' % (main_table, key) def revision_get_from_list(self, revisions, cur=None): cur = self._cursor(cur) query_keys = ', '.join( self.mangle_query_key(k, 'revision') for k in self.revision_get_cols ) yield from execute_values_to_bytes( cur, """ SELECT %s FROM (VALUES %%s) as t(id) LEFT JOIN revision ON t.id = revision.id LEFT JOIN person author ON revision.author = author.id LEFT JOIN person committer ON revision.committer = committer.id """ % query_keys, ((id,) for id in revisions)) def revision_log(self, root_revisions, limit=None, cur=None): cur = self._cursor(cur) query = """SELECT %s FROM swh_revision_log(%%s, %%s) """ % ', '.join(self.revision_get_cols) cur.execute(query, (root_revisions, limit)) yield from cursor_to_bytes(cur) revision_shortlog_cols = ['id', 'parents'] def revision_shortlog(self, root_revisions, limit=None, cur=None): cur = self._cursor(cur) query = """SELECT %s FROM swh_revision_list(%%s, %%s) """ % ', '.join(self.revision_shortlog_cols) cur.execute(query, (root_revisions, limit)) yield from cursor_to_bytes(cur) def release_missing_from_list(self, releases, cur=None): cur = self._cursor(cur) yield from execute_values_to_bytes( cur, """ SELECT id FROM (VALUES %s) as t(id) WHERE NOT EXISTS ( SELECT 1 FROM release r WHERE r.id = t.id ) """, ((id,) for id in releases)) object_find_by_sha1_git_cols = ['sha1_git', 'type', 'id', 'object_id'] def object_find_by_sha1_git(self, ids, cur=None): cur = self._cursor(cur) yield from execute_values_to_bytes( cur, """ WITH t (id) AS (VALUES %s), known_objects as (( select id as sha1_git, 'release'::object_type as type, id, object_id from release r where exists (select 1 from t where t.id = r.id) ) union all ( select id as sha1_git, 'revision'::object_type as type, id, object_id from revision r where exists (select 1 from t where t.id = r.id) ) union all ( select id as sha1_git, 'directory'::object_type as type, id, object_id from directory d where exists (select 1 from t where t.id = d.id) ) union all ( select sha1_git as sha1_git, 'content'::object_type as type, sha1 as id, object_id from content c where exists (select 1 from t where t.id = c.sha1_git) )) select t.id as sha1_git, k.type, k.id, k.object_id from t left join known_objects k on t.id = k.sha1_git """, ((id,) for id in ids) ) def stat_counters(self, cur=None): cur = self._cursor(cur) cur.execute('SELECT * FROM swh_stat_counters()') yield from cur fetch_history_cols = ['origin', 'date', 'status', 'result', 'stdout', 'stderr', 'duration'] def create_fetch_history(self, fetch_history, cur=None): """Create a fetch_history entry with the data in fetch_history""" cur = self._cursor(cur) query = '''INSERT INTO fetch_history (%s) VALUES (%s) RETURNING id''' % ( ','.join(self.fetch_history_cols), ','.join(['%s'] * len(self.fetch_history_cols)) ) cur.execute(query, [fetch_history.get(col) for col in self.fetch_history_cols]) return cur.fetchone()[0] def get_fetch_history(self, fetch_history_id, cur=None): """Get a fetch_history entry with the given id""" cur = self._cursor(cur) query = '''SELECT %s FROM fetch_history WHERE id=%%s''' % ( ', '.join(self.fetch_history_cols), ) cur.execute(query, (fetch_history_id,)) data = cur.fetchone() if not data: return None ret = {'id': fetch_history_id} for i, col in enumerate(self.fetch_history_cols): ret[col] = data[i] return ret def update_fetch_history(self, fetch_history, cur=None): """Update the fetch_history entry from the data in fetch_history""" cur = self._cursor(cur) query = '''UPDATE fetch_history SET %s WHERE id=%%s''' % ( ','.join('%s=%%s' % col for col in self.fetch_history_cols) ) cur.execute(query, [jsonize(fetch_history.get(col)) for col in self.fetch_history_cols + ['id']]) def origin_add(self, type, url, cur=None): """Insert a new origin and return the new identifier.""" insert = """INSERT INTO origin (type, url) values (%s, %s) RETURNING id""" cur.execute(insert, (type, url)) return cur.fetchone()[0] origin_cols = ['id', 'type', 'url'] def origin_get_with(self, type, url, cur=None): """Retrieve the origin id from its type and url if found.""" cur = self._cursor(cur) query = """SELECT %s FROM origin WHERE type=%%s AND url=%%s """ % ','.join(self.origin_cols) cur.execute(query, (type, url)) data = cur.fetchone() if data: return line_to_bytes(data) return None def origin_get(self, id, cur=None): """Retrieve the origin per its identifier. """ cur = self._cursor(cur) query = """SELECT %s FROM origin WHERE id=%%s """ % ','.join(self.origin_cols) cur.execute(query, (id,)) data = cur.fetchone() if data: return line_to_bytes(data) return None def origin_search(self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False, cur=None): """Search for origins whose urls contain a provided string pattern or match a provided regular expression. The search is performed in a case insensitive way. Args: url_pattern (str): the string pattern to search for in origin urls offset (int): number of found origins to skip before returning results limit (int): the maximum number of found origins to return regexp (bool): if True, consider the provided pattern as a regular expression and returns origins whose urls match it with_visit (bool): if True, filter out origins with no visit """ cur = self._cursor(cur) origin_cols = ','.join(self.origin_cols) query = """SELECT %s FROM origin WHERE """ if with_visit: query += """ EXISTS (SELECT 1 from origin_visit WHERE origin=origin.id) AND """ query += """ url %s %%s ORDER BY id OFFSET %%s LIMIT %%s""" if not regexp: query = query % (origin_cols, 'ILIKE') query_params = ('%'+url_pattern+'%', offset, limit) else: query = query % (origin_cols, '~*') query_params = (url_pattern, offset, limit) cur.execute(query, query_params) yield from cursor_to_bytes(cur) person_cols = ['fullname', 'name', 'email'] person_get_cols = person_cols + ['id'] def origin_get_range(self, origin_from=1, origin_count=100, cur=None): """Retrieve ``origin_count`` origins whose ids are greater or equal than ``origin_from``. Origins are sorted by id before retrieving them. Args: origin_from (int): the minimum id of origins to retrieve origin_count (int): the maximum number of origins to retrieve """ cur = self._cursor(cur) query = """SELECT %s FROM origin WHERE id >= %%s ORDER BY id LIMIT %%s """ % ','.join(self.origin_cols) cur.execute(query, (origin_from, origin_count)) yield from cursor_to_bytes(cur) def person_get(self, ids, cur=None): """Retrieve the persons identified by the list of ids. """ cur = self._cursor(cur) query = """SELECT %s FROM person WHERE id IN %%s""" % ', '.join(self.person_get_cols) cur.execute(query, (tuple(ids),)) yield from cursor_to_bytes(cur) release_add_cols = [ 'id', 'target', 'target_type', 'date', 'date_offset', 'date_neg_utc_offset', 'name', 'comment', 'synthetic', 'author_fullname', 'author_name', 'author_email', ] release_get_cols = release_add_cols + ['author_id'] def release_get_from_list(self, releases, cur=None): cur = self._cursor(cur) query_keys = ', '.join( self.mangle_query_key(k, 'release') for k in self.release_get_cols ) yield from execute_values_to_bytes( cur, """ SELECT %s FROM (VALUES %%s) as t(id) LEFT JOIN release ON t.id = release.id LEFT JOIN person author ON release.author = author.id """ % query_keys, ((id,) for id in releases)) def origin_metadata_add(self, origin, ts, provider, tool, metadata, cur=None): """ Add an origin_metadata for the origin at ts with provider, tool and metadata. Args: origin (int): the origin's id for which the metadata is added ts (datetime): time when the metadata was found provider (int): the metadata provider identifier tool (int): the tool's identifier used to extract metadata metadata (jsonb): the metadata retrieved at the time and location Returns: id (int): the origin_metadata unique id """ cur = self._cursor(cur) insert = """INSERT INTO origin_metadata (origin_id, discovery_date, provider_id, tool_id, metadata) values (%s, %s, %s, %s, %s) RETURNING id""" cur.execute(insert, (origin, ts, provider, tool, jsonize(metadata))) return cur.fetchone()[0] origin_metadata_get_cols = ['origin_id', 'discovery_date', 'tool_id', 'metadata', 'provider_id', 'provider_name', 'provider_type', 'provider_url'] def origin_metadata_get_by(self, origin_id, provider_type=None, cur=None): """Retrieve all origin_metadata entries for one origin_id """ cur = self._cursor(cur) if not provider_type: query = '''SELECT %s FROM swh_origin_metadata_get_by_origin( %%s)''' % (','.join( self.origin_metadata_get_cols)) cur.execute(query, (origin_id, )) else: query = '''SELECT %s FROM swh_origin_metadata_get_by_provider_type( %%s, %%s)''' % (','.join( self.origin_metadata_get_cols)) cur.execute(query, (origin_id, provider_type)) yield from cursor_to_bytes(cur) tool_cols = ['id', 'name', 'version', 'configuration'] @stored_procedure('swh_mktemp_tool') def mktemp_tool(self, cur=None): pass def tool_add_from_temp(self, cur=None): cur = self._cursor(cur) cur.execute("SELECT %s from swh_tool_add()" % ( ','.join(self.tool_cols), )) yield from cursor_to_bytes(cur) def tool_get(self, name, version, configuration, cur=None): cur = self._cursor(cur) cur.execute('''select %s from tool where name=%%s and version=%%s and configuration=%%s''' % ( ','.join(self.tool_cols)), (name, version, configuration)) data = cur.fetchone() if not data: return None return line_to_bytes(data) metadata_provider_cols = ['id', 'provider_name', 'provider_type', 'provider_url', 'metadata'] def metadata_provider_add(self, provider_name, provider_type, provider_url, metadata, cur=None): """Insert a new provider and return the new identifier.""" cur = self._cursor(cur) insert = """INSERT INTO metadata_provider (provider_name, provider_type, provider_url, metadata) values (%s, %s, %s, %s) RETURNING id""" cur.execute(insert, (provider_name, provider_type, provider_url, jsonize(metadata))) return cur.fetchone()[0] def metadata_provider_get(self, provider_id, cur=None): cur = self._cursor(cur) cur.execute('''select %s from metadata_provider where id=%%s ''' % ( ','.join(self.metadata_provider_cols)), (provider_id, )) data = cur.fetchone() if not data: return None return line_to_bytes(data) def metadata_provider_get_by(self, provider_name, provider_url, cur=None): cur = self._cursor(cur) cur.execute('''select %s from metadata_provider where provider_name=%%s and provider_url=%%s''' % ( ','.join(self.metadata_provider_cols)), (provider_name, provider_url)) data = cur.fetchone() if not data: return None return line_to_bytes(data) diff --git a/swh/storage/db_utils.py b/swh/storage/db_utils.py deleted file mode 100644 index 404f07b7..00000000 --- a/swh/storage/db_utils.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (C) 2015-2018 The Software Heritage developers -# See the AUTHORS file at the top-level directory of this distribution -# -# This code has been imported from psycopg2, version 2.7.4, -# https://github.com/psycopg/psycopg2/tree/5afb2ce803debea9533e293eef73c92ffce95bcd -# and modified by Software Heritage. -# -# Original file: lib/extras.py -# -# psycopg2 is free software: you can redistribute it and/or modify it under the -# terms of the GNU Lesser General Public License as published by the Free -# Software Foundation, either version 3 of the License, or (at your option) any -# later version. - - -import re - -import psycopg2.extensions - - -def _paginate(seq, page_size): - """Consume an iterable and return it in chunks. - Every chunk is at most `page_size`. Never return an empty chunk. - """ - page = [] - it = iter(seq) - while 1: - try: - for i in range(page_size): - page.append(next(it)) - yield page - page = [] - except StopIteration: - if page: - yield page - return - - -def _split_sql(sql): - """Split *sql* on a single ``%s`` placeholder. - Split on the %s, perform %% replacement and return pre, post lists of - snippets. - """ - curr = pre = [] - post = [] - tokens = re.split(br'(%.)', sql) - for token in tokens: - if len(token) != 2 or token[:1] != b'%': - curr.append(token) - continue - - if token[1:] == b's': - if curr is pre: - curr = post - else: - raise ValueError( - "the query contains more than one '%s' placeholder") - elif token[1:] == b'%': - curr.append(b'%') - else: - raise ValueError("unsupported format character: '%s'" - % token[1:].decode('ascii', 'replace')) - - if curr is pre: - raise ValueError("the query doesn't contain any '%s' placeholder") - - return pre, post - - -def execute_values_generator(cur, sql, argslist, template=None, page_size=100): - '''Execute a statement using SQL ``VALUES`` with a sequence of parameters. - Rows returned by the query are returned through a generator. - You need to consume the generator for the queries to be executed! - - :param cur: the cursor to use to execute the query. - :param sql: the query to execute. It must contain a single ``%s`` - placeholder, which will be replaced by a `VALUES list`__. - Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``. - :param argslist: sequence of sequences or dictionaries with the arguments - to send to the query. The type and content must be consistent with - *template*. - :param template: the snippet to merge to every item in *argslist* to - compose the query. - - - If the *argslist* items are sequences it should contain positional - placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there - are constants value...). - - If the *argslist* items are mappings it should contain named - placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``). - - If not specified, assume the arguments are sequence and use a simple - positional template (i.e. ``(%s, %s, ...)``), with the number of - placeholders sniffed by the first element in *argslist*. - :param page_size: maximum number of *argslist* items to include in every - statement. If there are more items the function will execute more than - one statement. - :param yield_from_cur: Whether to yield results from the cursor in this - function directly. - - .. __: https://www.postgresql.org/docs/current/static/queries-values.html - - After the execution of the function the `cursor.rowcount` property will - **not** contain a total result. - ''' - # we can't just use sql % vals because vals is bytes: if sql is bytes - # there will be some decoding error because of stupid codec used, and Py3 - # doesn't implement % on bytes. - if not isinstance(sql, bytes): - sql = sql.encode( - psycopg2.extensions.encodings[cur.connection.encoding] - ) - pre, post = _split_sql(sql) - - for page in _paginate(argslist, page_size=page_size): - if template is None: - template = b'(' + b','.join([b'%s'] * len(page[0])) + b')' - parts = pre[:] - for args in page: - parts.append(cur.mogrify(template, args)) - parts.append(b',') - parts[-1:] = post - cur.execute(b''.join(parts)) - yield from cur