diff --git a/swh/scrubber/__init__.py b/swh/scrubber/__init__.py --- a/swh/scrubber/__init__.py +++ b/swh/scrubber/__init__.py @@ -8,14 +8,16 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from .db import ScrubberDb + from swh.scrubber.db import ScrubberDb def get_scrubber_db(cls: str, **kwargs) -> ScrubberDb: - if cls != "local": - raise ValueError(f"Unknown scrubber db class '{cls}', use 'local' instead.") + if cls not in ("local", "postgresql"): + raise ValueError( + f"Unknown scrubber db class '{cls}', use 'postgresql' instead." + ) - from .db import ScrubberDb + from swh.scrubber.db import ScrubberDb return ScrubberDb.connect(kwargs.pop("db"), **kwargs) diff --git a/swh/scrubber/tests/test_cli.py b/swh/scrubber/tests/test_cli.py --- a/swh/scrubber/tests/test_cli.py +++ b/swh/scrubber/tests/test_cli.py @@ -25,7 +25,7 @@ runner = CliRunner() config = { - "scrubber_db": {"cls": "local", "db": scrubber_db.conn.dsn}, + "scrubber_db": {"cls": "postgresql", "db": scrubber_db.conn.dsn}, "graph": {"url": "http://graph.example.org:5009/"}, } if storage: @@ -72,7 +72,7 @@ assert result.exit_code == 0, result.output assert result.output == "" - get_scrubber_db.assert_called_once_with(cls="local", db=scrubber_db.conn.dsn) + get_scrubber_db.assert_called_once_with(cls="postgresql", db=scrubber_db.conn.dsn) StorageChecker.assert_called_once_with( db=scrubber_db, storage=StorageChecker.mock_calls[0][2]["storage"], @@ -103,7 +103,7 @@ assert result.exit_code == 0, result.output assert result.output == "" - get_scrubber_db.assert_called_once_with(cls="local", db=scrubber_db.conn.dsn) + get_scrubber_db.assert_called_once_with(cls="postgresql", db=scrubber_db.conn.dsn) JournalChecker.assert_called_once_with( db=scrubber_db, journal_client={ @@ -129,7 +129,7 @@ assert result.exit_code == 0, result.output assert result.output == "" - get_scrubber_db.assert_called_once_with(cls="local", db=scrubber_db.conn.dsn) + get_scrubber_db.assert_called_once_with(cls="postgresql", db=scrubber_db.conn.dsn) OriginLocator.assert_called_once_with( db=scrubber_db, storage=OriginLocator.mock_calls[0][2]["storage"], @@ -150,7 +150,7 @@ assert result.exit_code == 0, result.output assert result.output == "" - get_scrubber_db.assert_called_once_with(cls="local", db=scrubber_db.conn.dsn) + get_scrubber_db.assert_called_once_with(cls="postgresql", db=scrubber_db.conn.dsn) Fixer.assert_called_once_with( db=scrubber_db, start_object=CoreSWHID.from_string("swh:1:cnt:" + "00" * 20), diff --git a/swh/scrubber/tests/test_init.py b/swh/scrubber/tests/test_init.py new file mode 100644 --- /dev/null +++ b/swh/scrubber/tests/test_init.py @@ -0,0 +1,33 @@ +# Copyright (C) 2020-2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from typing import Any + +import pytest + +from swh.scrubber import get_scrubber_db + + +@pytest.mark.parametrize("clz", ["local", "postgresql"]) +def test_get_scrubber_db(mocker, clz): + mock_scrubber = mocker.patch("swh.scrubber.db.ScrubberDb") + + def test_connect(db_str: str, **kwargs) -> Any: + return "connection-result" + + mock_scrubber.connect.side_effect = test_connect + + actual_result = get_scrubber_db(clz, db="service=scrubber-db") + + assert mock_scrubber.connect.called is True + assert actual_result == "connection-result" + + +@pytest.mark.parametrize("clz", ["something", "anything"]) +def test_get_scrubber_db_raise(clz): + assert clz not in ["local", "postgresql"] + + with pytest.raises(ValueError, match="Unknown"): + get_scrubber_db(clz, db="service=scrubber-db")