diff --git a/swh/journal/cli.py b/swh/journal/cli.py --- a/swh/journal/cli.py +++ b/swh/journal/cli.py @@ -7,6 +7,7 @@ import logging import mmap import os +import warnings import click @@ -21,7 +22,7 @@ from swh.storage import get_storage from swh.objstorage import get_objstorage -from swh.journal.client import JournalClient +from swh.journal.client import get_journal_client as get_client from swh.journal.replay import is_hash_in_bytearray from swh.journal.replay import process_replay_objects from swh.journal.replay import process_replay_objects_content @@ -60,13 +61,22 @@ def get_journal_client(ctx, **kwargs): - conf = ctx.obj["config"].get("journal", {}) - conf.update({k: v for (k, v) in kwargs.items() if v not in (None, ())}) - if not conf.get("brokers"): - ctx.fail("You must specify at least one kafka broker.") - if not isinstance(conf["brokers"], (list, tuple)): - conf["brokers"] = [conf["brokers"]] - return JournalClient(**conf) + conf = ctx.obj["config"].copy() + if "journal" in conf: + warnings.warn( + "Journal client configuration should now be under the " + "`journal_client` field and have a `cls` argument.", + DeprecationWarning, + ) + conf["journal_client"] = {"cls": "kafka", **conf.pop("journal")} + + client_conf = conf.get("journal_client").copy() + client_conf.update(kwargs) + + try: + return get_client(**client_conf) + except ValueError as exc: + ctx.fail(exc) @cli.command() diff --git a/swh/journal/client.py b/swh/journal/client.py --- a/swh/journal/client.py +++ b/swh/journal/client.py @@ -39,6 +39,16 @@ ] +def get_journal_client(cls: str, **kwargs: Any): + """Factory function to instantiate a journal client object. + + Currently, only the "kafka" journal client is supported. + """ + if cls == "kafka": + return JournalClient(**kwargs) + raise ValueError("Unknown journal client class `%s`" % cls) + + def _error_cb(error): if error.fatal(): raise KafkaException(error) diff --git a/swh/journal/tests/test_cli.py b/swh/journal/tests/test_cli.py --- a/swh/journal/tests/test_cli.py +++ b/swh/journal/tests/test_cli.py @@ -10,7 +10,7 @@ import re import tempfile from typing import Any, Dict -from unittest.mock import patch +from unittest.mock import patch, MagicMock from click.testing import CliRunner from confluent_kafka import Producer @@ -21,7 +21,7 @@ from swh.objstorage.backends.in_memory import InMemoryObjStorage from swh.storage import get_storage -from swh.journal.cli import cli +from swh.journal.cli import cli, get_journal_client from swh.journal.replay import CONTENT_REPLAY_RETRIES from swh.journal.serializers import key_to_kafka, value_to_kafka @@ -57,7 +57,8 @@ def invoke(*args, env=None, journal_config=None): config = copy.deepcopy(CLI_CONFIG) if journal_config: - config["journal"] = journal_config + config["journal_client"] = journal_config.copy() + config["journal_client"]["cls"] = "kafka" runner = CliRunner() with tempfile.NamedTemporaryFile("a", suffix=".yml") as config_fd: @@ -67,6 +68,42 @@ return runner.invoke(cli, args, obj={"log_level": logging.DEBUG}, env=env,) +def test_get_journal_client_config_bwcompat(kafka_server): + cfg = { + "journal": { + "brokers": [kafka_server], + "group_id": "toto", + "prefix": "xiferp", + "object_types": ["content"], + "batch_size": 50, + } + } + ctx = MagicMock(obj={"config": cfg}) + with pytest.deprecated_call(): + client = get_journal_client(ctx, stop_after_objects=10, prefix="prefix") + assert client.subscription == ["prefix.content"] + assert client.stop_after_objects == 10 + assert client.batch_size == 50 + + +def test_get_journal_client_config(kafka_server): + cfg = { + "journal_client": { + "cls": "kafka", + "brokers": [kafka_server], + "group_id": "toto", + "prefix": "xiferp", + "object_types": ["content"], + "batch_size": 50, + } + } + ctx = MagicMock(obj={"config": cfg}) + client = get_journal_client(ctx, stop_after_objects=10, prefix="prefix") + assert client.subscription == ["prefix.content"] + assert client.stop_after_objects == 10 + assert client.batch_size == 50 + + def test_replay( storage, kafka_prefix: str, kafka_consumer_group: str, kafka_server: str, ):