diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py --- a/swh/core/db/tests/test_db.py +++ b/swh/core/db/tests/test_db.py @@ -6,12 +6,14 @@ import os.path import tempfile import unittest +from unittest.mock import Mock, MagicMock from hypothesis import strategies, given import psycopg2 import pytest from swh.core.db import BaseDb +from swh.core.db.common import db_transaction, db_transaction_generator from .db_testing import ( SingleDbTestFixture, db_create, db_destroy, db_close, ) @@ -108,3 +110,82 @@ items = [dict(zip(['i', 'txt', 'bytes'], item)) for item in data] with self.assertRaises(psycopg2.errors.NumericValueOutOfRange): self.db.copy_to(items, 'test_table', ['i', 'txt', 'bytes']) + + +def test_db_transaction(mocker): + expected_cur = object() + + called = False + + class Storage: + @db_transaction() + def endpoint(self, cur=None, db=None): + nonlocal called + called = True + assert cur is expected_cur + + storage = Storage() + + # 'with storage.get_db().transaction() as cur:' should cause + # 'cur' to be 'expected_cur' + db_mock = Mock() + db_mock.transaction.return_value = MagicMock() + db_mock.transaction.return_value.__enter__.return_value = expected_cur + mocker.patch.object( + storage, 'get_db', return_value=db_mock, create=True) + + put_db_mock = mocker.patch.object( + storage, 'put_db', create=True) + + storage.endpoint() + + assert called + put_db_mock.assert_called_once_with(db_mock) + + +def test_db_transaction__with_generator(): + with pytest.raises(ValueError, match='generator'): + class Storage: + @db_transaction() + def endpoint(self, cur=None, db=None): + yield None + + +def test_db_transaction_generator(mocker): + expected_cur = object() + + called = False + + class Storage: + @db_transaction_generator() + def endpoint(self, cur=None, db=None): + nonlocal called + called = True + assert cur is expected_cur + yield None + + storage = Storage() + + # 'with storage.get_db().transaction() as cur:' should cause + # 'cur' to be 'expected_cur' + db_mock = Mock() + db_mock.transaction.return_value = MagicMock() + db_mock.transaction.return_value.__enter__.return_value = expected_cur + mocker.patch.object( + storage, 'get_db', return_value=db_mock, create=True) + + put_db_mock = mocker.patch.object( + storage, 'put_db', create=True) + + list(storage.endpoint()) + + assert called + put_db_mock.assert_called_once_with(db_mock) + + +def test_db_transaction_generator__with_nongenerator(): + with pytest.raises(ValueError, match='generator'): + class Storage: + @db_transaction_generator() + def endpoint(self, cur=None, db=None): + pass