Changeset View
Changeset View
Standalone View
Standalone View
swh/core/db/tests/test_db.py
# Copyright (C) 2019 The Software Heritage developers | # Copyright (C) 2019 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
import os.path | import os.path | ||||
import tempfile | import tempfile | ||||
import unittest | import unittest | ||||
from unittest.mock import Mock, MagicMock | |||||
from hypothesis import strategies, given | from hypothesis import strategies, given | ||||
import psycopg2 | import psycopg2 | ||||
import pytest | import pytest | ||||
from swh.core.db import BaseDb | from swh.core.db import BaseDb | ||||
from swh.core.db.common import db_transaction, db_transaction_generator | |||||
from .db_testing import ( | from .db_testing import ( | ||||
SingleDbTestFixture, db_create, db_destroy, db_close, | SingleDbTestFixture, db_create, db_destroy, db_close, | ||||
) | ) | ||||
INIT_SQL = ''' | INIT_SQL = ''' | ||||
create table test_table | create table test_table | ||||
( | ( | ||||
▲ Show 20 Lines • Show All 80 Lines • ▼ Show 20 Lines | def test_copy_to(self, data): | ||||
self.assertCountEqual(list(cur), data) | self.assertCountEqual(list(cur), data) | ||||
def test_copy_to_thread_exception(self): | def test_copy_to_thread_exception(self): | ||||
data = [(2**65, 'foo', b'bar')] | data = [(2**65, 'foo', b'bar')] | ||||
items = [dict(zip(['i', 'txt', 'bytes'], item)) for item in data] | items = [dict(zip(['i', 'txt', 'bytes'], item)) for item in data] | ||||
with self.assertRaises(psycopg2.errors.NumericValueOutOfRange): | with self.assertRaises(psycopg2.errors.NumericValueOutOfRange): | ||||
self.db.copy_to(items, 'test_table', ['i', 'txt', 'bytes']) | self.db.copy_to(items, 'test_table', ['i', 'txt', 'bytes']) | ||||
def test_db_transaction(mocker): | |||||
expected_cur = object() | |||||
called = False | |||||
class Storage: | |||||
@db_transaction() | |||||
def endpoint(self, cur=None, db=None): | |||||
nonlocal called | |||||
called = True | |||||
assert cur is expected_cur | |||||
storage = Storage() | |||||
# 'with storage.get_db().transaction() as cur:' should cause | |||||
# 'cur' to be 'expected_cur' | |||||
db_mock = Mock() | |||||
db_mock.transaction.return_value = MagicMock() | |||||
db_mock.transaction.return_value.__enter__.return_value = expected_cur | |||||
mocker.patch.object( | |||||
storage, 'get_db', return_value=db_mock, create=True) | |||||
put_db_mock = mocker.patch.object( | |||||
storage, 'put_db', create=True) | |||||
storage.endpoint() | |||||
assert called | |||||
put_db_mock.assert_called_once_with(db_mock) | |||||
def test_db_transaction__with_generator(): | |||||
with pytest.raises(ValueError, match='generator'): | |||||
class Storage: | |||||
@db_transaction() | |||||
def endpoint(self, cur=None, db=None): | |||||
yield None | |||||
def test_db_transaction_generator(mocker): | |||||
expected_cur = object() | |||||
called = False | |||||
class Storage: | |||||
@db_transaction_generator() | |||||
def endpoint(self, cur=None, db=None): | |||||
nonlocal called | |||||
called = True | |||||
assert cur is expected_cur | |||||
yield None | |||||
storage = Storage() | |||||
# 'with storage.get_db().transaction() as cur:' should cause | |||||
# 'cur' to be 'expected_cur' | |||||
db_mock = Mock() | |||||
db_mock.transaction.return_value = MagicMock() | |||||
db_mock.transaction.return_value.__enter__.return_value = expected_cur | |||||
mocker.patch.object( | |||||
storage, 'get_db', return_value=db_mock, create=True) | |||||
put_db_mock = mocker.patch.object( | |||||
storage, 'put_db', create=True) | |||||
list(storage.endpoint()) | |||||
assert called | |||||
put_db_mock.assert_called_once_with(db_mock) | |||||
def test_db_transaction_generator__with_nongenerator(): | |||||
with pytest.raises(ValueError, match='generator'): | |||||
class Storage: | |||||
@db_transaction_generator() | |||||
def endpoint(self, cur=None, db=None): | |||||
ardumont: not used, most probably wrong commit. | |||||
Done Inline Actionsit is used. The test checks it raises a ValueError. vlorentz: it is used. The test checks it raises a ValueError. | |||||
Not Done Inline Actionsheh, missed it ;) ardumont: heh, missed it ;) | |||||
pass |
not used, most probably wrong commit.