diff --git a/swh/indexer/orchestrator.py b/swh/indexer/orchestrator.py --- a/swh/indexer/orchestrator.py +++ b/swh/indexer/orchestrator.py @@ -10,7 +10,6 @@ from swh.core.config import SWHConfig from swh.core.utils import grouper from swh.scheduler import utils -from . import TASK_NAMES, INDEXER_CLASSES def get_class(clazz): @@ -57,6 +56,9 @@ """ CONFIG_BASE_FILENAME = 'indexer/orchestrator' + # Overridable in child classes. + from . import TASK_NAMES, INDEXER_CLASSES + DEFAULT_CONFIG = { 'indexers': ('dict', { 'mimetype': { @@ -66,25 +68,26 @@ }), } - def __init__(self): - super().__init__() - self.config = self.parse_config_file() - indexer_names = list(self.config['indexers'].keys()) - random.shuffle(indexer_names) + def prepare(self): + super().prepare() + self.prepare_tasks() + def prepare_tasks(self): + indexer_names = list(self.config['indexers']) + random.shuffle(indexer_names) indexers = {} tasks = {} for name in indexer_names: - if name not in TASK_NAMES: + if name not in self.TASK_NAMES: raise ValueError('%s must be one of %s' % ( - name, TASK_NAMES.keys())) + name, ', '.join(self.TASK_NAMES))) opts = self.config['indexers'][name] indexers[name] = ( - INDEXER_CLASSES[name], + self.INDEXER_CLASSES[name], opts['check_presence'], opts['batch_size']) - tasks[name] = utils.get_task(TASK_NAMES[name]) + tasks[name] = utils.get_task(self.TASK_NAMES[name]) self.indexers = indexers self.tasks = tasks @@ -108,7 +111,10 @@ policy_update=policy_update) celery_tasks.append(celery_task) - group(celery_tasks).delay() + self._run_tasks(celery_tasks) + + def _run_tasks(self, celery_tasks): + group(celery_tasks).delay() class OrchestratorAllContentsIndexer(BaseOrchestratorIndexer): diff --git a/swh/indexer/tests/test_orchestrator.py b/swh/indexer/tests/test_orchestrator.py new file mode 100644 --- /dev/null +++ b/swh/indexer/tests/test_orchestrator.py @@ -0,0 +1,163 @@ +# Copyright (C) 2018 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import unittest +from nose.tools import istest + +from swh.indexer.orchestrator import BaseOrchestratorIndexer +from swh.indexer.indexer import RevisionIndexer +from swh.indexer.tests.test_utils import MockIndexerStorage +from swh.scheduler.task import Task + + +class BaseTestIndexer(RevisionIndexer): + ADDITIONAL_CONFIG = { + 'tools': ('dict', { + 'name': 'foo', + 'version': 'bar', + 'configuration': {} + }), + } + + def prepare(self): + self.idx_storage = MockIndexerStorage() + + def check(self): + pass + + def filter(self, ids): + self.filtered = ids + return ids + + def index(self, ids): + self.indexed = ids + return [id_ + '_indexed_by_' + self.__class__.__name__ + for id_ in ids] + + def persist_index_computations(self, result, policy_update): + self.persisted = result + + +class Indexer1(BaseTestIndexer): + def filter(self, ids): + return super().filter([id_ for id_ in ids if '1' in id_]) + + +class Indexer2(BaseTestIndexer): + def filter(self, ids): + return super().filter([id_ for id_ in ids if '2' in id_]) + + +class Indexer3(BaseTestIndexer): + def filter(self, ids): + return super().filter([id_ for id_ in ids if '3' in id_]) + + +class Indexer1Task(Task): + pass + + +class Indexer2Task(Task): + pass + + +class Indexer3Task(Task): + pass + + +class TestOrchestrator12(BaseOrchestratorIndexer): + TASK_NAMES = { + 'indexer1': 'swh.indexer.tests.test_orchestrator.Indexer1Task', + 'indexer2': 'swh.indexer.tests.test_orchestrator.Indexer2Task', + 'indexer3': 'swh.indexer.tests.test_orchestrator.Indexer3Task', + } + + INDEXER_CLASSES = { + 'indexer1': 'swh.indexer.tests.test_orchestrator.Indexer1', + 'indexer2': 'swh.indexer.tests.test_orchestrator.Indexer2', + 'indexer3': 'swh.indexer.tests.test_orchestrator.Indexer3', + } + + def __init__(self): + super().__init__() + self.running_tasks = [] + + def prepare(self): + self.config = { + 'indexers': { + 'indexer1': { + 'batch_size': 2, + 'check_presence': True, + }, + 'indexer2': { + 'batch_size': 2, + 'check_presence': True, + }, + } + } + self.prepare_tasks() + + def _run_tasks(self, celery_tasks): + self.running_tasks.extend(celery_tasks) + + +class OrchestratorTest(unittest.TestCase): + maxDiff = None + + @istest + def orchestrator_filter(self): + o = TestOrchestrator12() + o.prepare() + o.run(['id12', 'id2']) + self.assertCountEqual(o.running_tasks, [ + {'args': (), + 'chord_size': None, + 'immutable': False, + 'kwargs': {'ids': ['id12'], + 'policy_update': 'ignore-dups'}, + 'options': {}, + 'subtask_type': None, + 'task': 'swh.indexer.tests.test_orchestrator.Indexer1Task'}, + {'args': (), + 'chord_size': None, + 'immutable': False, + 'kwargs': {'ids': ['id12', 'id2'], + 'policy_update': 'ignore-dups'}, + 'options': {}, + 'subtask_type': None, + 'task': 'swh.indexer.tests.test_orchestrator.Indexer2Task'}, + ]) + + @istest + def orchestrator_batch(self): + o = TestOrchestrator12() + o.prepare() + o.run(['id12', 'id2a', 'id2b', 'id2c']) + self.assertCountEqual(o.running_tasks, [ + {'args': (), + 'chord_size': None, + 'immutable': False, + 'kwargs': {'ids': ['id12'], + 'policy_update': 'ignore-dups'}, + 'options': {}, + 'subtask_type': None, + 'task': 'swh.indexer.tests.test_orchestrator.Indexer1Task'}, + {'args': (), + 'chord_size': None, + 'immutable': False, + 'kwargs': {'ids': ['id12', 'id2a'], + 'policy_update': 'ignore-dups'}, + 'options': {}, + 'subtask_type': None, + 'task': 'swh.indexer.tests.test_orchestrator.Indexer2Task'}, + {'args': (), + 'chord_size': None, + 'immutable': False, + 'kwargs': {'ids': ['id2b', 'id2c'], + 'policy_update': 'ignore-dups'}, + 'options': {}, + 'subtask_type': None, + 'task': 'swh.indexer.tests.test_orchestrator.Indexer2Task'}, + ])