diff --git a/swh/storage/__init__.py b/swh/storage/__init__.py --- a/swh/storage/__init__.py +++ b/swh/storage/__init__.py @@ -3,9 +3,9 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from . import storage +from . import pg_storage -Storage = storage.Storage +Storage = pg_storage.PgStorage def get_storage(cls, args): @@ -29,7 +29,7 @@ if cls == 'remote': from .api.client import RemoteStorage as Storage elif cls == 'local': - from .storage import Storage + from .pg_storage import PgStorage as Storage else: raise ValueError('Unknown storage class `%s`' % cls) diff --git a/swh/storage/common.py b/swh/storage/common.py --- a/swh/storage/common.py +++ b/swh/storage/common.py @@ -7,20 +7,6 @@ import functools -def apply_options(cursor, options): - """Applies the given postgresql client options to the given cursor. - - Returns a dictionary with the old values if they changed.""" - old_options = {} - for option, value in options.items(): - cursor.execute('SHOW %s' % option) - old_value = cursor.fetchall()[0][0] - if old_value != value: - cursor.execute('SET LOCAL %s TO %%s' % option, (value,)) - old_options[option] = old_value - return old_options - - def db_transaction(**client_options): """decorator to execute Storage methods within DB transactions @@ -36,15 +22,16 @@ @functools.wraps(meth) def _meth(self, *args, **kwargs): if 'cur' in kwargs and kwargs['cur']: + db = kwargs['db'] cur = kwargs['cur'] - old_options = apply_options(cur, __client_options) + old_options = db.apply_options(cur, __client_options) ret = meth(self, *args, **kwargs) - apply_options(cur, old_options) + db.apply_options(cur, old_options) return ret else: db = self.get_db() with db.transaction() as cur: - apply_options(cur, __client_options) + db.apply_options(cur, __client_options) return meth(self, *args, db=db, cur=cur, **kwargs) return _meth @@ -67,14 +54,15 @@ @functools.wraps(meth) def _meth(self, *args, **kwargs): if 'cur' in kwargs and kwargs['cur']: + db = kwargs['db'] cur = kwargs['cur'] - old_options = apply_options(cur, __client_options) + old_options = db.apply_options(cur, __client_options) yield from meth(self, *args, **kwargs) - apply_options(cur, old_options) + db.apply_options(cur, old_options) else: db = self.get_db() with db.transaction() as cur: - apply_options(cur, __client_options) + db.apply_options(cur, __client_options) yield from meth(self, *args, db=db, cur=cur, **kwargs) return _meth return decorator diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import abc import binascii import datetime import enum @@ -11,6 +12,7 @@ import os import select import threading +import warnings from contextlib import contextmanager @@ -78,76 +80,27 @@ yield line_to_bytes(line) -class BaseDb: +class BaseDb(metaclass=abc.ABCMeta): """Base class for swh.storage.*Db. cf. swh.storage.db.Db, swh.archiver.db.ArchiverDb """ - @classmethod - def connect(cls, *args, **kwargs): - """factory method to create a DB proxy - - Accepts all arguments of psycopg2.connect; only some specific - possibilities are reported below. - - Args: - connstring: libpq2 connection string - - """ - conn = psycopg2.connect(*args, **kwargs) - return cls(conn) - - @classmethod - def from_pool(cls, pool): - return cls(pool.getconn(), pool=pool) + @abc.abstractmethod + def apply_options(self, cursor, options): + """Applies the given postgresql client options to the given cursor. - def _cursor(self, cur_arg): - """get a cursor: from cur_arg if given, or a fresh one otherwise - - meant to avoid boilerplate if/then/else in methods that proxy stored - procedures - - """ - if cur_arg is not None: - return cur_arg - # elif self.cur is not None: - # return self.cur - else: - return self.conn.cursor() - - def __init__(self, conn, pool=None): - """create a DB proxy + Returns a dictionary with the old values if they changed. - Args: - conn: psycopg2 connection to the SWH DB - pool: psycopg2 pool of connections - - """ - self.conn = conn - self.pool = pool - - def __del__(self): - if self.pool: - self.pool.putconn(self.conn) + Non-postgresql subclasses should emulate these options as much + as possible.""" + pass - @contextmanager + @abc.abstractmethod def transaction(self): - """context manager to execute within a DB transaction - - Yields: - a psycopg2 cursor - - """ - with self.conn.cursor() as cur: - try: - yield cur - self.conn.commit() - except Exception: - if not self.conn.closed: - self.conn.rollback() - raise + """context manager to execute within a DB transaction""" + pass def copy_to(self, items, tblname, columns, cur=None, item_cb=None): """Copy items' entries to table tblname with columns information. @@ -221,10 +174,85 @@ self._cursor(cur).execute('SELECT swh_mktemp(%s)', (tblname,)) -class Db(BaseDb): +class PgDb(BaseDb): """Proxy to the SWH DB, with wrappers around stored procedures """ + + @classmethod + def connect(cls, *args, **kwargs): + """factory method to create a DB proxy + + Accepts all arguments of psycopg2.connect; only some specific + possibilities are reported below. + + Args: + connstring: libpq2 connection string + + """ + conn = psycopg2.connect(*args, **kwargs) + return cls(conn) + + @classmethod + def from_pool(cls, pool): + return cls(pool.getconn(), pool=pool) + + def _cursor(self, cur_arg): + """get a cursor: from cur_arg if given, or a fresh one otherwise + + meant to avoid boilerplate if/then/else in methods that proxy stored + procedures + + """ + if cur_arg is not None: + return cur_arg + # elif self.cur is not None: + # return self.cur + else: + return self.conn.cursor() + + def __init__(self, conn, pool=None): + """create a DB proxy + + Args: + conn: psycopg2 connection to the SWH DB + pool: psycopg2 pool of connections + + """ + self.conn = conn + self.pool = pool + + def __del__(self): + if self.pool: + self.pool.putconn(self.conn) + + @contextmanager + def transaction(self): + """context manager to execute within a DB transaction + + Yields: + a psycopg2 cursor + + """ + with self.conn.cursor() as cur: + try: + yield cur + self.conn.commit() + except Exception: + if not self.conn.closed: + self.conn.rollback() + raise + + def apply_options(self, cursor, options): + old_options = {} + for option, value in options.items(): + cursor.execute('SHOW %s' % option) + old_value = cursor.fetchall()[0][0] + if old_value != value: + cursor.execute('SET LOCAL %s TO %%s' % option, (value,)) + old_options[option] = old_value + return old_options + def mktemp_dir_entry(self, entry_type, cur=None): self._cursor(cur).execute('SELECT swh_mktemp_dir_entry(%s)', (('directory_entry_%s' % entry_type),)) @@ -1035,3 +1063,10 @@ if not data: return None return line_to_bytes(data) + + +class Db(PgDb): + def __init__(self, *args, **kwargs): + warnings.warn("Db was renamed to PgDb in v0.0.109.", + DeprecationWarning) + super().__init__(*args, **kwargs) diff --git a/swh/storage/pg_storage.py b/swh/storage/pg_storage.py new file mode 100644 --- /dev/null +++ b/swh/storage/pg_storage.py @@ -0,0 +1,240 @@ +# Copyright (C) 2015-2018 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + + +from collections import defaultdict +import itertools + +import psycopg2 +import psycopg2.pool + +from . import converters +from .common import db_transaction_generator, db_transaction +from .db import PgDb +from .exc import StorageDBError +from .storage import BaseStorage + +from swh.model.hashutil import hash_to_bytes + +# Max block size of contents to return +BULK_BLOCK_CONTENT_LEN_MAX = 10000 + +EMPTY_SNAPSHOT_ID = hash_to_bytes('1a8893e6a86f444e8be8e7bda6cb34fb1735a00e') +"""Identifier for the empty snapshot""" + + +class PgStorage(BaseStorage): + """SWH storage proxy, encompassing DB and object storage + + """ + + def _init_db(self, db, min_pool_conns=1, max_pool_conns=10): + """ + Args: + db_conn: either a libpq connection string, or a psycopg2 connection + obj_root: path to the root of the object storage + + """ + try: + if isinstance(db, psycopg2.extensions.connection): + self._pool = None + return PgDb(db) + else: + self._pool = psycopg2.pool.ThreadedConnectionPool( + min_pool_conns, max_pool_conns, db + ) + return None + except psycopg2.OperationalError as e: + raise StorageDBError(e) + + def get_db(self): + if self._db: + return self._db + else: + return PgDb.from_pool(self._pool) + + def check_config(self, *, check_write): + if not self.objstorage.check_config(check_write=check_write): + return False + + # Check permissions on one of the tables + with self.get_db().transaction() as cur: + if check_write: + check = 'INSERT' + else: + check = 'SELECT' + + cur.execute( + "select has_table_privilege(current_user, 'content', %s)", + (check,) + ) + return cur.fetchone()[0] + + return True + + @db_transaction() + def _add_missing_content_to_db(self, content, db=None, cur=None): + # create temporary table for metadata injection + db.mktemp('content', cur) + + db.copy_to(content, 'tmp_content', + db.content_get_metadata_keys, cur) + + # move metadata in place + db.content_add_from_temp(cur) + + @db_transaction() + def _add_skipped_content_to_db(self, skipped_content, db=None, cur=None): + + db.mktemp('skipped_content', cur) + db.copy_to(skipped_content, 'tmp_skipped_content', + db.skipped_content_keys, cur) + + # move metadata in place + db.skipped_content_add_from_temp(cur) + + @db_transaction() + def content_update(self, content, keys=[], db=None, cur=None): + # TODO: Add a check on input keys. How to properly implement + # this? We don't know yet the new columns. + + db.mktemp('content', cur) + select_keys = list(set(db.content_get_metadata_keys).union(set(keys))) + db.copy_to(content, 'tmp_content', select_keys, cur) + db.content_update_from_temp(keys_to_update=keys, + cur=cur) + + @db_transaction_generator() + def content_missing_per_sha1(self, contents, db=None, cur=None): + for obj in db.content_missing_per_sha1(contents, cur): + yield obj[0] + + @db_transaction_generator() + def skipped_content_missing(self, content, db=None, cur=None): + keys = db.content_hash_keys + + db.mktemp('skipped_content', cur) + db.copy_to(content, 'tmp_skipped_content', + keys + ['length', 'reason'], cur) + + yield from db.skipped_content_missing_from_temp(cur) + + def directory_add(self, directories): + dirs = set() + dir_entries = { + 'file': defaultdict(list), + 'dir': defaultdict(list), + 'rev': defaultdict(list), + } + + for cur_dir in directories: + dir_id = cur_dir['id'] + dirs.add(dir_id) + for src_entry in cur_dir['entries']: + entry = src_entry.copy() + entry['dir_id'] = dir_id + dir_entries[entry['type']][dir_id].append(entry) + + dirs_missing = set(self.directory_missing(dirs)) + if not dirs_missing: + return + + db = self.get_db() + with db.transaction() as cur: + # Copy directory ids + dirs_missing_dict = ({'id': dir} for dir in dirs_missing) + db.mktemp('directory', cur) + db.copy_to(dirs_missing_dict, 'tmp_directory', ['id'], cur) + + # Copy entries + for entry_type, entry_list in dir_entries.items(): + entries = itertools.chain.from_iterable( + entries_for_dir + for dir_id, entries_for_dir + in entry_list.items() + if dir_id in dirs_missing) + + db.mktemp_dir_entry(entry_type) + + db.copy_to( + entries, + 'tmp_directory_entry_%s' % entry_type, + ['target', 'name', 'perms', 'dir_id'], + cur, + ) + + # Do the final copy + db.directory_add_from_temp(cur) + + def revision_add(self, revisions): + db = self.get_db() + + revisions_missing = set(self.revision_missing( + set(revision['id'] for revision in revisions))) + + if not revisions_missing: + return + + with db.transaction() as cur: + db.mktemp_revision(cur) + + revisions_filtered = ( + converters.revision_to_db(revision) for revision in revisions + if revision['id'] in revisions_missing) + + parents_filtered = [] + + db.copy_to( + revisions_filtered, 'tmp_revision', db.revision_add_cols, + cur, + lambda rev: parents_filtered.extend(rev['parents'])) + + db.revision_add_from_temp(cur) + + db.copy_to(parents_filtered, 'revision_history', + ['id', 'parent_id', 'parent_rank'], cur) + + def release_add(self, releases): + db = self.get_db() + + release_ids = set(release['id'] for release in releases) + releases_missing = set(self.release_missing(release_ids)) + + if not releases_missing: + return + + with db.transaction() as cur: + db.mktemp_release(cur) + + releases_filtered = ( + converters.release_to_db(release) for release in releases + if release['id'] in releases_missing + ) + + db.copy_to(releases_filtered, 'tmp_release', db.release_add_cols, + cur) + + db.release_add_from_temp(cur) + + @db_transaction() + def snapshot_add(self, origin, visit, snapshot, + db=None, cur=None): + if not db.snapshot_exists(snapshot['id'], cur): + db.mktemp_snapshot_branch(cur) + db.copy_to( + ( + { + 'name': name, + 'target': info['target'] if info else None, + 'target_type': info['target_type'] if info else None, + } + for name, info in snapshot['branches'].items() + ), + 'tmp_snapshot_branch', + ['name', 'target', 'target_type'], + cur, + ) + + db.snapshot_add(origin, visit, snapshot['id'], cur) diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -4,20 +4,16 @@ # See top-level LICENSE file for more information +import abc from collections import defaultdict from concurrent.futures import ThreadPoolExecutor import datetime -import itertools import json import dateutil.parser -import psycopg2 -import psycopg2.pool from . import converters from .common import db_transaction_generator, db_transaction -from .db import Db -from .exc import StorageDBError from .algos import diff from swh.model.hashutil import ALGORITHMS, hash_to_bytes @@ -31,58 +27,25 @@ """Identifier for the empty snapshot""" -class Storage(): - """SWH storage proxy, encompassing DB and object storage +class BaseStorage(metaclass=abc.ABCMeta): + """Abstract class for storage backends. """ - - def __init__(self, db, objstorage, min_pool_conns=1, max_pool_conns=10): - """ - Args: - db_conn: either a libpq connection string, or a psycopg2 connection - obj_root: path to the root of the object storage - - """ - try: - if isinstance(db, psycopg2.extensions.connection): - self._pool = None - self._db = Db(db) - else: - self._pool = psycopg2.pool.ThreadedConnectionPool( - min_pool_conns, max_pool_conns, db - ) - self._db = None - except psycopg2.OperationalError as e: - raise StorageDBError(e) - + def __init__(self, db, objstorage, **kwargs): + self._db = self._init_db(db, **kwargs) self.objstorage = get_objstorage(**objstorage) - def get_db(self): - if self._db: - return self._db - else: - return Db.from_pool(self._pool) + @abc.abstractmethod + def _init_db(self, db, **kwargs): + pass + + def get_db(self, db, **kwargs): + return self._db + @abc.abstractmethod def check_config(self, *, check_write): """Check that the storage is configured and ready to go.""" - - if not self.objstorage.check_config(check_write=check_write): - return False - - # Check permissions on one of the tables - with self.get_db().transaction() as cur: - if check_write: - check = 'INSERT' - else: - check = 'SELECT' - - cur.execute( - "select has_table_privilege(current_user, 'content', %s)", - (check,) - ) - return cur.fetchone()[0] - - return True + pass def content_add(self, content): """Add content blobs to the storage @@ -134,6 +97,11 @@ in self.skipped_content_missing( content_without_data)) + missing_filtered = (cont for cont in content_with_data + if cont['sha1'] in missing_content) + skipped_filtered = (cont for cont in content_without_data + if _unique_key(cont) in missing_skipped) + def add_to_objstorage(): data = { cont['sha1']: cont['data'] @@ -145,38 +113,40 @@ with db.transaction() as cur: with ThreadPoolExecutor(max_workers=1) as executor: added_to_objstorage = executor.submit(add_to_objstorage) - if missing_content: - # create temporary table for metadata injection - db.mktemp('content', cur) - - content_filtered = (cont for cont in content_with_data - if cont['sha1'] in missing_content) - db.copy_to(content_filtered, 'tmp_content', - db.content_get_metadata_keys, cur) + if missing_filtered: + self._add_missing_content_to_db(missing_filtered, + db=db, cur=cur) + if skipped_filtered: + self._add_skipped_content_to_db(skipped_filtered, + db=db, cur=cur) - # move metadata in place - db.content_add_from_temp(cur) + # Wait for objstorage addition before returning from the + # transaction, bubbling up any exception + added_to_objstorage.result() - if missing_skipped: - missing_filtered = ( - cont for cont in content_without_data - if _unique_key(cont) in missing_skipped - ) + @abc.abstractmethod + def _add_missing_content_to_db(self, content, db=None, cur=None): + """Insert in the database the list of content that was + previously missing from the archive and are now being added + to the objstorage. - db.mktemp('skipped_content', cur) - db.copy_to(missing_filtered, 'tmp_skipped_content', - db.skipped_content_keys, cur) + Args: + content: see `content_add`.""" + pass - # move metadata in place - db.skipped_content_add_from_temp(cur) + @abc.abstractmethod + def _add_skipped_content_to_db(self, skipped_content, db=None, cur=None): + """Insert in the database the list of content that was + previously missing from the archive and is now explicitely + skipped -- meaning it's now referenced but not stored. - # Wait for objstorage addition before returning from the - # transaction, bubbling up any exception - added_to_objstorage.result() + Args: + skipped_content: see `content_add`.""" + pass - @db_transaction() - def content_update(self, content, keys=[], db=None, cur=None): + @abc.abstractmethod + def content_update(self, content, keys=[]): """Update content blobs to the storage. Does nothing for unknown contents or skipped ones. @@ -196,14 +166,7 @@ new hash column """ - # TODO: Add a check on input keys. How to properly implement - # this? We don't know yet the new columns. - - db.mktemp('content', cur) - select_keys = list(set(db.content_get_metadata_keys).union(set(keys))) - db.copy_to(content, 'tmp_content', select_keys, cur) - db.content_update_from_temp(keys_to_update=keys, - cur=cur) + pass def content_get(self, content): """Retrieve in bulk contents and their data. @@ -304,7 +267,9 @@ @db_transaction_generator() def skipped_content_missing(self, content, db=None, cur=None): - """List skipped_content missing from storage + """List skipped_content missing from storage, ie. content + that is known by the archive but not stored for some reason + (e.g. file whose size is too large). Args: content: iterable of dictionaries containing the data for each @@ -353,6 +318,7 @@ return dict(zip(db.content_find_cols, c)) return None + @abc.abstractmethod def directory_add(self, directories): """Add directories to the storage @@ -372,51 +338,7 @@ directory entry - perms (int): entry permissions """ - dirs = set() - dir_entries = { - 'file': defaultdict(list), - 'dir': defaultdict(list), - 'rev': defaultdict(list), - } - - for cur_dir in directories: - dir_id = cur_dir['id'] - dirs.add(dir_id) - for src_entry in cur_dir['entries']: - entry = src_entry.copy() - entry['dir_id'] = dir_id - dir_entries[entry['type']][dir_id].append(entry) - - dirs_missing = set(self.directory_missing(dirs)) - if not dirs_missing: - return - - db = self.get_db() - with db.transaction() as cur: - # Copy directory ids - dirs_missing_dict = ({'id': dir} for dir in dirs_missing) - db.mktemp('directory', cur) - db.copy_to(dirs_missing_dict, 'tmp_directory', ['id'], cur) - - # Copy entries - for entry_type, entry_list in dir_entries.items(): - entries = itertools.chain.from_iterable( - entries_for_dir - for dir_id, entries_for_dir - in entry_list.items() - if dir_id in dirs_missing) - - db.mktemp_dir_entry(entry_type) - - db.copy_to( - entries, - 'tmp_directory_entry_%s' % entry_type, - ['target', 'name', 'perms', 'dir_id'], - cur, - ) - - # Do the final copy - db.directory_add_from_temp(cur) + pass @db_transaction_generator() def directory_missing(self, directories, db=None, cur=None): @@ -469,6 +391,7 @@ if res: return dict(zip(db.directory_ls_cols, res)) + @abc.abstractmethod def revision_add(self, revisions): """Add revisions to the storage @@ -501,32 +424,7 @@ - parents (list of sha1_git): the parents of this revision """ - db = self.get_db() - - revisions_missing = set(self.revision_missing( - set(revision['id'] for revision in revisions))) - - if not revisions_missing: - return - - with db.transaction() as cur: - db.mktemp_revision(cur) - - revisions_filtered = ( - converters.revision_to_db(revision) for revision in revisions - if revision['id'] in revisions_missing) - - parents_filtered = [] - - db.copy_to( - revisions_filtered, 'tmp_revision', db.revision_add_cols, - cur, - lambda rev: parents_filtered.extend(rev['parents'])) - - db.revision_add_from_temp(cur) - - db.copy_to(parents_filtered, 'revision_history', - ['id', 'parent_id', 'parent_rank'], cur) + pass @db_transaction_generator() def revision_missing(self, revisions, db=None, cur=None): @@ -683,9 +581,8 @@ dict(zip(db.release_get_cols, release)) ) - @db_transaction() - def snapshot_add(self, origin, visit, snapshot, - db=None, cur=None): + @abc.abstractmethod + def snapshot_add(self, origin, visit, snapshot): """Add a snapshot for the given origin/visit couple Args: @@ -707,23 +604,7 @@ (currently a ``sha1_git`` for all object kinds, or the name of the target branch for aliases) """ - if not db.snapshot_exists(snapshot['id'], cur): - db.mktemp_snapshot_branch(cur) - db.copy_to( - ( - { - 'name': name, - 'target': info['target'] if info else None, - 'target_type': info['target_type'] if info else None, - } - for name, info in snapshot['branches'].items() - ), - 'tmp_snapshot_branch', - ['name', 'target', 'target_type'], - cur, - ) - - db.snapshot_add(origin, visit, snapshot['id'], cur) + pass @db_transaction(statement_timeout=2000) def snapshot_get(self, snapshot_id, db=None, cur=None): diff --git a/swh/storage/tests/algos/test_snapshot.py b/swh/storage/tests/algos/test_snapshot.py --- a/swh/storage/tests/algos/test_snapshot.py +++ b/swh/storage/tests/algos/test_snapshot.py @@ -12,7 +12,7 @@ from_regex, none, one_of, sampled_from) from swh.model.identifiers import snapshot_identifier, identifier_to_bytes -from swh.storage.tests.storage_testing import StorageTestFixture +from swh.storage.tests.storage_testing import PgStorageTestFixture from swh.storage.algos.snapshot import snapshot_get_all_branches @@ -94,7 +94,7 @@ @pytest.mark.db -class TestSnapshotAllBranches(StorageTestFixture, unittest.TestCase): +class TestSnapshotAllBranches(PgStorageTestFixture, unittest.TestCase): @given(origins(), datetimes(), snapshots(min_size=0, max_size=10, only_objects=False)) def test_snapshot_small(self, origin, ts, snapshot): diff --git a/swh/storage/tests/storage_testing.py b/swh/storage/tests/storage_testing.py --- a/swh/storage/tests/storage_testing.py +++ b/swh/storage/tests/storage_testing.py @@ -5,6 +5,7 @@ import os import tempfile +import warnings from swh.storage import get_storage @@ -12,7 +13,7 @@ from swh.storage.tests import SQL_DIR -class StorageTestFixture(SingleDbTestFixture): +class PgStorageTestFixture(SingleDbTestFixture): """Mix this in a test subject class to get Storage testing support. This fixture requires to come before SingleDbTestFixture in the @@ -50,8 +51,17 @@ def tearDown(self): self.objtmp.cleanup() self.storage = None + self.reset_storage_tables() super().tearDown() def reset_storage_tables(self): excluded = {'dbversion', 'tool'} self.reset_db_tables(self.TEST_DB_NAME, excluded=excluded) + + +class StorageTestFixture(PgStorageTestFixture): + def __init__(self, *args, **kwargs): + warnings.warn("StorageTestFixture was renamed to " + "PgStorageTestFixture in v0.0.109.", + DeprecationWarning) + super().__init__(*args, **kwargs) diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py --- a/swh/storage/tests/test_api_client.py +++ b/swh/storage/tests/test_api_client.py @@ -7,14 +7,17 @@ import tempfile import unittest +import pytest + from swh.core.tests.server_testing import ServerTestFixture from swh.storage.api.client import RemoteStorage from swh.storage.api.server import app from swh.storage.tests.test_storage import CommonTestStorage +from swh.storage.tests.storage_testing import PgStorageTestFixture class TestRemoteStorage(CommonTestStorage, ServerTestFixture, - unittest.TestCase): + PgStorageTestFixture, unittest.TestCase): """Test the remote storage API. This class doesn't define any tests as we want identical @@ -52,3 +55,18 @@ def tearDown(self): super().tearDown() shutil.rmtree(self.storage_base) + + @pytest.mark.skip("Can only be tested with local storage as you " + "can't mock datetimes for the remote server") + def test_fetch_history(self): + pass + + @pytest.mark.skip("The remote API doesn't expose _person_add") + def test_person_get(self): + pass + + @pytest.mark.skip("This test is only relevant on the local " + "storage, with an actual objstorage raising an " + "exception") + def test_content_add_objstorage_exception(self): + pass diff --git a/swh/storage/tests/test_db.py b/swh/storage/tests/test_db.py --- a/swh/storage/tests/test_db.py +++ b/swh/storage/tests/test_db.py @@ -10,7 +10,7 @@ from swh.core.tests.db_testing import SingleDbTestFixture from swh.model.hashutil import hash_to_bytes -from swh.storage.db import Db +from swh.storage.db import PgDb from . import SQL_DIR @@ -21,7 +21,7 @@ def setUp(self): super().setUp() - self.db = Db(self.conn) + self.db = PgDb(self.conn) def tearDown(self): self.db.conn.close() diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -6,27 +6,23 @@ import copy import datetime import unittest +import warnings from collections import defaultdict from operator import itemgetter from unittest.mock import Mock, patch import psycopg2 -import pytest from swh.model import from_disk, identifiers from swh.model.hashutil import hash_to_bytes -from swh.storage.tests.storage_testing import StorageTestFixture +from swh.storage.tests.storage_testing import PgStorageTestFixture -@pytest.mark.db -class BaseTestStorage(StorageTestFixture): +class DataTestStorage: + """Base class which provides data to tests.""" def setUp(self): super().setUp() - db = self.test_db[self.TEST_DB_NAME] - self.conn = db.conn - self.cursor = db.cursor - self.maxDiff = None self.cont = { @@ -509,17 +505,22 @@ 'next_branch': None } - def tearDown(self): - self.reset_storage_tables() - super().tearDown() +class BaseTestFixture(DataTestStorage, PgStorageTestFixture): + def __init__(self, *args, **kwargs): + warnings.warn("BaseTestFixture was renamed to " + "DataTestStorage in v0.0.109, and no longer inherits" + "from (Pg)StorageTestFixture", + DeprecationWarning) + super().__init__(*args, **kwargs) -class CommonTestStorage(BaseTestStorage): + +class CommonTestStorage(DataTestStorage): """Base class for Storage testing. This class is used as-is to test local storage (see TestLocalStorage - below) and remote storage (see TestRemoteStorage in - test_remote_storage.py. + and TestMemStorage below) and remote storage (see TestRemoteStorage in + test_api_client.py. We need to have the two classes inherit from this base class separately to avoid nosetests running the tests from the base @@ -1841,12 +1842,6 @@ self.assertEqual(m_by_provider[0]['id'], o_m2) self.assertIsNotNone(o_m1) - -class TestLocalStorage(CommonTestStorage, unittest.TestCase): - """Test the local storage""" - - # Can only be tested with local storage as you can't mock - # datetimes for the remote server def test_fetch_history(self): origin = self.storage.origin_add_one(self.origin) with patch('datetime.datetime'): @@ -1869,7 +1864,6 @@ self.assertEqual(expected_fetch_history, fetch_history) - # The remote API doesn't expose _person_add def test_person_get(self): # given person0 = { @@ -1906,8 +1900,6 @@ }, ]) - # This test is only relevant on the local storage, with an actual - # objstorage raising an exception def test_content_add_objstorage_exception(self): self.storage.objstorage.add = Mock( side_effect=Exception('mocked broken objstorage') @@ -1921,7 +1913,18 @@ self.assertEqual(missing, [self.cont['sha1']]) -class AlteringSchemaTest(BaseTestStorage, unittest.TestCase): +class TestLocalStorage(CommonTestStorage, PgStorageTestFixture, + unittest.TestCase): + """Test the local storage""" + def setUp(self): + super().setUp() + db = self.test_db[self.TEST_DB_NAME] + self.conn = db.conn + self.cursor = db.cursor + + +class AlteringSchemaTest(DataTestStorage, PgStorageTestFixture, + unittest.TestCase): """This class is dedicated for the rare case where the schema needs to be altered dynamically.