diff --git a/swh/lister/cli.py b/swh/lister/cli.py --- a/swh/lister/cli.py +++ b/swh/lister/cli.py @@ -194,7 +194,8 @@ config = deepcopy(ctx.obj['config']) if options: - config.update(parse_options(options)[1]) + _, kw = parse_options(options) + config.update(kw) config['priority'] = priority config['policy'] = 'oneshot' diff --git a/swh/lister/debian/__init__.py b/swh/lister/debian/__init__.py --- a/swh/lister/debian/__init__.py +++ b/swh/lister/debian/__init__.py @@ -3,11 +3,11 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Any, List, Mapping, Optional +from typing import Any, List, Mapping def debian_init(db_engine, lister=None, - override_conf: Optional[Mapping[str, Any]] = None, + override_conf: Mapping[str, Any] = {}, distributions: List[str] = ['stretch', 'buster'], area_names: List[str] = ['main', 'contrib', 'non-free']): """Initialize the debian data model. @@ -15,26 +15,27 @@ Args: db_engine: SQLAlchemy manipulation database object lister: Debian lister instance. None by default. - override_conf: Override conf to pass to instantiate a lister. - None by default + override_conf: Override conf to pass to instantiate a lister distributions: Default distribution to build """ + distribution_name = 'Debian' from swh.storage.schemata.distribution import ( Distribution, Area) if lister is None: from .lister import DebianLister - lister = DebianLister(override_config=override_conf) + lister = DebianLister(distribution=distribution_name, + override_config=override_conf) if not lister.db_session\ .query(Distribution)\ - .filter(Distribution.name == 'Debian')\ + .filter(Distribution.name == distribution_name)\ .one_or_none(): d = Distribution( - name='Debian', + name=distribution_name, type='deb', mirror_uri='http://deb.debian.org/debian/') lister.db_session.add(d) diff --git a/swh/lister/debian/lister.py b/swh/lister/debian/lister.py --- a/swh/lister/debian/lister.py +++ b/swh/lister/debian/lister.py @@ -13,6 +13,7 @@ from debian.deb822 import Sources from sqlalchemy.orm import joinedload, load_only from sqlalchemy.schema import CreateTable, DropTable +from typing import Mapping from swh.lister.debian.models import ( AreaSnapshot, Distribution, DistributionSnapshot, Package, @@ -38,9 +39,24 @@ LISTER_NAME = 'debian' instance = 'debian' - def __init__(self, override_config=None): + def __init__(self, distribution: str = 'Debian', + date: datetime.datetime = None, + override_config: Mapping = {}): + """Initialize the debian lister for a given distribution at a given + date. + + Args: + distribution: name of the distribution (e.g. "Debian") + date: date the snapshot is taken (defaults to + now) + override_config: Override configuration + + """ ListerHttpTransport.__init__(self, url="notused") ListerBase.__init__(self, override_config=override_config) + self.distribution = override_config.get('distribution', distribution) + self.date = override_config.get('date', date) or datetime.datetime.now( + tz=datetime.timezone.utc) def transport_request(self, identifier): """Subvert ListerHttpTransport.transport_request, to try several @@ -189,29 +205,25 @@ return self.scheduler.create_tasks(tasks) - def run(self, distribution='Debian', date=None): + def run(self): """Run the lister for a given (distribution, area) tuple. - Args: - distribution (str): name of the distribution (e.g. "Debian") - date (datetime.datetime): date the snapshot is taken (defaults to - now) """ distribution = self.db_session\ .query(Distribution)\ .options(joinedload(Distribution.areas))\ - .filter(Distribution.name == distribution)\ + .filter(Distribution.name == self.distribution)\ .one_or_none() if not distribution: raise ValueError("Distribution %s is not registered" % - distribution) + self.distribution) if not distribution.type == 'deb': raise ValueError("Distribution %s is not a Debian derivative" % distribution) - date = date or datetime.datetime.now(tz=datetime.timezone.utc) + date = self.date logger.debug('Creating snapshot for distribution %s on date %s' % (distribution, date)) diff --git a/swh/lister/debian/tasks.py b/swh/lister/debian/tasks.py --- a/swh/lister/debian/tasks.py +++ b/swh/lister/debian/tasks.py @@ -10,7 +10,7 @@ @shared_task(name=__name__ + '.DebianListerTask') def list_debian_distribution(distribution, **lister_args): '''List a Debian distribution''' - DebianLister(**lister_args).run(distribution) + DebianLister(distribution=distribution, **lister_args).run() @shared_task(name=__name__ + '.ping') diff --git a/swh/lister/debian/tests/test_lister.py b/swh/lister/debian/tests/test_lister.py --- a/swh/lister/debian/tests/test_lister.py +++ b/swh/lister/debian/tests/test_lister.py @@ -14,7 +14,7 @@ """ # Run the lister - lister_debian.run(distribution="Debian") + lister_debian.run() r = lister_debian.scheduler.search_tasks(task_type='load-deb-package') assert len(r) == 151 diff --git a/swh/lister/debian/tests/test_tasks.py b/swh/lister/debian/tests/test_tasks.py --- a/swh/lister/debian/tests/test_tasks.py +++ b/swh/lister/debian/tests/test_tasks.py @@ -1,3 +1,8 @@ +# Copyright (C) 2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + from unittest.mock import patch @@ -22,5 +27,5 @@ res.wait() assert res.successful() - lister.assert_called_once_with() - lister.run.assert_called_once_with('stretch') + lister.assert_called_once_with(distribution='stretch') + lister.run.assert_called_once_with()