diff --git a/swh/indexer/orchestrator.py b/swh/indexer/orchestrator.py --- a/swh/indexer/orchestrator.py +++ b/swh/indexer/orchestrator.py @@ -93,6 +93,7 @@ self.tasks = tasks def run(self, ids): + all_results = [] for name, (idx_class, filtering, batch_size) in self.indexers.items(): if filtering: policy_update = 'ignore-dups' @@ -111,10 +112,12 @@ policy_update=policy_update) celery_tasks.append(celery_task) - self._run_tasks(celery_tasks) + all_results.append(self._run_tasks(celery_tasks)) + + return all_results def _run_tasks(self, celery_tasks): - group(celery_tasks).delay() + return group(celery_tasks).delay() class OrchestratorAllContentsIndexer(BaseOrchestratorIndexer): diff --git a/swh/indexer/tests/__init__.py b/swh/indexer/tests/__init__.py --- a/swh/indexer/tests/__init__.py +++ b/swh/indexer/tests/__init__.py @@ -1,5 +1,21 @@ from os import path import swh.indexer +from celery import shared_task +from celery.contrib.testing.worker import _start_worker_thread +from celery import current_app + +__all__ = ['start_worker_thread'] SQL_DIR = path.join(path.dirname(swh.indexer.__file__), 'sql') + + +def start_worker_thread(): + return _start_worker_thread(current_app) + + +# Needed to pass an assertion, see +# https://github.com/celery/celery/pull/5111 +@shared_task(name='celery.ping') +def ping(): + return 'pong' diff --git a/swh/indexer/tests/test_orchestrator.py b/swh/indexer/tests/test_orchestrator.py --- a/swh/indexer/tests/test_orchestrator.py +++ b/swh/indexer/tests/test_orchestrator.py @@ -6,12 +6,15 @@ import unittest from swh.indexer.orchestrator import BaseOrchestratorIndexer -from swh.indexer.indexer import RevisionIndexer -from swh.indexer.tests.test_utils import MockIndexerStorage +from swh.indexer.indexer import BaseIndexer +from swh.indexer.tests.test_utils import MockIndexerStorage, MockStorage from swh.scheduler.task import Task +from swh.scheduler.tests.celery_testing import CeleryTestFixture +from . import start_worker_thread -class BaseTestIndexer(RevisionIndexer): + +class BaseTestIndexer(BaseIndexer): ADDITIONAL_CONFIG = { 'tools': ('dict', { 'name': 'foo', @@ -22,16 +25,20 @@ def prepare(self): self.idx_storage = MockIndexerStorage() + self.storage = MockStorage() def check(self): pass def filter(self, ids): - self.filtered = ids + self.filtered.append(ids) return ids + def run(self, ids, policy_update): + return self.index(ids) + def index(self, ids): - self.indexed = ids + self.indexed.append(ids) return [id_ + '_indexed_by_' + self.__class__.__name__ for id_ in ids] @@ -40,30 +47,42 @@ class Indexer1(BaseTestIndexer): + filtered = [] + indexed = [] + def filter(self, ids): return super().filter([id_ for id_ in ids if '1' in id_]) class Indexer2(BaseTestIndexer): + filtered = [] + indexed = [] + def filter(self, ids): return super().filter([id_ for id_ in ids if '2' in id_]) class Indexer3(BaseTestIndexer): + filtered = [] + indexed = [] + def filter(self, ids): return super().filter([id_ for id_ in ids if '3' in id_]) class Indexer1Task(Task): - pass + def run(self, *args, **kwargs): + return Indexer1().run(*args, **kwargs) class Indexer2Task(Task): - pass + def run(self, *args, **kwargs): + return Indexer2().run(*args, **kwargs) class Indexer3Task(Task): - pass + def run(self, *args, **kwargs): + return Indexer3().run(*args, **kwargs) class TestOrchestrator12(BaseOrchestratorIndexer): @@ -98,15 +117,35 @@ } self.prepare_tasks() + +class MockedTestOrchestrator12(TestOrchestrator12): def _run_tasks(self, celery_tasks): self.running_tasks.extend(celery_tasks) -class OrchestratorTest(unittest.TestCase): +class OrchestratorTest(CeleryTestFixture, unittest.TestCase): + def test_orchestrator_filter(self): + with start_worker_thread(): + o = TestOrchestrator12() + o.prepare() + promises = o.run(['id12', 'id2']) + results = [] + for promise in promises: + results.append(promise.get(timeout=10)) + self.assertCountEqual( + results, + [[['id12_indexed_by_Indexer1']], + [['id12_indexed_by_Indexer2', + 'id2_indexed_by_Indexer2']]]) + self.assertEqual(Indexer2.indexed, [['id12', 'id2']]) + self.assertEqual(Indexer1.indexed, [['id12']]) + + +class MockedOrchestratorTest(unittest.TestCase): maxDiff = None - def test_orchestrator_filter(self): - o = TestOrchestrator12() + def test_mocked_orchestrator_filter(self): + o = MockedTestOrchestrator12() o.prepare() o.run(['id12', 'id2']) self.assertCountEqual(o.running_tasks, [ @@ -128,8 +167,8 @@ 'task': 'swh.indexer.tests.test_orchestrator.Indexer2Task'}, ]) - def test_orchestrator_batch(self): - o = TestOrchestrator12() + def test_mocked_orchestrator_batch(self): + o = MockedTestOrchestrator12() o.prepare() o.run(['id12', 'id2a', 'id2b', 'id2c']) self.assertCountEqual(o.running_tasks, [