diff --git a/swh/storage/db.py b/swh/storage/db.py index c409d5f5..e269cf20 100644 --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -1,105 +1,111 @@ # Copyright (C) 2015 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import binascii import functools import psycopg2 import tempfile from contextlib import contextmanager TMP_CONTENT_TABLE = 'tmp_content' def stored_procedure(stored_proc): """decorator to execute remote stored procedure, specified as argument Generally, the body of the decorated function should be empty. If it is not, the stored procedure will be executed first; the function body then. """ def wrap(meth): @functools.wraps(meth) def _meth(self, *args, **kwargs): cur = kwargs.get('cur', None) self._cursor(cur).execute('SELECT %s()' % stored_proc) meth(self, *args, **kwargs) return _meth return wrap class Db: """Proxy to the SWH DB, with wrappers around stored procedures """ @classmethod def connect(cls, *args, **kwargs): """factory method to create a DB proxy Accepts all arguments of psycopg2.connect; only some specific possibilities are reported below. Args: connstring: libpq2 connection string """ conn = psycopg2.connect(*args, **kwargs) return cls(conn) def _cursor(self, cur_arg): """get a cursor: from cur_arg if given, or a fresh one otherwise meant to avoid boilerplate if/then/else in methods that proxy stored procedures """ if cur_arg is not None: return cur_arg # elif self.cur is not None: # return self.cur else: return self.conn.cursor() def __init__(self, conn): """create a DB proxy Args: conn: psycopg2 connection to the SWH DB """ self.conn = conn @contextmanager def transaction(self): """context manager to execute within a DB transaction Yields: a psycopg2 cursor """ with self.conn.cursor() as cur: try: yield cur self.conn.commit() except: if not self.conn.closed: self.conn.rollback() raise def mktemp(self, tblname, cur=None): self._cursor(cur).execute('SELECT swh_mktemp(%s)', (tblname,)) def copy_to(self, items, tblname, columns, cur=None, item_cb=None): + def escape(data): + if isinstance(data, bytes): + return '\\\\x%s' % binascii.hexlify(data).decode('ascii') + else: + return str(data) with tempfile.TemporaryFile('w+') as f: for d in items: if item_cb is not None: item_cb(d) - line = '\t'.join([str(d[k]) for k in columns]) + '\n' + line = '\t'.join([escape(d[k]) for k in columns]) + '\n' f.write(line) f.seek(0) self._cursor(cur).copy_from(f, tblname, columns=columns) @stored_procedure('swh_content_add') def content_add_from_temp(self, cur=None): pass diff --git a/swh/storage/tests/test_db.py b/swh/storage/tests/test_db.py index 12680931..8c3a75f5 100644 --- a/swh/storage/tests/test_db.py +++ b/swh/storage/tests/test_db.py @@ -1,43 +1,46 @@ # Copyright (C) 2015 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import unittest from nose.tools import istest from nose.plugins.attrib import attr from .db_testing import DbTestFixture +from swh.core.hashutil import hex_to_hash from swh.storage.db import Db @attr('db') class TestDb(DbTestFixture, unittest.TestCase): def setUp(self): super().setUp() self.db = Db(self.conn) def tearDown(self): self.db.conn.close() super().tearDown() @istest def add_content(self): cur = self.cursor - sha1 = '34973274ccef6ab4dfaaf86599792fa9c3fe4689' + sha1 = hex_to_hash('34973274ccef6ab4dfaaf86599792fa9c3fe4689') self.db.mktemp('content', cur) self.db.copy_to([{ 'sha1': sha1, - 'sha1_git': 'd81cc0710eb6cf9efd5b920a8453e1e07157b6cd', - 'sha256': '673650f936cb3b0a2f93ce09d81be10748b1b203' - 'c19e8176b4eefc1964a0cf3a', + 'sha1_git': hex_to_hash( + 'd81cc0710eb6cf9efd5b920a8453e1e07157b6cd'), + 'sha256': hex_to_hash( + '673650f936cb3b0a2f93ce09d81be107' + '48b1b203c19e8176b4eefc1964a0cf3a'), 'length': 3}], 'tmp_content', ['sha1', 'sha1_git', 'sha256', 'length'], cur) self.db.content_add_from_temp(cur) self.cursor.execute('SELECT sha1 FROM content WHERE sha1 = %s', (sha1,)) - self.assertEqual(self.cursor.fetchone(), (sha1,)) + self.assertEqual(self.cursor.fetchone()[0].tobytes(), sha1)