diff --git a/requirements.txt b/requirements.txt index 872d546..68aa3ff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,18 +1,19 @@ # Add here external Python modules dependencies, one per line. Module names # should match https://pypi.python.org/pypi names. For the full spec or # dependency lines, see https://pip.readthedocs.org/en/1.1/requirements.html arrow attrs attrs-strict celery >= 4.3 Click elasticsearch > 5.4 flask pika >= 1.1.0 psycopg2 pyyaml setuptools +typing-extensions # test dependencies # hypothesis diff --git a/swh/scheduler/__init__.py b/swh/scheduler/__init__.py index 2d3892f..1274868 100644 --- a/swh/scheduler/__init__.py +++ b/swh/scheduler/__init__.py @@ -1,67 +1,88 @@ -# Copyright (C) 2018 The Software Heritage developers +# Copyright (C) 2018-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Any, Dict +from __future__ import annotations + +from importlib import import_module +from typing import TYPE_CHECKING, Any, Dict +import warnings # Percentage of tasks with priority to schedule PRIORITY_SLOT = 0.6 DEFAULT_CONFIG = { "scheduler": ( "dict", - {"cls": "local", "args": {"db": "dbname=softwareheritage-scheduler-dev",},}, + {"cls": "local", "db": "dbname=softwareheritage-scheduler-dev",}, ) } # current configuration. To be set by the config loading mechanism CONFIG = {} # type: Dict[str, Any] +if TYPE_CHECKING: + from swh.scheduler.interface import SchedulerInterface + + def compute_nb_tasks_from(num_tasks): """Compute and returns the tuple, number of tasks without priority, number of tasks with priority. Args: num_tasks (int): Returns: tuple number of tasks without priority (int), number of tasks with priority (int) """ if not num_tasks: return None, None return (int((1 - PRIORITY_SLOT) * num_tasks), int(PRIORITY_SLOT * num_tasks)) -def get_scheduler(cls, args={}): +BACKEND_TYPES: Dict[str, str] = { + "local": ".backend.SchedulerBackend", + "remote": ".api.client.RemoteScheduler", +} + + +def get_scheduler(cls: str, **kwargs) -> SchedulerInterface: """ - Get a scheduler object of class `scheduler_class` with arguments - `scheduler_args`. + Get a scheduler object of class `cls` with arguments `**kwargs`. Args: - scheduler (dict): dictionary with keys: - - cls (str): scheduler's class, either 'local' or 'remote' - args (dict): dictionary with keys, default to empty. + cls: scheduler's class, either 'local' or 'remote' + kwargs: arguments to pass to the class' constructor Returns: an instance of swh.scheduler, either local or remote: local: swh.scheduler.backend.SchedulerBackend remote: swh.scheduler.api.client.RemoteScheduler Raises: ValueError if passed an unknown storage class. """ - if cls == "remote": - from .api.client import RemoteScheduler as SchedulerBackend - elif cls == "local": - from .backend import SchedulerBackend - else: - raise ValueError("Unknown swh.scheduler class `%s`" % cls) - - return SchedulerBackend(**args) + if "args" in kwargs: + warnings.warn( + 'Explicit "args" key is deprecated, use keys directly instead.', + DeprecationWarning, + ) + kwargs = kwargs["args"] + + class_path = BACKEND_TYPES.get(cls) + if class_path is None: + raise ValueError( + f"Unknown Scheduler class `{cls}`. " + f"Supported: {', '.join(BACKEND_TYPES)}" + ) + + (module_path, class_name) = class_path.rsplit(".", 1) + module = import_module(module_path, package=__package__) + BackendClass = getattr(module, class_name) + return BackendClass(**kwargs) diff --git a/swh/scheduler/api/server.py b/swh/scheduler/api/server.py index 3a6c173..f1654a6 100644 --- a/swh/scheduler/api/server.py +++ b/swh/scheduler/api/server.py @@ -1,154 +1,150 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import os from swh.core import config from swh.core.api import JSONFormatter, MsgpackFormatter, RPCServerApp from swh.core.api import encode_data_server as encode_data from swh.core.api import error_handler, negotiate from swh.scheduler import get_scheduler from swh.scheduler.exc import SchedulerException from swh.scheduler.interface import SchedulerInterface from .serializers import DECODERS, ENCODERS scheduler = None def get_global_scheduler(): global scheduler if not scheduler: scheduler = get_scheduler(**app.config["scheduler"]) return scheduler class SchedulerServerApp(RPCServerApp): extra_type_decoders = DECODERS extra_type_encoders = ENCODERS app = SchedulerServerApp( __name__, backend_class=SchedulerInterface, backend_factory=get_global_scheduler ) @app.errorhandler(SchedulerException) def argument_error_handler(exception): return error_handler(exception, encode_data, status_code=400) @app.errorhandler(Exception) def my_error_handler(exception): return error_handler(exception, encode_data) def has_no_empty_params(rule): return len(rule.defaults or ()) >= len(rule.arguments or ()) @app.route("/") def index(): return """ Software Heritage scheduler RPC server

You have reached the Software Heritage scheduler RPC server.
See its documentation and API for more information

""" @app.route("/site-map") @negotiate(MsgpackFormatter) @negotiate(JSONFormatter) def site_map(): links = [] for rule in app.url_map.iter_rules(): if has_no_empty_params(rule) and hasattr(SchedulerInterface, rule.endpoint): links.append( dict( rule=rule.rule, description=getattr(SchedulerInterface, rule.endpoint).__doc__, ) ) # links is now a list of url, endpoint tuples return links -def load_and_check_config(config_file, type="local"): +def load_and_check_config(config_path, type="local"): """Check the minimal configuration is set to run the api or raise an error explanation. Args: - config_file (str): Path to the configuration file to load + config_path (str): Path to the configuration file to load type (str): configuration type. For 'local' type, more checks are done. Raises: Error if the setup is not as expected Returns: configuration as a dict """ - if not config_file: + if not config_path: raise EnvironmentError("Configuration file must be defined") - if not os.path.exists(config_file): - raise FileNotFoundError("Configuration file %s does not exist" % (config_file,)) + if not os.path.exists(config_path): + raise FileNotFoundError(f"Configuration file {config_path} does not exist") - cfg = config.read(config_file) + cfg = config.read(config_path) vcfg = cfg.get("scheduler") if not vcfg: raise KeyError("Missing '%scheduler' configuration") if type == "local": cls = vcfg.get("cls") if cls != "local": raise ValueError( "The scheduler backend can only be started with a 'local' " "configuration" ) - args = vcfg.get("args") - if not args: - raise KeyError("Invalid configuration; missing 'args' config entry") - - db = args.get("db") + db = vcfg.get("db") if not db: raise KeyError("Invalid configuration; missing 'db' config entry") return cfg api_cfg = None def make_app_from_configfile(): """Run the WSGI app from the webserver, loading the configuration from a configuration file. SWH_CONFIG_FILENAME environment variable defines the configuration path to load. """ global api_cfg if not api_cfg: - config_file = os.environ.get("SWH_CONFIG_FILENAME") - api_cfg = load_and_check_config(config_file) + config_path = os.environ.get("SWH_CONFIG_FILENAME") + api_cfg = load_and_check_config(config_path) app.config.update(api_cfg) handler = logging.StreamHandler() app.logger.addHandler(handler) return app if __name__ == "__main__": print('Please use the "swh-scheduler api-server" command') diff --git a/swh/scheduler/interface.py b/swh/scheduler/interface.py index 0496093..a2f8198 100644 --- a/swh/scheduler/interface.py +++ b/swh/scheduler/interface.py @@ -1,311 +1,314 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from typing import Any, Dict, Iterable, List, Optional from uuid import UUID +from typing_extensions import Protocol, runtime_checkable + from swh.core.api import remote_api_endpoint from swh.scheduler.model import ( ListedOrigin, ListedOriginPageToken, Lister, PaginatedListedOriginList, ) -class SchedulerInterface: +@runtime_checkable +class SchedulerInterface(Protocol): @remote_api_endpoint("task_type/create") def create_task_type(self, task_type): """Create a new task type ready for scheduling. Args: task_type (dict): a dictionary with the following keys: - type (str): an identifier for the task type - description (str): a human-readable description of what the task does - backend_name (str): the name of the task in the job-scheduling backend - default_interval (datetime.timedelta): the default interval between two task runs - min_interval (datetime.timedelta): the minimum interval between two task runs - max_interval (datetime.timedelta): the maximum interval between two task runs - backoff_factor (float): the factor by which the interval changes at each run - max_queue_length (int): the maximum length of the task queue for this task type """ ... @remote_api_endpoint("task_type/get") def get_task_type(self, task_type_name): """Retrieve the task type with id task_type_name""" ... @remote_api_endpoint("task_type/get_all") def get_task_types(self): """Retrieve all registered task types""" ... @remote_api_endpoint("task/create") def create_tasks(self, tasks, policy="recurring"): """Create new tasks. Args: tasks (list): each task is a dictionary with the following keys: - type (str): the task type - arguments (dict): the arguments for the task runner, keys: - args (list of str): arguments - kwargs (dict str -> str): keyword arguments - next_run (datetime.datetime): the next scheduled run for the task Returns: a list of created tasks. """ ... @remote_api_endpoint("task/set_status") def set_status_tasks(self, task_ids, status="disabled", next_run=None): """Set the tasks' status whose ids are listed. If given, also set the next_run date. """ ... @remote_api_endpoint("task/disable") def disable_tasks(self, task_ids): """Disable the tasks whose ids are listed.""" ... @remote_api_endpoint("task/search") def search_tasks( self, task_id=None, task_type=None, status=None, priority=None, policy=None, before=None, after=None, limit=None, ): """Search tasks from selected criterions""" ... @remote_api_endpoint("task/get") def get_tasks(self, task_ids): """Retrieve the info of tasks whose ids are listed.""" ... @remote_api_endpoint("task/peek_ready") def peek_ready_tasks( self, task_type, timestamp=None, num_tasks=None, num_tasks_priority=None, ): """Fetch the list of ready tasks Args: task_type (str): filtering task per their type timestamp (datetime.datetime): peek tasks that need to be executed before that timestamp num_tasks (int): only peek at num_tasks tasks (with no priority) num_tasks_priority (int): only peek at num_tasks_priority tasks (with priority) Returns: a list of tasks """ ... @remote_api_endpoint("task/grab_ready") def grab_ready_tasks( self, task_type, timestamp=None, num_tasks=None, num_tasks_priority=None, ): """Fetch the list of ready tasks, and mark them as scheduled Args: task_type (str): filtering task per their type timestamp (datetime.datetime): grab tasks that need to be executed before that timestamp num_tasks (int): only grab num_tasks tasks (with no priority) num_tasks_priority (int): only grab oneshot num_tasks tasks (with priorities) Returns: a list of tasks """ ... @remote_api_endpoint("task_run/schedule_one") def schedule_task_run(self, task_id, backend_id, metadata=None, timestamp=None): """Mark a given task as scheduled, adding a task_run entry in the database. Args: task_id (int): the identifier for the task being scheduled backend_id (str): the identifier of the job in the backend metadata (dict): metadata to add to the task_run entry timestamp (datetime.datetime): the instant the event occurred Returns: a fresh task_run entry """ ... @remote_api_endpoint("task_run/schedule") def mass_schedule_task_runs(self, task_runs): """Schedule a bunch of task runs. Args: task_runs (list): a list of dicts with keys: - task (int): the identifier for the task being scheduled - backend_id (str): the identifier of the job in the backend - metadata (dict): metadata to add to the task_run entry - scheduled (datetime.datetime): the instant the event occurred Returns: None """ ... @remote_api_endpoint("task_run/start") def start_task_run(self, backend_id, metadata=None, timestamp=None): """Mark a given task as started, updating the corresponding task_run entry in the database. Args: backend_id (str): the identifier of the job in the backend metadata (dict): metadata to add to the task_run entry timestamp (datetime.datetime): the instant the event occurred Returns: the updated task_run entry """ ... @remote_api_endpoint("task_run/end") def end_task_run( self, backend_id, status, metadata=None, timestamp=None, result=None, ): """Mark a given task as ended, updating the corresponding task_run entry in the database. Args: backend_id (str): the identifier of the job in the backend status (str): how the task ended; one of: 'eventful', 'uneventful', 'failed' metadata (dict): metadata to add to the task_run entry timestamp (datetime.datetime): the instant the event occurred Returns: the updated task_run entry """ ... @remote_api_endpoint("task/filter_for_archive") def filter_task_to_archive( self, after_ts: str, before_ts: str, limit: int = 10, page_token: Optional[str] = None, ) -> Dict[str, Any]: """Compute the tasks to archive within the datetime interval [after_ts, before_ts[. The method returns a paginated result. Returns: dict with the following keys: - **next_page_token**: opaque token to be used as `page_token` to retrieve the next page of result. If absent, there is no more pages to gather. - **tasks**: list of task dictionaries with the following keys: **id** (str): origin task id **started** (Optional[datetime]): started date **scheduled** (datetime): scheduled date **arguments** (json dict): task's arguments ... """ ... @remote_api_endpoint("task/delete_archived") def delete_archived_tasks(self, task_ids): """Delete archived tasks as much as possible. Only the task_ids whose complete associated task_run have been cleaned up will be. """ ... @remote_api_endpoint("task_run/get") def get_task_runs(self, task_ids, limit=None): """Search task run for a task id""" ... @remote_api_endpoint("lister/get_or_create") def get_or_create_lister( self, name: str, instance_name: Optional[str] = None ) -> Lister: """Retrieve information about the given instance of the lister from the database, or create the entry if it did not exist. """ ... @remote_api_endpoint("lister/update") def update_lister(self, lister: Lister) -> Lister: """Update the state for the given lister instance in the database. Returns: a new Lister object, with all fields updated from the database Raises: StaleData if the `updated` timestamp for the lister instance in database doesn't match the one passed by the user. """ ... @remote_api_endpoint("origins/record") def record_listed_origins( self, listed_origins: Iterable[ListedOrigin] ) -> List[ListedOrigin]: """Record a set of origins that a lister has listed. This performs an "upsert": origins with the same (lister_id, url, visit_type) values are updated with new values for extra_loader_arguments, last_update and last_seen. """ ... @remote_api_endpoint("origins/get") def get_listed_origins( self, lister_id: Optional[UUID] = None, url: Optional[str] = None, limit: int = 1000, page_token: Optional[ListedOriginPageToken] = None, ) -> PaginatedListedOriginList: """Get information on the listed origins matching either the `url` or `lister_id`, or both arguments. Use the `limit` and `page_token` arguments for continuation. The next page token, if any, is returned in the PaginatedListedOriginList object. """ ... @remote_api_endpoint("priority_ratios/get") def get_priority_ratios(self): ... diff --git a/swh/scheduler/pytest_plugin.py b/swh/scheduler/pytest_plugin.py index a7ae36c..24d7876 100644 --- a/swh/scheduler/pytest_plugin.py +++ b/swh/scheduler/pytest_plugin.py @@ -1,108 +1,108 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import timedelta import glob import os from celery.contrib.testing import worker from celery.contrib.testing.app import TestApp, setup_default_app import pkg_resources import pytest from swh.core.utils import numfile_sortkey as sortkey import swh.scheduler from swh.scheduler import get_scheduler SQL_DIR = os.path.join(os.path.dirname(swh.scheduler.__file__), "sql") DUMP_FILES = os.path.join(SQL_DIR, "*.sql") # celery tasks for testing purpose; tasks themselves should be # in swh/scheduler/tests/tasks.py TASK_NAMES = ["ping", "multiping", "add", "error", "echo"] @pytest.fixture def swh_scheduler_config(request, postgresql): scheduler_config = { "db": postgresql.dsn, } all_dump_files = sorted(glob.glob(DUMP_FILES), key=sortkey) cursor = postgresql.cursor() for fname in all_dump_files: with open(fname) as fobj: cursor.execute(fobj.read()) postgresql.commit() return scheduler_config @pytest.fixture def swh_scheduler(swh_scheduler_config): - scheduler = get_scheduler("local", swh_scheduler_config) + scheduler = get_scheduler("local", **swh_scheduler_config) for taskname in TASK_NAMES: scheduler.create_task_type( { "type": "swh-test-{}".format(taskname), "description": "The {} testing task".format(taskname), "backend_name": "swh.scheduler.tests.tasks.{}".format(taskname), "default_interval": timedelta(days=1), "min_interval": timedelta(hours=6), "max_interval": timedelta(days=12), } ) return scheduler # this alias is used to be able to easily instantiate a db-backed Scheduler # eg. for the RPC client/server test suite. swh_db_scheduler = swh_scheduler @pytest.fixture(scope="session") def swh_scheduler_celery_app(): """Set up a Celery app as swh.scheduler and swh worker tests would expect it""" test_app = TestApp( set_as_current=True, enable_logging=True, task_cls="swh.scheduler.task:SWHTask", config={ "accept_content": ["application/x-msgpack", "application/json"], "task_serializer": "msgpack", "result_serializer": "json", }, ) with setup_default_app(test_app, use_trap=False): from swh.scheduler.celery_backend import config config.app = test_app test_app.set_default() test_app.set_current() yield test_app @pytest.fixture(scope="session") def swh_scheduler_celery_includes(): """List of task modules that should be loaded by the swh_scheduler_celery_worker on startup.""" task_modules = ["swh.scheduler.tests.tasks"] for entrypoint in pkg_resources.iter_entry_points("swh.workers"): task_modules.extend(entrypoint.load()().get("task_modules", [])) return task_modules @pytest.fixture(scope="session") def swh_scheduler_celery_worker( swh_scheduler_celery_app, swh_scheduler_celery_includes, ): """Spawn a worker""" for module in swh_scheduler_celery_includes: swh_scheduler_celery_app.loader.import_task_module(module) with worker.start_worker(swh_scheduler_celery_app, pool="solo") as w: yield w diff --git a/swh/scheduler/tests/es/conftest.py b/swh/scheduler/tests/es/conftest.py index 6b3028d..389dfe8 100644 --- a/swh/scheduler/tests/es/conftest.py +++ b/swh/scheduler/tests/es/conftest.py @@ -1,48 +1,48 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest import yaml from swh.scheduler import get_scheduler @pytest.fixture def swh_sched_config(swh_scheduler_config): return { - "scheduler": {"cls": "local", "args": swh_scheduler_config,}, + "scheduler": {"cls": "local", **swh_scheduler_config,}, "elasticsearch": { "cls": "memory", "args": {"index_name_prefix": "swh-tasks",}, }, } @pytest.fixture def swh_sched_config_file(swh_sched_config, monkeypatch, tmp_path): conffile = str(tmp_path / "elastic.yml") with open(conffile, "w") as f: f.write(yaml.dump(swh_sched_config)) monkeypatch.setenv("SWH_CONFIG_FILENAME", conffile) return conffile @pytest.fixture def swh_sched(swh_sched_config): return get_scheduler(**swh_sched_config["scheduler"]) @pytest.fixture def swh_elasticsearch_backend(swh_sched_config): from swh.scheduler.backend_es import ElasticSearchBackend backend = ElasticSearchBackend(**swh_sched_config) backend.initialize() return backend @pytest.fixture def swh_elasticsearch_memory(swh_elasticsearch_backend): return swh_elasticsearch_backend.storage diff --git a/swh/scheduler/tests/test_cli_task_type.py b/swh/scheduler/tests/test_cli_task_type.py index adf2ebe..4b2f6cf 100644 --- a/swh/scheduler/tests/test_cli_task_type.py +++ b/swh/scheduler/tests/test_cli_task_type.py @@ -1,127 +1,127 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import traceback from click.testing import CliRunner import pkg_resources import pytest import yaml from swh.scheduler import get_scheduler from swh.scheduler.cli import cli FAKE_MODULE_ENTRY_POINTS = { "lister.gnu=swh.lister.gnu:register", "lister.pypi=swh.lister.pypi:register", } @pytest.fixture def mock_pkg_resources(monkeypatch): """Monkey patch swh.scheduler's mock_pkg_resources.iter_entry_point call """ def fake_iter_entry_points(*args, **kwargs): """Substitute fake function to return a fixed set of entrypoints """ from pkg_resources import Distribution, EntryPoint d = Distribution() return [EntryPoint.parse(entry, dist=d) for entry in FAKE_MODULE_ENTRY_POINTS] original_method = pkg_resources.iter_entry_points monkeypatch.setattr(pkg_resources, "iter_entry_points", fake_iter_entry_points) yield # reset monkeypatch: is that needed? monkeypatch.setattr(pkg_resources, "iter_entry_points", original_method) @pytest.fixture def local_sched_config(swh_scheduler_config): """Expose the local scheduler configuration """ - return {"scheduler": {"cls": "local", "args": swh_scheduler_config}} + return {"scheduler": {"cls": "local", **swh_scheduler_config}} @pytest.fixture def local_sched_configfile(local_sched_config, tmp_path): """Write in temporary location the local scheduler configuration """ configfile = tmp_path / "config.yml" configfile.write_text(yaml.dump(local_sched_config)) return configfile.as_posix() def test_register_ttypes_all( mock_pkg_resources, local_sched_config, local_sched_configfile ): """Registering all task types""" for command in [ ["--config-file", local_sched_configfile, "task-type", "register"], ["--config-file", local_sched_configfile, "task-type", "register", "-p", "all"], [ "--config-file", local_sched_configfile, "task-type", "register", "-p", "lister.gnu", "-p", "lister.pypi", ], ]: result = CliRunner().invoke(cli, command) assert result.exit_code == 0, traceback.print_exception(*result.exc_info) scheduler = get_scheduler(**local_sched_config["scheduler"]) all_tasks = [ "list-gnu-full", "list-pypi", ] for task in all_tasks: task_type_desc = scheduler.get_task_type(task) assert task_type_desc assert task_type_desc["type"] == task assert task_type_desc["backoff_factor"] == 1 def test_register_ttypes_filter( mock_pkg_resources, local_sched_config, local_sched_configfile ): """Filtering on one worker should only register its associated task type """ result = CliRunner().invoke( cli, [ "--config-file", local_sched_configfile, "task-type", "register", "--plugins", "lister.gnu", ], ) assert result.exit_code == 0, traceback.print_exception(*result.exc_info) scheduler = get_scheduler(**local_sched_config["scheduler"]) all_tasks = [ "list-gnu-full", ] for task in all_tasks: task_type_desc = scheduler.get_task_type(task) assert task_type_desc assert task_type_desc["type"] == task assert task_type_desc["backoff_factor"] == 1 diff --git a/swh/scheduler/tests/test_init.py b/swh/scheduler/tests/test_init.py new file mode 100644 index 0000000..9a97548 --- /dev/null +++ b/swh/scheduler/tests/test_init.py @@ -0,0 +1,77 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import inspect + +import pytest + +from swh.scheduler import get_scheduler +from swh.scheduler.api.client import RemoteScheduler +from swh.scheduler.backend import SchedulerBackend +from swh.scheduler.interface import SchedulerInterface + +SERVER_IMPLEMENTATIONS = [ + ("remote", RemoteScheduler, {"url": "localhost"}), + ("local", SchedulerBackend, {"db": "something"}), +] + + +@pytest.fixture +def mock_psycopg2(mocker): + mocker.patch("swh.scheduler.backend.psycopg2.pool") + + +def test_init_get_scheduler_failure(): + with pytest.raises(ValueError, match="Unknown Scheduler class"): + get_scheduler("unknown-scheduler-storage") + + +@pytest.mark.parametrize("class_name,expected_class,kwargs", SERVER_IMPLEMENTATIONS) +def test_init_get_scheduler(class_name, expected_class, kwargs, mock_psycopg2): + concrete_scheduler = get_scheduler(class_name, **kwargs) + assert isinstance(concrete_scheduler, expected_class) + assert isinstance(concrete_scheduler, SchedulerInterface) + + +@pytest.mark.parametrize("class_name,expected_class,kwargs", SERVER_IMPLEMENTATIONS) +def test_init_get_scheduler_deprecation_warning( + class_name, expected_class, kwargs, mock_psycopg2 +): + with pytest.warns(DeprecationWarning): + concrete_scheduler = get_scheduler(class_name, args=kwargs) + assert isinstance(concrete_scheduler, expected_class) + + +def test_types(swh_scheduler) -> None: + """Checks all methods of SchedulerInterface are implemented by this + backend, and that they have the same signature.""" + # Create an instance of the protocol (which cannot be instantiated + # directly, so this creates a subclass, then instantiates it) + interface = type("_", (SchedulerInterface,), {})() + + missing_methods = [] + + for meth_name in dir(interface): + if meth_name.startswith("_"): + continue + interface_meth = getattr(interface, meth_name) + try: + concrete_meth = getattr(swh_scheduler, meth_name) + except AttributeError: + missing_methods.append(meth_name) + continue + + expected_signature = inspect.signature(interface_meth) + actual_signature = inspect.signature(concrete_meth) + + assert expected_signature == actual_signature, meth_name + + assert missing_methods == [] + + # If all the assertions above succeed, then this one should too. + # But there's no harm in double-checking. + # And we could replace the assertions above by this one, but unlike + # the assertions above, it doesn't explain what is missing. + assert isinstance(swh_scheduler, SchedulerInterface) diff --git a/swh/scheduler/tests/test_server.py b/swh/scheduler/tests/test_server.py index 989c91c..b5e1166 100644 --- a/swh/scheduler/tests/test_server.py +++ b/swh/scheduler/tests/test_server.py @@ -1,105 +1,90 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import copy - import pytest import yaml from swh.scheduler.api.server import load_and_check_config def prepare_config_file(tmpdir, content, name="config.yml"): """Prepare configuration file in `$tmpdir/name` with content `content`. Args: tmpdir (LocalPath): root directory content (str/dict): Content of the file either as string or as a dict. If a dict, converts the dict into a yaml string. name (str): configuration filename Returns path (str) of the configuration file prepared. """ config_path = tmpdir / name if isinstance(content, dict): # convert if needed content = yaml.dump(content) config_path.write_text(content, encoding="utf-8") # pytest on python3.5 does not support LocalPath manipulation, so # convert path to string return str(config_path) @pytest.mark.parametrize("scheduler_class", [None, ""]) def test_load_and_check_config_no_configuration(scheduler_class): """Inexistent configuration files raises""" with pytest.raises(EnvironmentError, match="Configuration file must be defined"): load_and_check_config(scheduler_class) def test_load_and_check_config_inexistent_fil(): """Inexistent config filepath should raise""" config_path = "/some/inexistent/config.yml" expected_error = f"Configuration file {config_path} does not exist" with pytest.raises(FileNotFoundError, match=expected_error): load_and_check_config(config_path) def test_load_and_check_config_wrong_configuration(tmpdir): """Wrong configuration raises""" config_path = prepare_config_file(tmpdir, "something: useless") with pytest.raises(KeyError, match="Missing '%scheduler' configuration"): load_and_check_config(config_path) def test_load_and_check_config_remote_config_local_type_raise(tmpdir): """Configuration without 'local' storage is rejected""" - config = {"scheduler": {"cls": "remote", "args": {}}} + config = {"scheduler": {"cls": "remote"}} config_path = prepare_config_file(tmpdir, config) expected_error = ( "The scheduler backend can only be started with a 'local'" " configuration" ) with pytest.raises(ValueError, match=expected_error): load_and_check_config(config_path, type="local") def test_load_and_check_config_local_incomplete_configuration(tmpdir): """Incomplete 'local' configuration should raise""" - config = { - "scheduler": { - "cls": "local", - "args": {"db": "database", "something": "needed-for-test",}, - } - } - - for key in ["db", "args"]: - c = copy.deepcopy(config) - if key == "db": - source = c["scheduler"]["args"] - else: - source = c["scheduler"] - source.pop(key) - config_path = prepare_config_file(tmpdir, c) - expected_error = f"Invalid configuration; missing '{key}' config entry" - with pytest.raises(KeyError, match=expected_error): - load_and_check_config(config_path) + config = {"scheduler": {"cls": "local", "something": "needed-for-test",}} + + config_path = prepare_config_file(tmpdir, config) + expected_error = "Invalid configuration; missing 'db' config entry" + with pytest.raises(KeyError, match=expected_error): + load_and_check_config(config_path) def test_load_and_check_config_local_config_fine(tmpdir): """Local configuration is fine""" - config = {"scheduler": {"cls": "local", "args": {"db": "db",}}} + config = {"scheduler": {"cls": "local", "db": "db",}} config_path = prepare_config_file(tmpdir, config) cfg = load_and_check_config(config_path, type="local") assert cfg == config def test_load_and_check_config_remote_config_fine(tmpdir): """Remote configuration is fine""" - config = {"scheduler": {"cls": "remote", "args": {}}} + config = {"scheduler": {"cls": "remote"}} config_path = prepare_config_file(tmpdir, config) cfg = load_and_check_config(config_path, type="any") - assert cfg == config