diff --git a/swh/scrubber/__init__.py b/swh/scrubber/__init__.py index 2527e35..1c4c0a2 100644 --- a/swh/scrubber/__init__.py +++ b/swh/scrubber/__init__.py @@ -1,23 +1,25 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from __future__ import annotations from typing import TYPE_CHECKING if TYPE_CHECKING: - from .db import ScrubberDb + from swh.scrubber.db import ScrubberDb def get_scrubber_db(cls: str, **kwargs) -> ScrubberDb: - if cls != "local": - raise ValueError(f"Unknown scrubber db class '{cls}', use 'local' instead.") + if cls not in ("local", "postgresql"): + raise ValueError( + f"Unknown scrubber db class '{cls}', use 'postgresql' instead." + ) - from .db import ScrubberDb + from swh.scrubber.db import ScrubberDb return ScrubberDb.connect(kwargs.pop("db"), **kwargs) get_datastore = get_scrubber_db diff --git a/swh/scrubber/tests/test_cli.py b/swh/scrubber/tests/test_cli.py index b54a5c8..4b85237 100644 --- a/swh/scrubber/tests/test_cli.py +++ b/swh/scrubber/tests/test_cli.py @@ -1,159 +1,159 @@ # Copyright (C) 2020-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import tempfile from unittest.mock import MagicMock, call from click.testing import CliRunner import yaml from swh.model.swhids import CoreSWHID from swh.scrubber.cli import scrubber_cli_group from swh.scrubber.storage_checker import storage_db def invoke( scrubber_db, args, storage=None, kafka_server=None, kafka_prefix=None, kafka_consumer_group=None, ): runner = CliRunner() config = { - "scrubber_db": {"cls": "local", "db": scrubber_db.conn.dsn}, + "scrubber_db": {"cls": "postgresql", "db": scrubber_db.conn.dsn}, "graph": {"url": "http://graph.example.org:5009/"}, } if storage: with storage_db(storage) as db: config["storage"] = { "cls": "postgresql", "db": db.conn.dsn, "objstorage": {"cls": "memory"}, } assert ( (kafka_server is None) == (kafka_prefix is None) == (kafka_consumer_group is None) ) if kafka_server: config["journal_client"] = dict( cls="kafka", brokers=kafka_server, group_id=kafka_consumer_group, prefix=kafka_prefix, stop_on_eof=True, ) with tempfile.NamedTemporaryFile("a", suffix=".yml") as config_fd: yaml.dump(config, config_fd) config_fd.seek(0) args = ["-C" + config_fd.name] + list(args) result = runner.invoke(scrubber_cli_group, args, catch_exceptions=False) return result def test_check_storage(mocker, scrubber_db, swh_storage): storage_checker = MagicMock() StorageChecker = mocker.patch( "swh.scrubber.storage_checker.StorageChecker", return_value=storage_checker ) get_scrubber_db = mocker.patch( "swh.scrubber.get_scrubber_db", return_value=scrubber_db ) result = invoke( scrubber_db, ["check", "storage", "--object-type=snapshot"], storage=swh_storage ) assert result.exit_code == 0, result.output assert result.output == "" - get_scrubber_db.assert_called_once_with(cls="local", db=scrubber_db.conn.dsn) + get_scrubber_db.assert_called_once_with(cls="postgresql", db=scrubber_db.conn.dsn) StorageChecker.assert_called_once_with( db=scrubber_db, storage=StorageChecker.mock_calls[0][2]["storage"], object_type="snapshot", start_object="0" * 40, end_object="f" * 40, ) assert storage_checker.method_calls == [call.run()] def test_check_journal( mocker, scrubber_db, kafka_server, kafka_prefix, kafka_consumer_group ): journal_checker = MagicMock() JournalChecker = mocker.patch( "swh.scrubber.journal_checker.JournalChecker", return_value=journal_checker ) get_scrubber_db = mocker.patch( "swh.scrubber.get_scrubber_db", return_value=scrubber_db ) result = invoke( scrubber_db, ["check", "journal"], kafka_server=kafka_server, kafka_prefix=kafka_prefix, kafka_consumer_group=kafka_consumer_group, ) assert result.exit_code == 0, result.output assert result.output == "" - get_scrubber_db.assert_called_once_with(cls="local", db=scrubber_db.conn.dsn) + get_scrubber_db.assert_called_once_with(cls="postgresql", db=scrubber_db.conn.dsn) JournalChecker.assert_called_once_with( db=scrubber_db, journal_client={ "brokers": kafka_server, "cls": "kafka", "group_id": kafka_consumer_group, "prefix": kafka_prefix, "stop_on_eof": True, }, ) assert journal_checker.method_calls == [call.run()] def test_locate_origins(mocker, scrubber_db, swh_storage): origin_locator = MagicMock() OriginLocator = mocker.patch( "swh.scrubber.origin_locator.OriginLocator", return_value=origin_locator ) get_scrubber_db = mocker.patch( "swh.scrubber.get_scrubber_db", return_value=scrubber_db ) result = invoke(scrubber_db, ["locate"], storage=swh_storage) assert result.exit_code == 0, result.output assert result.output == "" - get_scrubber_db.assert_called_once_with(cls="local", db=scrubber_db.conn.dsn) + get_scrubber_db.assert_called_once_with(cls="postgresql", db=scrubber_db.conn.dsn) OriginLocator.assert_called_once_with( db=scrubber_db, storage=OriginLocator.mock_calls[0][2]["storage"], graph=OriginLocator.mock_calls[0][2]["graph"], start_object=CoreSWHID.from_string("swh:1:cnt:" + "00" * 20), end_object=CoreSWHID.from_string("swh:1:snp:" + "ff" * 20), ) assert origin_locator.method_calls == [call.run()] def test_fix_objects(mocker, scrubber_db): fixer = MagicMock() Fixer = mocker.patch("swh.scrubber.fixer.Fixer", return_value=fixer) get_scrubber_db = mocker.patch( "swh.scrubber.get_scrubber_db", return_value=scrubber_db ) result = invoke(scrubber_db, ["fix"]) assert result.exit_code == 0, result.output assert result.output == "" - get_scrubber_db.assert_called_once_with(cls="local", db=scrubber_db.conn.dsn) + get_scrubber_db.assert_called_once_with(cls="postgresql", db=scrubber_db.conn.dsn) Fixer.assert_called_once_with( db=scrubber_db, start_object=CoreSWHID.from_string("swh:1:cnt:" + "00" * 20), end_object=CoreSWHID.from_string("swh:1:snp:" + "ff" * 20), ) assert fixer.method_calls == [call.run()] diff --git a/swh/scrubber/tests/test_init.py b/swh/scrubber/tests/test_init.py new file mode 100644 index 0000000..1f80823 --- /dev/null +++ b/swh/scrubber/tests/test_init.py @@ -0,0 +1,33 @@ +# Copyright (C) 2020-2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from typing import Any + +import pytest + +from swh.scrubber import get_scrubber_db + + +@pytest.mark.parametrize("clz", ["local", "postgresql"]) +def test_get_scrubber_db(mocker, clz): + mock_scrubber = mocker.patch("swh.scrubber.db.ScrubberDb") + + def test_connect(db_str: str, **kwargs) -> Any: + return "connection-result" + + mock_scrubber.connect.side_effect = test_connect + + actual_result = get_scrubber_db(clz, db="service=scrubber-db") + + assert mock_scrubber.connect.called is True + assert actual_result == "connection-result" + + +@pytest.mark.parametrize("clz", ["something", "anything"]) +def test_get_scrubber_db_raise(clz): + assert clz not in ["local", "postgresql"] + + with pytest.raises(ValueError, match="Unknown"): + get_scrubber_db(clz, db="service=scrubber-db")