diff --git a/swh/vault/backend.py b/swh/vault/backend.py index d5cf91b..4c0a05c 100644 --- a/swh/vault/backend.py +++ b/swh/vault/backend.py @@ -1,254 +1,258 @@ # Copyright (C) 2017 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import smtplib import celery import psycopg2 import psycopg2.extras from functools import wraps from email.mime.text import MIMEText from swh.model import hashutil from swh.scheduler.utils import get_task from swh.vault.cache import VaultCache from swh.vault.cookers import get_cooker from swh.vault.cooking_tasks import SWHCookingTask # noqa cooking_task_name = 'swh.vault.cooking_tasks.SWHCookingTask' NOTIF_EMAIL_FROM = ('"Software Heritage Vault" ' '') NOTIF_EMAIL_SUBJECT = ("Bundle ready: {obj_type} {short_id}") NOTIF_EMAIL_BODY = """ You have requested the following bundle from the Software Heritage Vault: Object Type: {obj_type} Object ID: {hex_id} This bundle is now available for download at the following address: {url} Please keep in mind that this link might expire at some point, in which case you will need to request the bundle again. --\x20 The Software Heritage Developers """ # TODO: Imported from swh.scheduler.backend. Factorization needed. def autocommit(fn): @wraps(fn) def wrapped(self, *args, **kwargs): autocommit = False # TODO: I don't like using None, it's confusing for the user. how about # a NEW_CURSOR object()? if 'cursor' not in kwargs or not kwargs['cursor']: autocommit = True kwargs['cursor'] = self.cursor() try: ret = fn(self, *args, **kwargs) except: if autocommit: self.rollback() raise if autocommit: self.commit() return ret return wrapped # TODO: This has to be factorized with other database base classes and helpers # (swh.scheduler.backend.SchedulerBackend, swh.storage.db.BaseDb, ...) # The three first methods are imported from swh.scheduler.backend. class VaultBackend: """ Backend for the Software Heritage vault. """ def __init__(self, config): self.config = config self.cache = VaultCache(**self.config['cache']) self.db = None self.reconnect() self.smtp_server = smtplib.SMTP('localhost') def reconnect(self): if not self.db or self.db.closed: self.db = psycopg2.connect( dsn=self.config['vault_db'], cursor_factory=psycopg2.extras.RealDictCursor, ) def close(self): self.db.close() def cursor(self): """Return a fresh cursor on the database, with auto-reconnection in case of failure""" cur = None # Get a fresh cursor and reconnect at most three times tries = 0 while True: tries += 1 try: cur = self.db.cursor() cur.execute('select 1') break except psycopg2.OperationalError: if tries < 3: self.reconnect() else: raise return cur def commit(self): """Commit a transaction""" self.db.commit() def rollback(self): """Rollback a transaction""" self.db.rollback() @autocommit def task_info(self, obj_type, obj_id, cursor=None): obj_id = hashutil.hash_to_bytes(obj_id) cursor.execute(''' SELECT id, type, object_id, task_uuid, task_status, - ts_created, ts_done, progress_msg + ts_created, ts_done, ts_last_access, progress_msg FROM vault_bundle WHERE type = %s AND object_id = %s''', (obj_type, obj_id)) res = cursor.fetchone() if res: res['object_id'] = bytes(res['object_id']) return res + def _send_task(task_uuid, args): + task = get_task(cooking_task_name) + task.apply_async(args, task_id=task_uuid) + @autocommit def create_task(self, obj_type, obj_id, cursor=None): obj_id = hashutil.hash_to_bytes(obj_id) args = [self.config, obj_type, obj_id] - cooker = get_cooker(obj_type)(*args) + CookerCls = get_cooker(obj_type) + cooker = CookerCls(*args) cooker.check_exists() task_uuid = celery.uuid() cursor.execute(''' INSERT INTO vault_bundle (type, object_id, task_uuid) VALUES (%s, %s, %s)''', (obj_type, obj_id, task_uuid)) self.commit() - task = get_task(cooking_task_name) - task.apply_async(args, task_id=task_uuid) + self._send_task(task_uuid, args) @autocommit def add_notif_email(self, obj_type, obj_id, email, cursor=None): obj_id = hashutil.hash_to_bytes(obj_id) cursor.execute(''' INSERT INTO vault_notif_email (email, bundle_id) VALUES (%s, (SELECT id FROM vault_bundle WHERE type = %s AND object_id = %s))''', (email, obj_type, obj_id)) @autocommit def cook_request(self, obj_type, obj_id, email=None, cursor=None): info = self.task_info(obj_type, obj_id) if info is None: self.create_task(obj_type, obj_id) if email is not None: if info is not None and info['task_status'] == 'done': self.send_notification(None, email, obj_type, obj_id) else: self.add_notif_email(obj_type, obj_id, email) info = self.task_info(obj_type, obj_id) return info @autocommit def is_available(self, obj_type, obj_id, cursor=None): info = self.task_info(obj_type, obj_id, cursor=cursor) return (info is not None and info['task_status'] == 'done' and self.cache.is_cached(obj_type, obj_id)) @autocommit def fetch(self, obj_type, obj_id, cursor=None): if not self.is_available(obj_type, obj_id, cursor=cursor): return None self.update_access_ts(obj_type, obj_id, cursor=cursor) return self.cache.get(obj_type, obj_id) @autocommit def update_access_ts(self, obj_type, obj_id, cursor=None): obj_id = hashutil.hash_to_bytes(obj_id) cursor.execute(''' UPDATE vault_bundle SET ts_last_access = NOW() WHERE type = %s AND object_id = %s''', (obj_type, obj_id)) @autocommit def set_status(self, obj_type, obj_id, status, cursor=None): obj_id = hashutil.hash_to_bytes(obj_id) req = (''' UPDATE vault_bundle SET task_status = %s ''' + (''', ts_done = NOW() ''' if status == 'done' else '') + '''WHERE type = %s AND object_id = %s''') cursor.execute(req, (status, obj_type, obj_id)) @autocommit def set_progress(self, obj_type, obj_id, progress, cursor=None): obj_id = hashutil.hash_to_bytes(obj_id) cursor.execute(''' UPDATE vault_bundle SET progress_msg = %s WHERE type = %s AND object_id = %s''', (progress, obj_type, obj_id)) @autocommit def send_all_notifications(self, obj_type, obj_id, cursor=None): obj_id = hashutil.hash_to_bytes(obj_id) cursor.execute(''' SELECT vault_notif_email.id AS id, email FROM vault_notif_email INNER JOIN vault_bundle ON bundle_id = vault_bundle.id WHERE vault_bundle.type = %s AND vault_bundle.object_id = %s''', (obj_type, obj_id)) for d in cursor: self.send_notification(d['id'], d['email'], obj_type, obj_id) @autocommit def send_notification(self, n_id, email, obj_type, obj_id, cursor=None): hex_id = hashutil.hash_to_hex(obj_id) short_id = hex_id[:7] # TODO: instead of hardcoding this, we should probably: # * add a "fetch_url" field in the vault_notif_email table # * generate the url with flask.url_for() on the web-ui side # * send this url as part of the cook request and store it in # the table # * use this url for the notification e-mail url = ('https://archive.softwareheritage.org/api/1/vault/{}/{}/' 'raw'.format(obj_type, hex_id)) text = NOTIF_EMAIL_BODY.strip() text = text.format(obj_type=obj_type, hex_id=hex_id, url=url) msg = MIMEText(text) msg['Subject'] = (NOTIF_EMAIL_SUBJECT .format(obj_type=obj_type, short_id=short_id)) msg['From'] = NOTIF_EMAIL_FROM msg['To'] = email self.smtp_server.send_message(msg) if n_id is not None: cursor.execute(''' DELETE FROM vault_notif_email WHERE id = %s''', (n_id,)) diff --git a/swh/vault/tests/test_backend.py b/swh/vault/tests/test_backend.py new file mode 100644 index 0000000..4838a95 --- /dev/null +++ b/swh/vault/tests/test_backend.py @@ -0,0 +1,205 @@ +# Copyright (C) 2017 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import contextlib +import datetime +import psycopg2 +import unittest + +from unittest.mock import patch + +from swh.core.tests.db_testing import DbTestFixture +from swh.model import hashutil +from swh.storage.tests.storage_testing import StorageTestFixture +from swh.vault.tests.vault_testing import VaultTestFixture + + +class BaseTestBackend(VaultTestFixture, StorageTestFixture, DbTestFixture): + @contextlib.contextmanager + def mock_cooking(self): + with patch.object(self.vault_backend, '_send_task') as mt: + with patch('swh.vault.backend.get_cooker') as mg: + mcc = unittest.mock.MagicMock() + mc = unittest.mock.MagicMock() + mg.return_value = mcc + mcc.return_value = mc + mc.check_exists.return_value = True + + yield {'send_task': mt, + 'get_cooker': mg, + 'cooker_cls': mcc, + 'cooker': mc} + + def assertTimestampAlmostNow(self, ts, tolerance_secs=1.0): + now = datetime.datetime.now(datetime.timezone.utc) + creation_delta_secs = (ts - now).total_seconds() + self.assertLess(creation_delta_secs, tolerance_secs) + + +TEST_TYPE = 'revision_gitfast' +TEST_HEX_ID = '4a4b9771542143cf070386f86b4b92d42966bdbc' +TEST_OBJ_ID = hashutil.hash_to_bytes(TEST_HEX_ID) +TEST_PROGRESS = ("Mr. White, You're telling me you're cooking again?" + " \N{ASTONISHED FACE} ") +TEST_EMAIL = 'ouiche@example.com' + + +class TestBackend(BaseTestBackend, unittest.TestCase): + def test_create_task_simple(self): + with self.mock_cooking() as m: + self.vault_backend.create_task(TEST_TYPE, TEST_OBJ_ID) + + m['get_cooker'].assert_called_once_with(TEST_TYPE) + + args = m['cooker_cls'].call_args[0] + self.assertEqual(args[0], self.vault_backend.config) + self.assertEqual(args[1], TEST_TYPE) + self.assertEqual(args[2], TEST_OBJ_ID) + + self.assertEqual(m['cooker'].check_exists.call_count, 1) + + self.assertEqual(m['send_task'].call_count, 1) + args = m['send_task'].call_args[0][1] + self.assertEqual(args[0], self.vault_backend.config) + self.assertEqual(args[1], TEST_TYPE) + self.assertEqual(args[2], TEST_OBJ_ID) + + info = self.vault_backend.task_info(TEST_TYPE, TEST_OBJ_ID) + self.assertEqual(info['object_id'], TEST_OBJ_ID) + self.assertEqual(info['type'], TEST_TYPE) + self.assertEqual(str(info['task_uuid']), + m['send_task'].call_args[0][0]) + self.assertEqual(info['task_status'], 'new') + + self.assertTimestampAlmostNow(info['ts_created']) + + self.assertEqual(info['ts_done'], None) + self.assertEqual(info['progress_msg'], None) + + def test_create_fail_duplicate_task(self): + with self.mock_cooking(): + self.vault_backend.create_task(TEST_TYPE, TEST_OBJ_ID) + with self.assertRaises(psycopg2.IntegrityError): + self.vault_backend.create_task(TEST_TYPE, TEST_OBJ_ID) + + def test_create_fail_nonexisting_object(self): + with self.mock_cooking() as m: + m['cooker'].check_exists.side_effect = ValueError('Nothing here.') + with self.assertRaises(ValueError): + self.vault_backend.create_task(TEST_TYPE, TEST_OBJ_ID) + + def test_create_set_progress(self): + with self.mock_cooking(): + self.vault_backend.create_task(TEST_TYPE, TEST_OBJ_ID) + + info = self.vault_backend.task_info(TEST_TYPE, TEST_OBJ_ID) + self.assertEqual(info['progress_msg'], None) + self.vault_backend.set_progress(TEST_TYPE, TEST_OBJ_ID, + TEST_PROGRESS) + info = self.vault_backend.task_info(TEST_TYPE, TEST_OBJ_ID) + self.assertEqual(info['progress_msg'], TEST_PROGRESS) + + def test_create_set_status(self): + with self.mock_cooking(): + self.vault_backend.create_task(TEST_TYPE, TEST_OBJ_ID) + + info = self.vault_backend.task_info(TEST_TYPE, TEST_OBJ_ID) + self.assertEqual(info['task_status'], 'new') + self.assertEqual(info['ts_done'], None) + + self.vault_backend.set_status(TEST_TYPE, TEST_OBJ_ID, 'pending') + info = self.vault_backend.task_info(TEST_TYPE, TEST_OBJ_ID) + self.assertEqual(info['task_status'], 'pending') + self.assertEqual(info['ts_done'], None) + + self.vault_backend.set_status(TEST_TYPE, TEST_OBJ_ID, 'done') + info = self.vault_backend.task_info(TEST_TYPE, TEST_OBJ_ID) + self.assertEqual(info['task_status'], 'done') + self.assertTimestampAlmostNow(info['ts_done']) + + def test_create_update_access_ts(self): + with self.mock_cooking(): + self.vault_backend.create_task(TEST_TYPE, TEST_OBJ_ID) + + info = self.vault_backend.task_info(TEST_TYPE, TEST_OBJ_ID) + access_ts_1 = info['ts_last_access'] + self.assertTimestampAlmostNow(access_ts_1) + + self.vault_backend.update_access_ts(TEST_TYPE, TEST_OBJ_ID) + info = self.vault_backend.task_info(TEST_TYPE, TEST_OBJ_ID) + access_ts_2 = info['ts_last_access'] + self.assertTimestampAlmostNow(access_ts_2) + + self.vault_backend.update_access_ts(TEST_TYPE, TEST_OBJ_ID) + info = self.vault_backend.task_info(TEST_TYPE, TEST_OBJ_ID) + access_ts_3 = info['ts_last_access'] + self.assertTimestampAlmostNow(access_ts_3) + + self.assertLess(access_ts_1, access_ts_2) + self.assertLess(access_ts_2, access_ts_3) + + def test_cook_request_idempotent(self): + with self.mock_cooking(): + info1 = self.vault_backend.cook_request(TEST_TYPE, TEST_OBJ_ID) + info2 = self.vault_backend.cook_request(TEST_TYPE, TEST_OBJ_ID) + info3 = self.vault_backend.cook_request(TEST_TYPE, TEST_OBJ_ID) + self.assertEqual(info1, info2) + self.assertEqual(info1, info3) + + def test_cook_email_pending_done(self): + with self.mock_cooking(), \ + patch.object(self.vault_backend, 'add_notif_email') as madd, \ + patch.object(self.vault_backend, 'send_notification') as msend: + + self.vault_backend.cook_request(TEST_TYPE, TEST_OBJ_ID) + madd.assert_not_called() + msend.assert_not_called() + + madd.reset_mock() + msend.reset_mock() + + self.vault_backend.cook_request(TEST_TYPE, TEST_OBJ_ID, TEST_EMAIL) + madd.assert_called_once_with(TEST_TYPE, TEST_OBJ_ID, TEST_EMAIL) + msend.assert_not_called() + + madd.reset_mock() + msend.reset_mock() + + self.vault_backend.set_status(TEST_TYPE, TEST_OBJ_ID, 'done') + self.vault_backend.cook_request(TEST_TYPE, TEST_OBJ_ID, TEST_EMAIL) + msend.assert_called_once_with(None, TEST_EMAIL, + TEST_TYPE, TEST_OBJ_ID) + madd.assert_not_called() + + def test_send_all_emails(self): + with self.mock_cooking(): + emails = ('a@example.com', + 'billg@example.com', + 'test+42@example.org') + for email in emails: + self.vault_backend.cook_request(TEST_TYPE, TEST_OBJ_ID, email) + + self.vault_backend.set_status(TEST_TYPE, TEST_OBJ_ID, 'done') + + with patch.object(self.vault_backend, 'smtp_server') as m: + self.vault_backend.send_all_notifications(TEST_TYPE, TEST_OBJ_ID) + + sent_emails = {k[0][0] for k in m.send_message.call_args_list} + self.assertEqual({k['To'] for k in sent_emails}, set(emails)) + + for e in sent_emails: + self.assertIn('info@softwareheritage.org', e['From']) + self.assertIn(TEST_TYPE, e['Subject']) + self.assertIn(TEST_HEX_ID[:5], e['Subject']) + self.assertIn(TEST_TYPE, str(e)) + self.assertIn('https://archive.softwareheritage.org/', str(e)) + self.assertIn(TEST_HEX_ID[:5], str(e)) + self.assertIn('--\x20\n', str(e)) # Well-formated signature!!! + + # Check that the entries have been deleted and recalling the + # function does not re-send the e-mails + m.reset_mock() + self.vault_backend.send_all_notifications(TEST_TYPE, TEST_OBJ_ID) + m.assert_not_called() diff --git a/swh/vault/tests/test_cookers.py b/swh/vault/tests/test_cookers.py index cc81028..a90ed54 100644 --- a/swh/vault/tests/test_cookers.py +++ b/swh/vault/tests/test_cookers.py @@ -1,274 +1,274 @@ # Copyright (C) 2017 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import contextlib import datetime import gzip import io import os import pathlib import subprocess import tarfile import tempfile import unittest import dulwich.fastexport import dulwich.index import dulwich.objects import dulwich.porcelain import dulwich.repo from swh.core.tests.db_testing import DbTestFixture from swh.loader.git.loader import GitLoader from swh.model import hashutil from swh.model.git import compute_hashes_from_directory from swh.storage.tests.storage_testing import StorageTestFixture from swh.vault.cookers import DirectoryCooker, RevisionGitfastCooker from swh.vault.tests.vault_testing import VaultTestFixture class TestRepo: """A tiny context manager for a test git repository, with some utility functions to perform basic git stuff. """ def __enter__(self): self.tmp_dir = tempfile.TemporaryDirectory(prefix='tmp-vault-repo-') self.repo_dir = self.tmp_dir.__enter__() self.repo = dulwich.repo.Repo.init(self.repo_dir) self.author = '"Test Author" '.encode() return pathlib.Path(self.repo_dir) def __exit__(self, exc, value, tb): self.tmp_dir.__exit__(exc, value, tb) def checkout(self, rev_sha): rev = self.repo[rev_sha] dulwich.index.build_index_from_tree(self.repo_dir, self.repo.index_path(), self.repo.object_store, rev.tree) def git_shell(self, *cmd, stdout=subprocess.DEVNULL, **kwargs): subprocess.check_call(('git', '-C', self.repo_dir) + cmd, stdout=stdout, **kwargs) def commit(self, message='Commit test\n', ref=b'HEAD'): self.git_shell('add', '.') message = message.encode() + b'\n' return self.repo.do_commit(message=message, committer=self.author, ref=ref) def merge(self, parent_sha_list, message='Merge branches.'): self.git_shell('merge', '--allow-unrelated-histories', '-m', message, *[p.decode() for p in parent_sha_list]) return self.repo.refs[b'HEAD'] def print_debug_graph(self, reflog=False): args = ['log', '--all', '--graph', '--decorate'] if reflog: args.append('--reflog') self.git_shell(*args, stdout=None) -TEST_CONTENT = (" test content\n" - "and unicode \N{BLACK HEART SUIT}\n" - " and trailing spaces ") -TEST_EXECUTABLE = b'\x42\x40\x00\x00\x05' - - class BaseTestCookers(VaultTestFixture, StorageTestFixture, DbTestFixture): """Base class of cookers unit tests""" def setUp(self): super().setUp() self.loader = GitLoader() self.loader.storage = self.storage def load(self, repo_path): """Load a repository in the test storage""" self.loader.load('fake_origin', repo_path, datetime.datetime.now()) @contextlib.contextmanager def cook_extract_directory(self, obj_id): """Context manager that cooks a directory and extract it.""" cooker = DirectoryCooker(self.vault_config, 'directory', obj_id) with cooker: cooker.check_exists() # Raises if false tarball = b''.join(cooker.prepare_bundle()) with tempfile.TemporaryDirectory('tmp-vault-extract-') as td: fobj = io.BytesIO(tarball) with tarfile.open(fileobj=fobj, mode='r') as tar: tar.extractall(td) p = pathlib.Path(td) / hashutil.hash_to_hex(obj_id) yield p @contextlib.contextmanager def cook_extract_revision_gitfast(self, obj_id): """Context manager that cooks a revision and extract it.""" cooker = RevisionGitfastCooker(self.vault_config, 'revision_gitfast', obj_id) with cooker: cooker.check_exists() # Raises if false fastexport = b''.join(cooker.prepare_bundle()) fastexport_stream = gzip.GzipFile(fileobj=io.BytesIO(fastexport)) test_repo = TestRepo() with test_repo as p: processor = dulwich.fastexport.GitImportProcessor(test_repo.repo) processor.import_stream(fastexport_stream) yield test_repo, p +TEST_CONTENT = (" test content\n" + "and unicode \N{BLACK HEART SUIT}\n" + " and trailing spaces ") +TEST_EXECUTABLE = b'\x42\x40\x00\x00\x05' + + class TestDirectoryCooker(BaseTestCookers, unittest.TestCase): def test_directory_simple(self): repo = TestRepo() with repo as rp: (rp / 'file').write_text(TEST_CONTENT) (rp / 'executable').write_bytes(TEST_EXECUTABLE) (rp / 'executable').chmod(0o755) (rp / 'link').symlink_to('file') (rp / 'dir1/dir2').mkdir(parents=True) (rp / 'dir1/dir2/file').write_text(TEST_CONTENT) c = repo.commit() self.load(str(rp)) obj_id_hex = repo.repo[c].tree.decode() obj_id = hashutil.hash_to_bytes(obj_id_hex) with self.cook_extract_directory(obj_id) as p: self.assertEqual((p / 'file').stat().st_mode, 0o100644) self.assertEqual((p / 'file').read_text(), TEST_CONTENT) self.assertEqual((p / 'executable').stat().st_mode, 0o100755) self.assertEqual((p / 'executable').read_bytes(), TEST_EXECUTABLE) self.assertTrue((p / 'link').is_symlink) self.assertEqual(os.readlink(str(p / 'link')), 'file') self.assertEqual((p / 'dir1/dir2/file').stat().st_mode, 0o100644) self.assertEqual((p / 'dir1/dir2/file').read_text(), TEST_CONTENT) dir_pb = bytes(p) dir_hashes = compute_hashes_from_directory(dir_pb)[dir_pb] dir_hash = dir_hashes['checksums']['sha1_git'] self.assertEqual(obj_id_hex, hashutil.hash_to_hex(dir_hash)) class TestRevisionGitfastCooker(BaseTestCookers, unittest.TestCase): def test_revision_simple(self): # # 1--2--3--4--5--6--7 # repo = TestRepo() with repo as rp: (rp / 'file1').write_text(TEST_CONTENT) repo.commit('add file1') (rp / 'file2').write_text(TEST_CONTENT) repo.commit('add file2') (rp / 'dir1/dir2').mkdir(parents=True) (rp / 'dir1/dir2/file').write_text(TEST_CONTENT) repo.commit('add dir1/dir2/file') (rp / 'bin1').write_bytes(TEST_EXECUTABLE) (rp / 'bin1').chmod(0o755) repo.commit('add bin1') (rp / 'link1').symlink_to('file1') repo.commit('link link1 to file1') (rp / 'file2').unlink() repo.commit('remove file2') (rp / 'bin1').rename(rp / 'bin') repo.commit('rename bin1 to bin') self.load(str(rp)) obj_id_hex = repo.repo.refs[b'HEAD'].decode() obj_id = hashutil.hash_to_bytes(obj_id_hex) with self.cook_extract_revision_gitfast(obj_id) as (ert, p): ert.checkout(b'HEAD') self.assertEqual((p / 'file1').stat().st_mode, 0o100644) self.assertEqual((p / 'file1').read_text(), TEST_CONTENT) self.assertTrue((p / 'link1').is_symlink) self.assertEqual(os.readlink(str(p / 'link1')), 'file1') self.assertEqual((p / 'bin').stat().st_mode, 0o100755) self.assertEqual((p / 'bin').read_bytes(), TEST_EXECUTABLE) self.assertEqual((p / 'dir1/dir2/file').read_text(), TEST_CONTENT) self.assertEqual((p / 'dir1/dir2/file').stat().st_mode, 0o100644) self.assertEqual(ert.repo.refs[b'HEAD'].decode(), obj_id_hex) def test_revision_two_roots(self): # # 1----3---4 # / # 2---- # repo = TestRepo() with repo as rp: (rp / 'file1').write_text(TEST_CONTENT) c1 = repo.commit('Add file1') del repo.repo.refs[b'refs/heads/master'] # git update-ref -d HEAD (rp / 'file2').write_text(TEST_CONTENT) repo.commit('Add file2') repo.merge([c1]) (rp / 'file3').write_text(TEST_CONTENT) repo.commit('add file3') obj_id_hex = repo.repo.refs[b'HEAD'].decode() obj_id = hashutil.hash_to_bytes(obj_id_hex) self.load(str(rp)) with self.cook_extract_revision_gitfast(obj_id) as (ert, p): self.assertEqual(ert.repo.refs[b'HEAD'].decode(), obj_id_hex) def test_revision_two_double_fork_merge(self): # # 2---4---6 # / / / # 1---3---5 # repo = TestRepo() with repo as rp: (rp / 'file1').write_text(TEST_CONTENT) c1 = repo.commit('Add file1') repo.repo.refs[b'refs/heads/c1'] = c1 (rp / 'file2').write_text(TEST_CONTENT) repo.commit('Add file2') (rp / 'file3').write_text(TEST_CONTENT) c3 = repo.commit('Add file3', ref=b'refs/heads/c1') repo.repo.refs[b'refs/heads/c3'] = c3 repo.merge([c3]) (rp / 'file5').write_text(TEST_CONTENT) c5 = repo.commit('Add file3', ref=b'refs/heads/c3') repo.merge([c5]) obj_id_hex = repo.repo.refs[b'HEAD'].decode() obj_id = hashutil.hash_to_bytes(obj_id_hex) self.load(str(rp)) with self.cook_extract_revision_gitfast(obj_id) as (ert, p): self.assertEqual(ert.repo.refs[b'HEAD'].decode(), obj_id_hex) def test_revision_triple_merge(self): # # .---.---5 # / / / # 2 3 4 # / / / # 1---.---. # repo = TestRepo() with repo as rp: (rp / 'file1').write_text(TEST_CONTENT) c1 = repo.commit('Commit 1') repo.repo.refs[b'refs/heads/b1'] = c1 repo.repo.refs[b'refs/heads/b2'] = c1 repo.commit('Commit 2') c3 = repo.commit('Commit 3', ref=b'refs/heads/b1') c4 = repo.commit('Commit 4', ref=b'refs/heads/b2') repo.merge([c3, c4]) obj_id_hex = repo.repo.refs[b'HEAD'].decode() obj_id = hashutil.hash_to_bytes(obj_id_hex) self.load(str(rp)) with self.cook_extract_revision_gitfast(obj_id) as (ert, p): self.assertEqual(ert.repo.refs[b'HEAD'].decode(), obj_id_hex) diff --git a/swh/vault/tests/vault_testing.py b/swh/vault/tests/vault_testing.py index 3dedd37..956d5c2 100644 --- a/swh/vault/tests/vault_testing.py +++ b/swh/vault/tests/vault_testing.py @@ -1,51 +1,56 @@ # Copyright (C) 2017 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import tempfile import pathlib from swh.vault.backend import VaultBackend class VaultTestFixture: """Mix this in a test subject class to get Vault Database testing support. This fixture requires to come before DbTestFixture and StorageTestFixture in the inheritance list as it uses their methods to setup its own internal components. Usage example: class TestVault(VaultTestFixture, StorageTestFixture, DbTestFixture): ... """ TEST_VAULT_DB_NAME = 'softwareheritage-test-vault' @classmethod def setUpClass(cls): if not hasattr(cls, 'DB_TEST_FIXTURE_IMPORTED'): raise RuntimeError("VaultTestFixture needs to be followed by " "DbTestFixture in the inheritance list.") test_dir = pathlib.Path(__file__).absolute().parent test_db_dump = test_dir / '../../../sql/swh-vault-schema.sql' test_db_dump = test_db_dump.absolute() cls.add_db(cls.TEST_VAULT_DB_NAME, str(test_db_dump), 'psql') super().setUpClass() def setUp(self): super().setUp() self.cache_root = tempfile.TemporaryDirectory('vault-cache-') self.vault_config = { 'storage': self.storage_config, 'vault_db': 'postgresql:///' + self.TEST_VAULT_DB_NAME, 'cache': {'root': self.cache_root.name} } self.vault_backend = VaultBackend(self.vault_config) def tearDown(self): - self.reset_tables() self.cache_root.cleanup() self.vault_backend.close() + self.reset_storage_tables() + self.reset_vault_tables() super().tearDown() + + def reset_vault_tables(self): + excluded = {'dbversion'} + self.reset_db_tables(self.TEST_VAULT_DB_NAME, excluded=excluded)