diff --git a/swh/lister/bitbucket/lister.py b/swh/lister/bitbucket/lister.py index 0d7dcf3..45b573c 100644 --- a/swh/lister/bitbucket/lister.py +++ b/swh/lister/bitbucket/lister.py @@ -1,80 +1,79 @@ # Copyright (C) 2017-2019 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import iso8601 from datetime import datetime from urllib import parse from swh.lister.bitbucket.models import BitBucketModel from swh.lister.core.indexing_lister import IndexingHttpLister logger = logging.getLogger(__name__) class BitBucketLister(IndexingHttpLister): PATH_TEMPLATE = '/repositories?after=%s' MODEL = BitBucketModel LISTER_NAME = 'bitbucket' DEFAULT_URL = 'https://api.bitbucket.org/2.0' instance = 'bitbucket' default_min_bound = datetime.utcfromtimestamp(0) - def __init__(self, api_baseurl=None, override_config=None, per_page=100): - super().__init__( - api_baseurl=api_baseurl, override_config=override_config) + def __init__(self, url=None, override_config=None, per_page=100): + super().__init__(url=url, override_config=override_config) per_page = self.config.get('per_page', per_page) self.PATH_TEMPLATE = '%s&pagelen=%s' % ( self.PATH_TEMPLATE, per_page) def get_model_from_repo(self, repo): return { 'uid': repo['uuid'], 'indexable': iso8601.parse_date(repo['created_on']), 'name': repo['name'], 'full_name': repo['full_name'], 'html_url': repo['links']['html']['href'], 'origin_url': repo['links']['clone'][0]['href'], 'origin_type': repo['scm'], } def get_next_target_from_response(self, response): """This will read the 'next' link from the api response if any and return it as a datetime. Args: response (Response): requests' response from api call Returns: next date as a datetime """ body = response.json() next_ = body.get('next') if next_ is not None: next_ = parse.urlparse(next_) return iso8601.parse_date(parse.parse_qs(next_.query)['after'][0]) def transport_response_simplified(self, response): repos = response.json()['values'] return [self.get_model_from_repo(repo) for repo in repos] def request_uri(self, identifier): identifier = parse.quote(identifier.isoformat()) return super().request_uri(identifier or '1970-01-01') def is_within_bounds(self, inner, lower=None, upper=None): # values are expected to be datetimes if lower is None and upper is None: ret = True elif lower is None: ret = inner <= upper elif upper is None: ret = inner >= lower else: ret = lower <= inner <= upper return ret diff --git a/swh/lister/bitbucket/tasks.py b/swh/lister/bitbucket/tasks.py index f3f415e..b8fa316 100644 --- a/swh/lister/bitbucket/tasks.py +++ b/swh/lister/bitbucket/tasks.py @@ -1,58 +1,54 @@ # Copyright (C) 2017-2019 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import random from celery import group from swh.scheduler.celery_backend.config import app from .lister import BitBucketLister GROUP_SPLIT = 10000 -def new_lister(api_baseurl='https://api.bitbucket.org/2.0', per_page=100): - return BitBucketLister(api_baseurl=api_baseurl, per_page=per_page) - - @app.task(name=__name__ + '.IncrementalBitBucketLister') def list_bitbucket_incremental(**lister_args): '''Incremental update of the BitBucket forge''' - lister = new_lister(**lister_args) + lister = BitBucketLister(**lister_args) lister.run(min_bound=lister.db_last_index(), max_bound=None) @app.task(name=__name__ + '.RangeBitBucketLister') def _range_bitbucket_lister(start, end, **lister_args): - lister = new_lister(**lister_args) + lister = BitBucketLister(**lister_args) lister.run(min_bound=start, max_bound=end) @app.task(name=__name__ + '.FullBitBucketRelister', bind=True) def list_bitbucket_full(self, split=None, **lister_args): """Full update of the BitBucket forge It's not to be called for an initial listing. """ - lister = new_lister(**lister_args) + lister = BitBucketLister(**lister_args) ranges = lister.db_partition_indices(split or GROUP_SPLIT) if not ranges: self.log.info('Nothing to list') return random.shuffle(ranges) promise = group(_range_bitbucket_lister.s(minv, maxv, **lister_args) for minv, maxv in ranges)() self.log.debug('%s OK (spawned %s subtasks)', (self.name, len(ranges))) try: promise.save() # so that we can restore the GroupResult in tests except (NotImplementedError, AttributeError): self.log.info('Unable to call save_group with current result backend.') return promise.id @app.task(name=__name__ + '.ping') def _ping(): return 'OK' diff --git a/swh/lister/bitbucket/tests/test_tasks.py b/swh/lister/bitbucket/tests/test_tasks.py index 1e02b6f..bd881ab 100644 --- a/swh/lister/bitbucket/tests/test_tasks.py +++ b/swh/lister/bitbucket/tests/test_tasks.py @@ -1,92 +1,89 @@ from time import sleep from celery.result import GroupResult from unittest.mock import patch def test_ping(swh_app, celery_session_worker): res = swh_app.send_task( 'swh.lister.bitbucket.tasks.ping') assert res res.wait() assert res.successful() assert res.result == 'OK' @patch('swh.lister.bitbucket.tasks.BitBucketLister') def test_incremental(lister, swh_app, celery_session_worker): # setup the mocked BitbucketLister lister.return_value = lister lister.db_last_index.return_value = 42 lister.run.return_value = None res = swh_app.send_task( 'swh.lister.bitbucket.tasks.IncrementalBitBucketLister') assert res res.wait() assert res.successful() - lister.assert_called_once_with( - api_baseurl='https://api.bitbucket.org/2.0', per_page=100) + lister.assert_called_once_with() lister.db_last_index.assert_called_once_with() lister.run.assert_called_once_with(min_bound=42, max_bound=None) @patch('swh.lister.bitbucket.tasks.BitBucketLister') def test_range(lister, swh_app, celery_session_worker): # setup the mocked BitbucketLister lister.return_value = lister lister.run.return_value = None res = swh_app.send_task( 'swh.lister.bitbucket.tasks.RangeBitBucketLister', kwargs=dict(start=12, end=42)) assert res res.wait() assert res.successful() - lister.assert_called_once_with( - api_baseurl='https://api.bitbucket.org/2.0', per_page=100) + lister.assert_called_once_with() lister.db_last_index.assert_not_called() lister.run.assert_called_once_with(min_bound=12, max_bound=42) @patch('swh.lister.bitbucket.tasks.BitBucketLister') def test_relister(lister, swh_app, celery_session_worker): # setup the mocked BitbucketLister lister.return_value = lister lister.run.return_value = None lister.db_partition_indices.return_value = [ (i, i+9) for i in range(0, 50, 10)] res = swh_app.send_task( 'swh.lister.bitbucket.tasks.FullBitBucketRelister') assert res res.wait() assert res.successful() # retrieve the GroupResult for this task and wait for all the subtasks # to complete promise_id = res.result assert promise_id promise = GroupResult.restore(promise_id, app=swh_app) for i in range(5): if promise.ready(): break sleep(1) - lister.assert_called_with( - api_baseurl='https://api.bitbucket.org/2.0', per_page=100) + lister.assert_called_with() # one by the FullBitbucketRelister task # + 5 for the RangeBitbucketLister subtasks assert lister.call_count == 6 lister.db_last_index.assert_not_called() lister.db_partition_indices.assert_called_once_with(10000) # lister.run should have been called once per partition interval for i in range(5): assert (dict(min_bound=10*i, max_bound=10*i + 9),) \ in lister.run.call_args_list diff --git a/swh/lister/cli.py b/swh/lister/cli.py index 50bde74..9e1dad4 100644 --- a/swh/lister/cli.py +++ b/swh/lister/cli.py @@ -1,235 +1,235 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import logging import pkg_resources from copy import deepcopy from importlib import import_module import click from sqlalchemy import create_engine from swh.core.cli import CONTEXT_SETTINGS from swh.scheduler import get_scheduler from swh.scheduler.task import SWHTask from swh.lister.core.models import initialize logger = logging.getLogger(__name__) LISTERS = {entry_point.name.split('.', 1)[1]: entry_point for entry_point in pkg_resources.iter_entry_points('swh.workers') if entry_point.name.split('.', 1)[0] == 'lister'} SUPPORTED_LISTERS = list(LISTERS) # the key in this dict is the suffix used to match new task-type to be added. # For example for a task which function name is "list_gitlab_full', the default # value used when inserting a new task-type in the scheduler db will be the one # under the 'full' key below (because it matches xxx_full). DEFAULT_TASK_TYPE = { 'full': { # for tasks like 'list_xxx_full()' 'default_interval': '90 days', 'min_interval': '90 days', 'max_interval': '90 days', 'backoff_factor': 1 }, '*': { # value if not suffix matches 'default_interval': '1 day', 'min_interval': '1 day', 'max_interval': '1 day', 'backoff_factor': 1 }, } def get_lister(lister_name, db_url=None, **conf): """Instantiate a lister given its name. Args: lister_name (str): Lister's name conf (dict): Configuration dict (lister db cnx, policy, priority...) Returns: Tuple (instantiated lister, drop_tables function, init schema function, insert minimum data function) """ if lister_name not in LISTERS: raise ValueError( 'Invalid lister %s: only supported listers are %s' % (lister_name, SUPPORTED_LISTERS)) if db_url: conf['lister'] = {'cls': 'local', 'args': {'db': db_url}} - # To allow api_baseurl override per lister + registry_entry = LISTERS[lister_name].load()() lister_cls = registry_entry['lister'] lister = lister_cls(override_config=conf) return lister @click.group(name='lister', context_settings=CONTEXT_SETTINGS) @click.option('--config-file', '-C', default=None, type=click.Path(exists=True, dir_okay=False,), help="Configuration file.") @click.option('--db-url', '-d', default=None, help='SQLAlchemy DB URL; see ' '') # noqa @click.pass_context def lister(ctx, config_file, db_url): '''Software Heritage Lister tools.''' from swh.core import config ctx.ensure_object(dict) override_conf = {} if db_url: override_conf['lister'] = { 'cls': 'local', 'args': {'db': db_url} } if not config_file: config_file = os.environ.get('SWH_CONFIG_FILENAME') conf = config.read(config_file, override_conf) ctx.obj['config'] = conf ctx.obj['override_conf'] = override_conf @lister.command(name='db-init', context_settings=CONTEXT_SETTINGS) @click.option('--drop-tables', '-D', is_flag=True, default=False, help='Drop tables before creating the database schema') @click.pass_context def db_init(ctx, drop_tables): """Initialize the database model for given listers. """ cfg = ctx.obj['config'] lister_cfg = cfg['lister'] if lister_cfg['cls'] != 'local': click.echo('A local lister configuration is required') ctx.exit(1) db_url = lister_cfg['args']['db'] db_engine = create_engine(db_url) registry = {} for lister, entrypoint in LISTERS.items(): logger.info('Loading lister %s', lister) registry[lister] = entrypoint.load()() logger.info('Initializing database') initialize(db_engine, drop_tables) for lister, entrypoint in LISTERS.items(): registry_entry = registry[lister] init_hook = registry_entry.get('init') if callable(init_hook): logger.info('Calling init hook for %s', lister) init_hook(db_engine) @lister.command(name='register-task-types', context_settings=CONTEXT_SETTINGS) @click.option('--lister', '-l', 'listers', multiple=True, default=('all', ), show_default=True, help='Only registers task-types for these listers', type=click.Choice(['all'] + SUPPORTED_LISTERS)) @click.pass_context def register_task_types(ctx, listers): """Insert missing task-type entries in the scheduler According to declared tasks in each loaded lister plugin. """ cfg = ctx.obj['config'] scheduler = get_scheduler(**cfg['scheduler']) for lister, entrypoint in LISTERS.items(): if 'all' not in listers and lister not in listers: continue logger.info('Loading lister %s', lister) registry_entry = entrypoint.load()() for task_module in registry_entry['task_modules']: mod = import_module(task_module) for task_name in (x for x in dir(mod) if not x.startswith('_')): taskobj = getattr(mod, task_name) if isinstance(taskobj, SWHTask): task_type = task_name.replace('_', '-') task_cfg = registry_entry.get('task_types', {}).get( task_type, {}) ensure_task_type(task_type, taskobj, task_cfg, scheduler) def ensure_task_type(task_type, swhtask, task_config, scheduler): """Ensure a task-type is known by the scheduler Args: task_type (str): the type of the task to check/insert (correspond to the 'type' field in the db) swhtask (SWHTask): the SWHTask instance the task-type correspond to task_config (dict): a dict with specific/overloaded values for the task-type to be created scheduler: the scheduler object used to access the scheduler db """ for suffix, defaults in DEFAULT_TASK_TYPE.items(): if task_type.endswith('-' + suffix): task_type_dict = defaults.copy() break else: task_type_dict = DEFAULT_TASK_TYPE['*'].copy() task_type_dict['type'] = task_type task_type_dict['backend_name'] = swhtask.name if swhtask.__doc__: task_type_dict['description'] = swhtask.__doc__.splitlines()[0] task_type_dict.update(task_config) current_task_type = scheduler.get_task_type(task_type) if current_task_type: # check some stuff if current_task_type['backend_name'] != task_type_dict['backend_name']: logger.warning('Existing task type %s for lister %s has a ' 'different backend name than current ' 'code version provides (%s vs. %s)', task_type, lister, current_task_type['backend_name'], task_type_dict['backend_name'], ) else: logger.info('Create task type %s in scheduler', task_type) logger.debug(' %s', task_type_dict) scheduler.create_task_type(task_type_dict) @lister.command(name='run', context_settings=CONTEXT_SETTINGS, help='Trigger a full listing run for a particular forge ' 'instance. The output of this listing results in ' '"oneshot" tasks in the scheduler db with a priority ' 'defined by the user') @click.option('--lister', '-l', help='Lister to run', type=click.Choice(SUPPORTED_LISTERS)) @click.option('--priority', '-p', default='high', type=click.Choice(['high', 'medium', 'low']), help='Task priority for the listed repositories to ingest') @click.argument('options', nargs=-1) @click.pass_context def run(ctx, lister, priority, options): from swh.scheduler.cli.utils import parse_options config = deepcopy(ctx.obj['config']) if options: config.update(parse_options(options)[1]) config['priority'] = priority config['policy'] = 'oneshot' get_lister(lister, **config).run() if __name__ == '__main__': lister() diff --git a/swh/lister/core/indexing_lister.py b/swh/lister/core/indexing_lister.py index 834d5f9..7d4a38a 100644 --- a/swh/lister/core/indexing_lister.py +++ b/swh/lister/core/indexing_lister.py @@ -1,249 +1,249 @@ # Copyright (C) 2015-2019 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import abc import logging from itertools import count import dateutil from sqlalchemy import func from .lister_transports import ListerHttpTransport from .lister_base import ListerBase logger = logging.getLogger(__name__) class IndexingLister(ListerBase): """Lister* intermediate class for any service that follows the pattern: - The service must report at least one stable unique identifier, known herein as the UID value, for every listed repository. - If the service splits the list of repositories into sublists, it must report at least one stable and sorted index identifier for every listed repository, known herein as the indexable value, which can be used as part of the service endpoint query to request a sublist beginning from that index. This might be the UID if the UID is monotonic. - Client sends a request to list repositories starting from a given index. - Client receives structured (json/xml/etc) response with information about a sequential series of repositories starting from that index and, if necessary/available, some indication of the URL or index for fetching the next series of repository data. See :class:`swh.lister.core.lister_base.ListerBase` for more details. This class cannot be instantiated. To create a new Lister for a source code listing service that follows the model described above, you must subclass this class and provide the required overrides in addition to any unmet implementation/override requirements of this class's base. (see parent class and member docstrings for details) Required Overrides:: def get_next_target_from_response """ flush_packet_db = 20 """Number of iterations in-between write flushes of lister repositories to db (see fn:`run`). """ default_min_bound = '' """Default initialization value for the minimum boundary index to use when undefined (see fn:`run`). """ @abc.abstractmethod def get_next_target_from_response(self, response): """Find the next server endpoint identifier given the entire response. Implementation of this method depends on the server API spec and the shape of the network response object returned by the transport_request method. Args: response (transport response): response page from the server Returns: index of next page, possibly extracted from a next href url """ pass # You probably don't need to override anything below this line. def filter_before_inject(self, models_list): """Overrides ListerBase.filter_before_inject Bounds query results by this Lister's set max_index. """ models_list = [ m for m in models_list if self.is_within_bounds(m['indexable'], None, self.max_index) ] return models_list def db_query_range(self, start, end): """Look in the db for a range of repositories with indexable values in the range [start, end] Args: start (model indexable type): start of desired indexable range end (model indexable type): end of desired indexable range Returns: a list of sqlalchemy.ext.declarative.declarative_base objects with indexable values within the given range """ retlist = self.db_session.query(self.MODEL) if start is not None: retlist = retlist.filter(self.MODEL.indexable >= start) if end is not None: retlist = retlist.filter(self.MODEL.indexable <= end) return retlist def db_partition_indices(self, partition_size): """Describe an index-space compartmentalization of the db table in equal sized chunks. This is used to describe min&max bounds for parallelizing fetch tasks. Args: partition_size (int): desired size to make each partition Returns: a list of tuples (begin, end) of indexable value that declare approximately equal-sized ranges of existing repos """ n = max(self.db_num_entries(), 10) partition_size = min(partition_size, n) n_partitions = n // partition_size min_index = self.db_first_index() max_index = self.db_last_index() if not min_index or not max_index: # Nothing to list return [] if isinstance(min_index, str): def format_bound(bound): return bound.isoformat() min_index = dateutil.parser.parse(min_index) max_index = dateutil.parser.parse(max_index) else: def format_bound(bound): return bound partition_width = (max_index - min_index) / n_partitions partitions = [ [ format_bound(min_index + i * partition_width), format_bound(min_index + (i+1) * partition_width), ] for i in range(n_partitions) ] # Remove bounds for lowest and highest partition partitions[0][0] = None partitions[-1][1] = None return [tuple(partition) for partition in partitions] def db_first_index(self): """Look in the db for the smallest indexable value Returns: the smallest indexable value of all repos in the db """ t = self.db_session.query(func.min(self.MODEL.indexable)).first() if t: return t[0] def db_last_index(self): """Look in the db for the largest indexable value Returns: the largest indexable value of all repos in the db """ t = self.db_session.query(func.max(self.MODEL.indexable)).first() if t: return t[0] def disable_deleted_repo_tasks(self, start, end, keep_these): """Disable tasks for repos that no longer exist between start and end. Args: start: beginning of range to disable end: end of range to disable keep_these (uid list): do not disable repos with uids in this list """ if end is None: end = self.db_last_index() if not self.is_within_bounds(end, None, self.max_index): end = self.max_index deleted_repos = self.winnow_models( self.db_query_range(start, end), self.MODEL.uid, keep_these ) tasks_to_disable = [repo.task_id for repo in deleted_repos if repo.task_id is not None] if tasks_to_disable: self.scheduler.disable_tasks(tasks_to_disable) for repo in deleted_repos: repo.task_id = None def run(self, min_bound=None, max_bound=None): """Main entry function. Sequentially fetches repository data from the service according to the basic outline in the class docstring, continually fetching sublists until either there is no next index reference given or the given next index is greater than the desired max_bound. Args: min_bound (indexable type): optional index to start from max_bound (indexable type): optional index to stop at Returns: nothing """ self.min_index = min_bound self.max_index = max_bound def ingest_indexes(): index = min_bound or self.default_min_bound for i in count(1): response, injected_repos = self.ingest_data(index) if not response and not injected_repos: logger.info('No response from api server, stopping') return next_index = self.get_next_target_from_response(response) # Determine if any repos were deleted, and disable their tasks. keep_these = list(injected_repos.keys()) self.disable_deleted_repo_tasks(index, next_index, keep_these) # termination condition if next_index is None or next_index == index: logger.info('stopping after index %s, no next link found', index) return index = next_index logger.debug('Index: %s', index) yield i for i in ingest_indexes(): if (i % self.flush_packet_db) == 0: logger.debug('Flushing updates at index %s', i) self.db_session.commit() self.db_session = self.mk_session() self.db_session.commit() self.db_session = self.mk_session() class IndexingHttpLister(ListerHttpTransport, IndexingLister): """Convenience class for ensuring right lookup and init order when combining IndexingLister and ListerHttpTransport.""" - def __init__(self, api_baseurl=None, override_config=None): + def __init__(self, url=None, override_config=None): IndexingLister.__init__(self, override_config=override_config) - ListerHttpTransport.__init__(self, api_baseurl=api_baseurl) + ListerHttpTransport.__init__(self, url=url) diff --git a/swh/lister/core/lister_transports.py b/swh/lister/core/lister_transports.py index ff0827c..55188fd 100644 --- a/swh/lister/core/lister_transports.py +++ b/swh/lister/core/lister_transports.py @@ -1,229 +1,229 @@ # Copyright (C) 2017-2018 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import abc import random from datetime import datetime from email.utils import parsedate from pprint import pformat import logging import requests import xmltodict try: from swh.lister._version import __version__ except ImportError: __version__ = 'devel' from .abstractattribute import AbstractAttribute from .lister_base import FetchError logger = logging.getLogger(__name__) class ListerHttpTransport(abc.ABC): """Use the Requests library for making Lister endpoint requests. To be used in conjunction with ListerBase or a subclass of it. """ DEFAULT_URL = None PATH_TEMPLATE = AbstractAttribute('string containing a python string' ' format pattern that produces the API' ' endpoint path for listing stored' ' repositories when given an index.' ' eg. "/repositories?after=%s".' 'To be implemented in the API-specific' ' class inheriting this.') EXPECTED_STATUS_CODES = (200, 429, 403, 404) def request_headers(self): """Returns dictionary of any request headers needed by the server. MAY BE OVERRIDDEN if request headers are needed. """ return { 'User-Agent': 'Software Heritage lister (%s)' % self.lister_version } def request_instance_credentials(self): """Returns dictionary of any credentials configuration needed by the forge instance to list. The 'credentials' configuration is expected to be a dict of multiple levels. The first level is the lister's name, the second is the lister's instance name, which value is expected to be a list of credential structures (typically a couple username/password). For example: credentials: github: # github lister github: # has only one instance (so far) - username: some password: somekey - username: one password: onekey - ... gitlab: # gitlab lister riseup: # has many instances - username: someone password: ... - ... gitlab: - username: someone password: ... - ... Returns: list of credential dicts for the current lister. """ all_creds = self.config.get('credentials') if not all_creds: return [] lister_creds = all_creds.get(self.LISTER_NAME, {}) creds = lister_creds.get(self.instance, []) return creds def request_uri(self, identifier): """Get the full request URI given the transport_request identifier. MAY BE OVERRIDDEN if something more complex than the PATH_TEMPLATE is required. """ path = self.PATH_TEMPLATE % identifier - return self.api_baseurl + path + return self.url + path def request_params(self, identifier): """Get the full parameters passed to requests given the transport_request identifier. This uses credentials if any are provided (see request_instance_credentials). MAY BE OVERRIDDEN if something more complex than the request headers is needed. """ params = {} params['headers'] = self.request_headers() or {} creds = self.request_instance_credentials() if not creds: return params auth = random.choice(creds) if creds else None if auth: params['auth'] = (auth['username'], auth['password']) return params def transport_quota_check(self, response): """Implements ListerBase.transport_quota_check with standard 429 code check for HTTP with Requests library. MAY BE OVERRIDDEN if the server notifies about rate limits in a non-standard way that doesn't use HTTP 429 and the Retry-After response header. ( https://tools.ietf.org/html/rfc6585#section-4 ) """ if response.status_code == 429: # HTTP too many requests retry_after = response.headers.get('Retry-After', self.back_off()) try: # might be seconds return True, float(retry_after) except Exception: # might be http-date at_date = datetime(*parsedate(retry_after)[:6]) from_now = (at_date - datetime.today()).total_seconds() + 5 return True, max(0, from_now) else: # response ok self.reset_backoff() return False, 0 - def __init__(self, api_baseurl=None): - if not api_baseurl: - api_baseurl = self.config.get('api_baseurl') - if not api_baseurl: - api_baseurl = self.DEFAULT_URL - if not api_baseurl: - raise NameError('HTTP Lister Transport requires api_baseurl.') - self.api_baseurl = api_baseurl # eg. 'https://api.github.com' + def __init__(self, url=None): + if not url: + url = self.config.get('url') + if not url: + url = self.DEFAULT_URL + if not url: + raise NameError('HTTP Lister Transport requires an url.') + self.url = url # eg. 'https://api.github.com' self.session = requests.Session() self.lister_version = __version__ def _transport_action(self, identifier, method='get'): """Permit to ask information to the api prior to actually executing query. """ path = self.request_uri(identifier) params = self.request_params(identifier) try: if method == 'head': response = self.session.head(path, **params) else: response = self.session.get(path, **params) except requests.exceptions.ConnectionError as e: logger.warning('Failed to fetch %s: %s', path, e) raise FetchError(e) else: if response.status_code not in self.EXPECTED_STATUS_CODES: raise FetchError(response) return response def transport_head(self, identifier): """Retrieve head information on api. """ return self._transport_action(identifier, method='head') def transport_request(self, identifier): """Implements ListerBase.transport_request for HTTP using Requests. Retrieve get information on api. """ return self._transport_action(identifier) def transport_response_to_string(self, response): """Implements ListerBase.transport_response_to_string for HTTP given Requests responses. """ s = pformat(response.request.path_url) s += '\n#\n' + pformat(response.request.headers) s += '\n#\n' + pformat(response.status_code) s += '\n#\n' + pformat(response.headers) s += '\n#\n' try: # json? s += pformat(response.json()) except Exception: # not json try: # xml? s += pformat(xmltodict.parse(response.text)) except Exception: # not xml s += pformat(response.text) return s class ListerOnePageApiTransport(ListerHttpTransport): """Leverage requests library to retrieve basic html page and parse result. To be used in conjunction with ListerBase or a subclass of it. """ PAGE = AbstractAttribute("The server api's unique page to retrieve and " "parse for information") PATH_TEMPLATE = None # we do not use it - def __init__(self, api_baseurl=None): + def __init__(self, url=None): self.session = requests.Session() self.lister_version = __version__ def request_uri(self, _): """Get the full request URI given the transport_request identifier. """ return self.PAGE diff --git a/swh/lister/core/page_by_page_lister.py b/swh/lister/core/page_by_page_lister.py index 4895068..3d6d9c7 100644 --- a/swh/lister/core/page_by_page_lister.py +++ b/swh/lister/core/page_by_page_lister.py @@ -1,160 +1,160 @@ # Copyright (C) 2015-2018 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import abc import logging from .lister_transports import ListerHttpTransport from .lister_base import ListerBase class PageByPageLister(ListerBase): """Lister* intermediate class for any service that follows the simple pagination page pattern. - Client sends a request to list repositories starting from a given page identifier. - Client receives structured (json/xml/etc) response with information about a sequential series of repositories (per page) starting from a given index. And, if available, some indication of the next page index for fetching the remaining repository data. See :class:`swh.lister.core.lister_base.ListerBase` for more details. This class cannot be instantiated. To create a new Lister for a source code listing service that follows the model described above, you must subclass this class. Then provide the required overrides in addition to any unmet implementation/override requirements of this class's base (see parent class and member docstrings for details). Required Overrides:: def get_next_target_from_response """ @abc.abstractmethod def get_next_target_from_response(self, response): """Find the next server endpoint page given the entire response. Implementation of this method depends on the server API spec and the shape of the network response object returned by the transport_request method. For example, some api can use the headers links to provide the next page. Args: response (transport response): response page from the server Returns: index of next page, possibly extracted from a next href url """ pass @abc.abstractmethod def get_pages_information(self): """Find the total number of pages. Implementation of this method depends on the server API spec and the shape of the network response object returned by the transport_request method. For example, some api can use dedicated headers: - x-total-pages to provide the total number of pages - x-total to provide the total number of repositories - x-per-page to provide the number of elements per page Returns: tuple (total number of repositories, total number of pages, per_page) """ pass # You probably don't need to override anything below this line. def do_additional_checks(self, models_list): """Potentially check for existence of repositories in models_list. This will be called only if check_existence is flipped on in the run method below. """ for m in models_list: sql_repo = self.db_query_equal('uid', m['uid']) if sql_repo: return False return models_list def run(self, min_bound=None, max_bound=None, check_existence=False): """Main entry function. Sequentially fetches repository data from the service according to the basic outline in the class docstring. Continually fetching sublists until either there is no next page reference given or the given next page is greater than the desired max_page. Args: min_bound: optional page to start from max_bound: optional page to stop at check_existence (bool): optional existence check (for incremental lister whose sort order is inverted) Returns: nothing """ page = min_bound or 0 loop_count = 0 self.min_page = min_bound self.max_page = max_bound while self.is_within_bounds(page, self.min_page, self.max_page): logging.info('listing repos starting at %s' % page) response, injected_repos = self.ingest_data(page, checks=check_existence) if not response and not injected_repos: logging.info('No response from api server, stopping') break elif not injected_repos: logging.info('Repositories already seen, stopping') break next_page = self.get_next_target_from_response(response) # termination condition if (next_page is None) or (next_page == page): logging.info('stopping after page %s, no next link found' % page) break else: page = next_page loop_count += 1 if loop_count == 20: logging.info('flushing updates') loop_count = 0 self.db_session.commit() self.db_session = self.mk_session() self.db_session.commit() self.db_session = self.mk_session() class PageByPageHttpLister(ListerHttpTransport, PageByPageLister): """Convenience class for ensuring right lookup and init order when combining PageByPageLister and ListerHttpTransport. """ - def __init__(self, api_baseurl=None, override_config=None): + def __init__(self, url=None, override_config=None): PageByPageLister.__init__(self, override_config=override_config) - ListerHttpTransport.__init__(self, api_baseurl=api_baseurl) + ListerHttpTransport.__init__(self, url=url) diff --git a/swh/lister/core/tests/test_lister.py b/swh/lister/core/tests/test_lister.py index b7ae9e5..dec68c6 100644 --- a/swh/lister/core/tests/test_lister.py +++ b/swh/lister/core/tests/test_lister.py @@ -1,340 +1,340 @@ # Copyright (C) 2019 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import abc import time from unittest import TestCase from unittest.mock import Mock, patch import requests_mock from sqlalchemy import create_engine from swh.lister.core.abstractattribute import AbstractAttribute from swh.lister.tests.test_utils import init_db def noop(*args, **kwargs): pass class HttpListerTesterBase(abc.ABC): """Testing base class for listers. This contains methods for both :class:`HttpSimpleListerTester` and :class:`HttpListerTester`. See :class:`swh.lister.gitlab.tests.test_lister` for an example of how to customize for a specific listing service. """ Lister = AbstractAttribute('The lister class to test') lister_subdir = AbstractAttribute('bitbucket, github, etc.') good_api_response_file = AbstractAttribute('Example good response body') LISTER_NAME = 'fake-lister' # May need to override this if the headers are used for something def response_headers(self, request): return {} # May need to override this if the server uses non-standard rate limiting # method. # Please keep the requested retry delay reasonably low. def mock_rate_quota(self, n, request, context): self.rate_limit += 1 context.status_code = 429 context.headers['Retry-After'] = '1' return '{"error":"dummy"}' def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.rate_limit = 1 self.response = None self.fl = None self.helper = None if self.__class__ != HttpListerTesterBase: self.run = TestCase.run.__get__(self, self.__class__) else: self.run = noop def mock_limit_n_response(self, n, request, context): self.fl.reset_backoff() if self.rate_limit <= n: return self.mock_rate_quota(n, request, context) else: return self.mock_response(request, context) def mock_limit_twice_response(self, request, context): return self.mock_limit_n_response(2, request, context) def get_api_response(self, identifier): fl = self.get_fl() if self.response is None: self.response = fl.safely_issue_request(identifier) return self.response def get_fl(self, override_config=None): """Retrieve an instance of fake lister (fl). """ if override_config or self.fl is None: - self.fl = self.Lister(api_baseurl='https://fakeurl', + self.fl = self.Lister(url='https://fakeurl', override_config=override_config) self.fl.INITIAL_BACKOFF = 1 self.fl.reset_backoff() return self.fl def disable_scheduler(self, fl): fl.schedule_missing_tasks = Mock(return_value=None) def disable_db(self, fl): fl.winnow_models = Mock(return_value=[]) fl.db_inject_repo = Mock(return_value=fl.MODEL()) fl.disable_deleted_repo_tasks = Mock(return_value=None) def init_db(self, db, model): engine = create_engine(db.url()) model.metadata.create_all(engine) @requests_mock.Mocker() def test_is_within_bounds(self, http_mocker): fl = self.get_fl() self.assertFalse(fl.is_within_bounds(1, 2, 3)) self.assertTrue(fl.is_within_bounds(2, 1, 3)) self.assertTrue(fl.is_within_bounds(1, 1, 1)) self.assertTrue(fl.is_within_bounds(1, None, None)) self.assertTrue(fl.is_within_bounds(1, None, 2)) self.assertTrue(fl.is_within_bounds(1, 0, None)) self.assertTrue(fl.is_within_bounds("b", "a", "c")) self.assertFalse(fl.is_within_bounds("a", "b", "c")) self.assertTrue(fl.is_within_bounds("a", None, "c")) self.assertTrue(fl.is_within_bounds("a", None, None)) self.assertTrue(fl.is_within_bounds("b", "a", None)) self.assertFalse(fl.is_within_bounds("a", "b", None)) self.assertTrue(fl.is_within_bounds("aa:02", "aa:01", "aa:03")) self.assertFalse(fl.is_within_bounds("aa:12", None, "aa:03")) with self.assertRaises(TypeError): fl.is_within_bounds(1.0, "b", None) with self.assertRaises(TypeError): fl.is_within_bounds("A:B", "A::B", None) class HttpListerTester(HttpListerTesterBase, abc.ABC): """Base testing class for subclass of :class:`swh.lister.core.indexing_lister.IndexingHttpLister` See :class:`swh.lister.github.tests.test_gh_lister` for an example of how to customize for a specific listing service. """ last_index = AbstractAttribute('Last index in good_api_response') first_index = AbstractAttribute('First index in good_api_response') bad_api_response_file = AbstractAttribute('Example bad response body') entries_per_page = AbstractAttribute('Number of results in good response') test_re = AbstractAttribute('Compiled regex matching the server url. Must' ' capture the index value.') convert_type = str """static method used to convert the "request_index" to its right type (for indexing listers for example, this is in accordance with the model's "indexable" column). """ def mock_response(self, request, context): self.fl.reset_backoff() self.rate_limit = 1 context.status_code = 200 custom_headers = self.response_headers(request) context.headers.update(custom_headers) req_index = self.request_index(request) if req_index == self.first_index: response_file = self.good_api_response_file else: response_file = self.bad_api_response_file with open('swh/lister/%s/tests/%s' % (self.lister_subdir, response_file), 'r', encoding='utf-8') as r: return r.read() def request_index(self, request): m = self.test_re.search(request.path_url) if m and (len(m.groups()) > 0): return self.convert_type(m.group(1)) @requests_mock.Mocker() def test_fetch_multiple_pages_yesdb(self, http_mocker): http_mocker.get(self.test_re, text=self.mock_response) db = init_db() fl = self.get_fl(override_config={ 'lister': { 'cls': 'local', 'args': {'db': db.url()} } }) self.init_db(db, fl.MODEL) self.disable_scheduler(fl) fl.run(min_bound=self.first_index) self.assertEqual(fl.db_last_index(), self.last_index) partitions = fl.db_partition_indices(5) self.assertGreater(len(partitions), 0) for k in partitions: self.assertLessEqual(len(k), 5) self.assertGreater(len(k), 0) @requests_mock.Mocker() def test_fetch_none_nodb(self, http_mocker): http_mocker.get(self.test_re, text=self.mock_response) fl = self.get_fl() self.disable_scheduler(fl) self.disable_db(fl) fl.run(min_bound=1, max_bound=1) # stores no results # FIXME: Determine what this method tries to test and add checks to # actually test @requests_mock.Mocker() def test_fetch_one_nodb(self, http_mocker): http_mocker.get(self.test_re, text=self.mock_response) fl = self.get_fl() self.disable_scheduler(fl) self.disable_db(fl) fl.run(min_bound=self.first_index, max_bound=self.first_index) # FIXME: Determine what this method tries to test and add checks to # actually test @requests_mock.Mocker() def test_fetch_multiple_pages_nodb(self, http_mocker): http_mocker.get(self.test_re, text=self.mock_response) fl = self.get_fl() self.disable_scheduler(fl) self.disable_db(fl) fl.run(min_bound=self.first_index) # FIXME: Determine what this method tries to test and add checks to # actually test @requests_mock.Mocker() def test_repos_list(self, http_mocker): """Test the number of repos listed by the lister """ http_mocker.get(self.test_re, text=self.mock_response) li = self.get_fl().transport_response_simplified( self.get_api_response(self.first_index) ) self.assertIsInstance(li, list) self.assertEqual(len(li), self.entries_per_page) @requests_mock.Mocker() def test_model_map(self, http_mocker): """Check if all the keys of model are present in the model created by the `transport_response_simplified` """ http_mocker.get(self.test_re, text=self.mock_response) fl = self.get_fl() li = fl.transport_response_simplified( self.get_api_response(self.first_index)) di = li[0] self.assertIsInstance(di, dict) pubs = [k for k in vars(fl.MODEL).keys() if not k.startswith('_')] for k in pubs: if k not in ['last_seen', 'task_id', 'id']: self.assertIn(k, di) @requests_mock.Mocker() def test_api_request(self, http_mocker): """Test API request for rate limit handling """ http_mocker.get(self.test_re, text=self.mock_limit_twice_response) with patch.object(time, 'sleep', wraps=time.sleep) as sleepmock: self.get_api_response(self.first_index) self.assertEqual(sleepmock.call_count, 2) class HttpSimpleListerTester(HttpListerTesterBase, abc.ABC): """Base testing class for subclass of :class:`swh.lister.core.simple)_lister.SimpleLister` See :class:`swh.lister.pypi.tests.test_lister` for an example of how to customize for a specific listing service. """ entries = AbstractAttribute('Number of results in good response') PAGE = AbstractAttribute("The server api's unique page to retrieve and " "parse for information") def get_fl(self, override_config=None): """Retrieve an instance of fake lister (fl). """ if override_config or self.fl is None: self.fl = self.Lister( override_config=override_config) self.fl.INITIAL_BACKOFF = 1 self.fl.reset_backoff() return self.fl def mock_response(self, request, context): self.fl.reset_backoff() self.rate_limit = 1 context.status_code = 200 custom_headers = self.response_headers(request) context.headers.update(custom_headers) response_file = self.good_api_response_file with open('swh/lister/%s/tests/%s' % (self.lister_subdir, response_file), 'r', encoding='utf-8') as r: return r.read() @requests_mock.Mocker() def test_api_request(self, http_mocker): """Test API request for rate limit handling """ http_mocker.get(self.PAGE, text=self.mock_limit_twice_response) with patch.object(time, 'sleep', wraps=time.sleep) as sleepmock: self.get_api_response(0) self.assertEqual(sleepmock.call_count, 2) @requests_mock.Mocker() def test_model_map(self, http_mocker): """Check if all the keys of model are present in the model created by the `transport_response_simplified` """ http_mocker.get(self.PAGE, text=self.mock_response) fl = self.get_fl() li = fl.list_packages(self.get_api_response(0)) li = fl.transport_response_simplified(li) di = li[0] self.assertIsInstance(di, dict) pubs = [k for k in vars(fl.MODEL).keys() if not k.startswith('_')] for k in pubs: if k not in ['last_seen', 'task_id', 'id']: self.assertIn(k, di) @requests_mock.Mocker() def test_repos_list(self, http_mocker): """Test the number of packages listed by the lister """ http_mocker.get(self.PAGE, text=self.mock_response) li = self.get_fl().list_packages( self.get_api_response(0) ) self.assertIsInstance(li, list) self.assertEqual(len(li), self.entries) diff --git a/swh/lister/debian/lister.py b/swh/lister/debian/lister.py index 44b766e..8837b17 100644 --- a/swh/lister/debian/lister.py +++ b/swh/lister/debian/lister.py @@ -1,237 +1,237 @@ # Copyright (C) 2017 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import bz2 from collections import defaultdict import datetime import gzip import lzma import logging from debian.deb822 import Sources from sqlalchemy.orm import joinedload, load_only from sqlalchemy.schema import CreateTable, DropTable from swh.storage.schemata.distribution import ( AreaSnapshot, Distribution, DistributionSnapshot, Package, TempPackage, ) from swh.lister.core.lister_base import ListerBase, FetchError from swh.lister.core.lister_transports import ListerHttpTransport decompressors = { 'gz': lambda f: gzip.GzipFile(fileobj=f), 'bz2': bz2.BZ2File, 'xz': lzma.LZMAFile, } class DebianLister(ListerHttpTransport, ListerBase): MODEL = Package PATH_TEMPLATE = None LISTER_NAME = 'debian' instance = 'debian' def __init__(self, override_config=None): - ListerHttpTransport.__init__(self, api_baseurl="bogus") + ListerHttpTransport.__init__(self, url="notused") ListerBase.__init__(self, override_config=override_config) def transport_request(self, identifier): """Subvert ListerHttpTransport.transport_request, to try several index URIs in turn. The Debian repository format supports several compression algorithms across the ages, so we try several URIs. Once we have found a working URI, we break and set `self.decompressor` to the one that matched. Returns: a requests Response object. Raises: FetchError: when all the URIs failed to be retrieved. """ response = None compression = None for uri, compression in self.area.index_uris(): response = super().transport_request(uri) if response.status_code == 200: break else: raise FetchError( "Could not retrieve index for %s" % self.area ) self.decompressor = decompressors.get(compression) return response def request_uri(self, identifier): # In the overridden transport_request, we pass # ListerBase.transport_request() the full URI as identifier, so we # need to return it here. return identifier def request_params(self, identifier): # Enable streaming to allow wrapping the response in the decompressor # in transport_response_simplified. params = super().request_params(identifier) params['stream'] = True return params def transport_response_simplified(self, response): """Decompress and parse the package index fetched in `transport_request`. For each package, we "pivot" the file list entries (Files, Checksums-Sha1, Checksums-Sha256), to return a files dict mapping filenames to their checksums. """ if self.decompressor: data = self.decompressor(response.raw) else: data = response.raw for src_pkg in Sources.iter_paragraphs(data.readlines()): files = defaultdict(dict) for field in src_pkg._multivalued_fields: if field.startswith('checksums-'): sum_name = field[len('checksums-'):] else: sum_name = 'md5sum' if field in src_pkg: for entry in src_pkg[field]: name = entry['name'] files[name]['name'] = entry['name'] files[name]['size'] = int(entry['size'], 10) files[name][sum_name] = entry[sum_name] yield { 'name': src_pkg['Package'], 'version': src_pkg['Version'], 'directory': src_pkg['Directory'], 'files': files, } def inject_repo_data_into_db(self, models_list): """Generate the Package entries that didn't previously exist. Contrary to ListerBase, we don't actually insert the data in database. `schedule_missing_tasks` does it once we have the origin and task identifiers. """ by_name_version = {} temp_packages = [] area_id = self.area.id for model in models_list: name = model['name'] version = model['version'] temp_packages.append({ 'area_id': area_id, 'name': name, 'version': version, }) by_name_version[name, version] = model # Add all the listed packages to a temporary table self.db_session.execute(CreateTable(TempPackage.__table__)) self.db_session.bulk_insert_mappings(TempPackage, temp_packages) def exists_tmp_pkg(db_session, model): return ( db_session.query(model) .filter(Package.area_id == TempPackage.area_id) .filter(Package.name == TempPackage.name) .filter(Package.version == TempPackage.version) .exists() ) # Filter out the packages that already exist in the main Package table new_packages = self.db_session\ .query(TempPackage)\ .options(load_only('name', 'version'))\ .filter(~exists_tmp_pkg(self.db_session, Package))\ .all() self.old_area_packages = self.db_session.query(Package).filter( exists_tmp_pkg(self.db_session, TempPackage) ).all() self.db_session.execute(DropTable(TempPackage.__table__)) added_packages = [] for package in new_packages: model = by_name_version[package.name, package.version] added_packages.append(Package(area=self.area, **model)) self.db_session.add_all(added_packages) return added_packages def schedule_missing_tasks(self, models_list, added_packages): """We create tasks at the end of the full snapshot processing""" return def create_tasks_for_snapshot(self, snapshot): tasks = [ snapshot.task_for_package(name, versions) for name, versions in snapshot.get_packages().items() ] return self.scheduler.create_tasks(tasks) def run(self, distribution, date=None): """Run the lister for a given (distribution, area) tuple. Args: distribution (str): name of the distribution (e.g. "Debian") date (datetime.datetime): date the snapshot is taken (defaults to now) """ distribution = self.db_session\ .query(Distribution)\ .options(joinedload(Distribution.areas))\ .filter(Distribution.name == distribution)\ .one_or_none() if not distribution: raise ValueError("Distribution %s is not registered" % distribution) if not distribution.type == 'deb': raise ValueError("Distribution %s is not a Debian derivative" % distribution) date = date or datetime.datetime.now(tz=datetime.timezone.utc) logging.debug('Creating snapshot for distribution %s on date %s' % (distribution, date)) snapshot = DistributionSnapshot(date=date, distribution=distribution) self.db_session.add(snapshot) for area in distribution.areas: if not area.active: continue self.area = area logging.debug('Processing area %s' % area) _, new_area_packages = self.ingest_data(None) area_snapshot = AreaSnapshot(snapshot=snapshot, area=area) self.db_session.add(area_snapshot) area_snapshot.packages.extend(new_area_packages) area_snapshot.packages.extend(self.old_area_packages) self.create_tasks_for_snapshot(snapshot) self.db_session.commit() return True diff --git a/swh/lister/github/tasks.py b/swh/lister/github/tasks.py index c94db27..555dc0a 100644 --- a/swh/lister/github/tasks.py +++ b/swh/lister/github/tasks.py @@ -1,57 +1,53 @@ # Copyright (C) 2017-2019 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import random from celery import group from swh.scheduler.celery_backend.config import app from swh.lister.github.lister import GitHubLister GROUP_SPLIT = 10000 -def new_lister(api_baseurl='https://api.github.com', **kw): - return GitHubLister(api_baseurl=api_baseurl, **kw) - - @app.task(name=__name__ + '.IncrementalGitHubLister') def list_github_incremental(**lister_args): 'Incremental update of GitHub' - lister = new_lister(**lister_args) + lister = GitHubLister(**lister_args) lister.run(min_bound=lister.db_last_index(), max_bound=None) @app.task(name=__name__ + '.RangeGitHubLister') def _range_github_lister(start, end, **lister_args): - lister = new_lister(**lister_args) + lister = GitHubLister(**lister_args) lister.run(min_bound=start, max_bound=end) @app.task(name=__name__ + '.FullGitHubRelister', bind=True) def list_github_full(self, split=None, **lister_args): """Full update of GitHub It's not to be called for an initial listing. """ - lister = new_lister(**lister_args) + lister = GitHubLister(**lister_args) ranges = lister.db_partition_indices(split or GROUP_SPLIT) if not ranges: self.log.info('Nothing to list') return random.shuffle(ranges) promise = group(_range_github_lister.s(minv, maxv, **lister_args) for minv, maxv in ranges)() self.log.debug('%s OK (spawned %s subtasks)' % (self.name, len(ranges))) try: promise.save() # so that we can restore the GroupResult in tests except (NotImplementedError, AttributeError): self.log.info('Unable to call save_group with current result backend.') return promise.id @app.task(name=__name__ + '.ping') def _ping(): return 'OK' diff --git a/swh/lister/github/tests/test_tasks.py b/swh/lister/github/tests/test_tasks.py index 9bd30c1..c652404 100644 --- a/swh/lister/github/tests/test_tasks.py +++ b/swh/lister/github/tests/test_tasks.py @@ -1,90 +1,90 @@ from time import sleep from celery.result import GroupResult from unittest.mock import patch def test_ping(swh_app, celery_session_worker): res = swh_app.send_task( 'swh.lister.github.tasks.ping') assert res res.wait() assert res.successful() assert res.result == 'OK' @patch('swh.lister.github.tasks.GitHubLister') def test_incremental(lister, swh_app, celery_session_worker): # setup the mocked GitHubLister lister.return_value = lister lister.db_last_index.return_value = 42 lister.run.return_value = None res = swh_app.send_task( 'swh.lister.github.tasks.IncrementalGitHubLister') assert res res.wait() assert res.successful() - lister.assert_called_once_with(api_baseurl='https://api.github.com') + lister.assert_called_once_with() lister.db_last_index.assert_called_once_with() lister.run.assert_called_once_with(min_bound=42, max_bound=None) @patch('swh.lister.github.tasks.GitHubLister') def test_range(lister, swh_app, celery_session_worker): # setup the mocked GitHubLister lister.return_value = lister lister.run.return_value = None res = swh_app.send_task( 'swh.lister.github.tasks.RangeGitHubLister', kwargs=dict(start=12, end=42)) assert res res.wait() assert res.successful() - lister.assert_called_once_with(api_baseurl='https://api.github.com') + lister.assert_called_once_with() lister.db_last_index.assert_not_called() lister.run.assert_called_once_with(min_bound=12, max_bound=42) @patch('swh.lister.github.tasks.GitHubLister') def test_relister(lister, swh_app, celery_session_worker): # setup the mocked GitHubLister lister.return_value = lister lister.run.return_value = None lister.db_partition_indices.return_value = [ (i, i+9) for i in range(0, 50, 10)] res = swh_app.send_task( 'swh.lister.github.tasks.FullGitHubRelister') assert res res.wait() assert res.successful() # retrieve the GroupResult for this task and wait for all the subtasks # to complete promise_id = res.result assert promise_id promise = GroupResult.restore(promise_id, app=swh_app) for i in range(5): if promise.ready(): break sleep(1) - lister.assert_called_with(api_baseurl='https://api.github.com') + lister.assert_called_with() # one by the FullGitHubRelister task # + 5 for the RangeGitHubLister subtasks assert lister.call_count == 6 lister.db_last_index.assert_not_called() lister.db_partition_indices.assert_called_once_with(10000) # lister.run should have been called once per partition interval for i in range(5): # XXX inconsistent behavior: max_bound is INCLUDED here assert (dict(min_bound=10*i, max_bound=10*i + 9),) \ in lister.run.call_args_list diff --git a/swh/lister/gitlab/lister.py b/swh/lister/gitlab/lister.py index 4dc46f7..60b5320 100644 --- a/swh/lister/gitlab/lister.py +++ b/swh/lister/gitlab/lister.py @@ -1,82 +1,81 @@ # Copyright (C) 2018-2019 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import time from urllib3.util import parse_url from ..core.page_by_page_lister import PageByPageHttpLister from .models import GitLabModel class GitLabLister(PageByPageHttpLister): # Template path expecting an integer that represents the page id PATH_TEMPLATE = '/projects?page=%d&order_by=id' DEFAULT_URL = 'https://gitlab.com/api/v4/' MODEL = GitLabModel LISTER_NAME = 'gitlab' - def __init__(self, api_baseurl=None, instance=None, + def __init__(self, url=None, instance=None, override_config=None, sort='asc', per_page=20): - super().__init__(api_baseurl=api_baseurl, - override_config=override_config) + super().__init__(url=url, override_config=override_config) if instance is None: - instance = parse_url(self.api_baseurl).host + instance = parse_url(self.url).host self.instance = instance self.PATH_TEMPLATE = '%s&sort=%s&per_page=%s' % ( self.PATH_TEMPLATE, sort, per_page) def uid(self, repo): return '%s/%s' % (self.instance, repo['path_with_namespace']) def get_model_from_repo(self, repo): return { 'instance': self.instance, 'uid': self.uid(repo), 'name': repo['name'], 'full_name': repo['path_with_namespace'], 'html_url': repo['web_url'], 'origin_url': repo['http_url_to_repo'], 'origin_type': 'git', } def transport_quota_check(self, response): """Deal with rate limit if any. """ # not all gitlab instance have rate limit if 'RateLimit-Remaining' in response.headers: reqs_remaining = int(response.headers['RateLimit-Remaining']) if response.status_code == 403 and reqs_remaining == 0: reset_at = int(response.headers['RateLimit-Reset']) delay = min(reset_at - time.time(), 3600) return True, delay return False, 0 def _get_int(self, headers, key): _val = headers.get(key) if _val: return int(_val) def get_next_target_from_response(self, response): """Determine the next page identifier. """ return self._get_int(response.headers, 'x-next-page') def get_pages_information(self): """Determine pages information. """ response = self.transport_head(identifier=1) if not response.ok: raise ValueError( 'Problem during information fetch: %s' % response.status_code) h = response.headers return (self._get_int(h, 'x-total'), self._get_int(h, 'x-total-pages'), self._get_int(h, 'x-per-page')) def transport_response_simplified(self, response): repos = response.json() return [self.get_model_from_repo(repo) for repo in repos] diff --git a/swh/lister/gitlab/tasks.py b/swh/lister/gitlab/tasks.py index 30c4241..30cab41 100644 --- a/swh/lister/gitlab/tasks.py +++ b/swh/lister/gitlab/tasks.py @@ -1,59 +1,52 @@ # Copyright (C) 2018 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import random from celery import group from swh.scheduler.celery_backend.config import app from .. import utils from .lister import GitLabLister NBPAGES = 10 -def new_lister(api_baseurl='https://gitlab.com/api/v4', - instance=None, sort='asc', per_page=20): - return GitLabLister( - api_baseurl=api_baseurl, instance=instance, sort=sort, - per_page=per_page) - - @app.task(name=__name__ + '.IncrementalGitLabLister') def list_gitlab_incremental(**lister_args): """Incremental update of a GitLab instance""" lister_args['sort'] = 'desc' - lister = new_lister(**lister_args) + lister = GitLabLister(**lister_args) total_pages = lister.get_pages_information()[1] # stopping as soon as existing origins for that instance are detected lister.run(min_bound=1, max_bound=total_pages, check_existence=True) @app.task(name=__name__ + '.RangeGitLabLister') def _range_gitlab_lister(start, end, **lister_args): - lister = new_lister(**lister_args) + lister = GitLabLister(**lister_args) lister.run(min_bound=start, max_bound=end) @app.task(name=__name__ + '.FullGitLabRelister', bind=True) def list_gitlab_full(self, **lister_args): """Full update of a GitLab instance""" - lister = new_lister(**lister_args) + lister = GitLabLister(**lister_args) _, total_pages, _ = lister.get_pages_information() ranges = list(utils.split_range(total_pages, NBPAGES)) random.shuffle(ranges) promise = group(_range_gitlab_lister.s(minv, maxv, **lister_args) for minv, maxv in ranges)() self.log.debug('%s OK (spawned %s subtasks)' % (self.name, len(ranges))) try: promise.save() except (NotImplementedError, AttributeError): self.log.info('Unable to call save_group with current result backend.') return promise.id @app.task(name=__name__ + '.ping') def _ping(): return 'OK' diff --git a/swh/lister/gitlab/tests/test_tasks.py b/swh/lister/gitlab/tests/test_tasks.py index f8d0a81..56332a1 100644 --- a/swh/lister/gitlab/tests/test_tasks.py +++ b/swh/lister/gitlab/tests/test_tasks.py @@ -1,150 +1,142 @@ from time import sleep from celery.result import GroupResult from unittest.mock import patch def test_ping(swh_app, celery_session_worker): res = swh_app.send_task( 'swh.lister.gitlab.tasks.ping') assert res res.wait() assert res.successful() assert res.result == 'OK' @patch('swh.lister.gitlab.tasks.GitLabLister') def test_incremental(lister, swh_app, celery_session_worker): # setup the mocked GitlabLister lister.return_value = lister lister.run.return_value = None lister.get_pages_information.return_value = (None, 10, None) res = swh_app.send_task( 'swh.lister.gitlab.tasks.IncrementalGitLabLister') assert res res.wait() assert res.successful() - lister.assert_called_once_with( - api_baseurl='https://gitlab.com/api/v4', - instance=None, sort='desc', per_page=20) + lister.assert_called_once_with(sort='desc') lister.db_last_index.assert_not_called() lister.get_pages_information.assert_called_once_with() lister.run.assert_called_once_with( min_bound=1, max_bound=10, check_existence=True) @patch('swh.lister.gitlab.tasks.GitLabLister') def test_range(lister, swh_app, celery_session_worker): # setup the mocked GitlabLister lister.return_value = lister lister.run.return_value = None res = swh_app.send_task( 'swh.lister.gitlab.tasks.RangeGitLabLister', kwargs=dict(start=12, end=42)) assert res res.wait() assert res.successful() - lister.assert_called_once_with( - api_baseurl='https://gitlab.com/api/v4', - instance=None, sort='asc', per_page=20) + lister.assert_called_once_with() lister.db_last_index.assert_not_called() lister.run.assert_called_once_with(min_bound=12, max_bound=42) @patch('swh.lister.gitlab.tasks.GitLabLister') def test_relister(lister, swh_app, celery_session_worker): # setup the mocked GitlabLister lister.return_value = lister lister.run.return_value = None lister.get_pages_information.return_value = (None, 85, None) lister.db_partition_indices.return_value = [ (i, i+9) for i in range(0, 80, 10)] + [(80, 85)] res = swh_app.send_task( 'swh.lister.gitlab.tasks.FullGitLabRelister') assert res res.wait() assert res.successful() # retrieve the GroupResult for this task and wait for all the subtasks # to complete promise_id = res.result assert promise_id promise = GroupResult.restore(promise_id, app=swh_app) for i in range(5): if promise.ready(): break sleep(1) - lister.assert_called_with( - api_baseurl='https://gitlab.com/api/v4', - instance=None, sort='asc', per_page=20) + lister.assert_called_with() # one by the FullGitlabRelister task # + 9 for the RangeGitlabLister subtasks assert lister.call_count == 10 lister.db_last_index.assert_not_called() lister.db_partition_indices.assert_not_called() lister.get_pages_information.assert_called_once_with() # lister.run should have been called once per partition interval for i in range(8): # XXX inconsistent behavior: max_bound is EXCLUDED here assert (dict(min_bound=10*i, max_bound=10*i + 10),) \ in lister.run.call_args_list assert (dict(min_bound=80, max_bound=85),) \ in lister.run.call_args_list @patch('swh.lister.gitlab.tasks.GitLabLister') def test_relister_instance(lister, swh_app, celery_session_worker): # setup the mocked GitlabLister lister.return_value = lister lister.run.return_value = None lister.get_pages_information.return_value = (None, 85, None) lister.db_partition_indices.return_value = [ (i, i+9) for i in range(0, 80, 10)] + [(80, 85)] res = swh_app.send_task( 'swh.lister.gitlab.tasks.FullGitLabRelister', - kwargs=dict(api_baseurl='https://0xacab.org/api/v4')) + kwargs=dict(url='https://0xacab.org/api/v4')) assert res res.wait() assert res.successful() # retrieve the GroupResult for this task and wait for all the subtasks # to complete promise_id = res.result assert promise_id promise = GroupResult.restore(promise_id, app=swh_app) for i in range(5): if promise.ready(): break sleep(1) - lister.assert_called_with( - api_baseurl='https://0xacab.org/api/v4', - instance=None, sort='asc', per_page=20) + lister.assert_called_with(url='https://0xacab.org/api/v4') # one by the FullGitlabRelister task # + 9 for the RangeGitlabLister subtasks assert lister.call_count == 10 lister.db_last_index.assert_not_called() lister.db_partition_indices.assert_not_called() lister.get_pages_information.assert_called_once_with() # lister.run should have been called once per partition interval for i in range(8): # XXX inconsistent behavior: max_bound is EXCLUDED here assert (dict(min_bound=10*i, max_bound=10*i + 10),) \ in lister.run.call_args_list assert (dict(min_bound=80, max_bound=85),) \ in lister.run.call_args_list diff --git a/swh/lister/npm/lister.py b/swh/lister/npm/lister.py index c7e9d29..0672f7c 100644 --- a/swh/lister/npm/lister.py +++ b/swh/lister/npm/lister.py @@ -1,157 +1,156 @@ # Copyright (C) 2018-2019 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from urllib.parse import quote from swh.lister.core.indexing_lister import IndexingHttpLister from swh.lister.npm.models import NpmModel from swh.scheduler.utils import create_task_dict class NpmListerBase(IndexingHttpLister): """List packages available in the npm registry in a paginated way """ MODEL = NpmModel LISTER_NAME = 'npm' instance = 'npm' - def __init__(self, api_baseurl='https://replicate.npmjs.com', + def __init__(self, url='https://replicate.npmjs.com', per_page=1000, override_config=None): - super().__init__(api_baseurl=api_baseurl, - override_config=override_config) + super().__init__(url=url, override_config=override_config) self.per_page = per_page + 1 self.PATH_TEMPLATE += '&limit=%s' % self.per_page @property def ADDITIONAL_CONFIG(self): """(Override) Add extra configuration """ default_config = super().ADDITIONAL_CONFIG default_config['loading_task_policy'] = ('str', 'recurring') return default_config def get_model_from_repo(self, repo_name): """(Override) Transform from npm package name to model """ package_url, package_metadata_url = self._compute_urls(repo_name) return { 'uid': repo_name, 'indexable': repo_name, 'name': repo_name, 'full_name': repo_name, 'html_url': package_metadata_url, 'origin_url': package_url, 'origin_type': 'npm', } def task_dict(self, origin_type, origin_url, **kwargs): """(Override) Return task dict for loading a npm package into the archive This is overridden from the lister_base as more information is needed for the ingestion task creation. """ task_type = 'load-%s' % origin_type task_policy = self.config['loading_task_policy'] package_name = kwargs.get('name') package_metadata_url = kwargs.get('html_url') return create_task_dict(task_type, task_policy, package_name, origin_url, package_metadata_url=package_metadata_url) def request_headers(self): """(Override) Set requests headers to send when querying the npm registry """ return {'User-Agent': 'Software Heritage npm lister', 'Accept': 'application/json'} def _compute_urls(self, repo_name): """Return a tuple (package_url, package_metadata_url) """ return ( 'https://www.npmjs.com/package/%s' % repo_name, # package metadata url needs to be escaped otherwise some requests # may fail (for instance when a package name contains '/') - '%s/%s' % (self.api_baseurl, quote(repo_name, safe='')) + '%s/%s' % (self.url, quote(repo_name, safe='')) ) def string_pattern_check(self, inner, lower, upper=None): """ (Override) Inhibit the effect of that method as packages indices correspond to package names and thus do not respect any kind of fixed length string pattern """ pass class NpmLister(NpmListerBase): """List all packages available in the npm registry in a paginated way """ PATH_TEMPLATE = '/_all_docs?startkey="%s"' def get_next_target_from_response(self, response): """(Override) Get next npm package name to continue the listing """ repos = response.json()['rows'] return repos[-1]['id'] if len(repos) == self.per_page else None def transport_response_simplified(self, response): """(Override) Transform npm registry response to list for model manipulation """ repos = response.json()['rows'] if len(repos) == self.per_page: repos = repos[:-1] return [self.get_model_from_repo(repo['id']) for repo in repos] class NpmIncrementalLister(NpmListerBase): """List packages in the npm registry, updated since a specific update_seq value of the underlying CouchDB database, in a paginated way """ PATH_TEMPLATE = '/_changes?since=%s' @property def CONFIG_BASE_FILENAME(self): # noqa: N802 return 'lister_npm_incremental' def get_next_target_from_response(self, response): """(Override) Get next npm package name to continue the listing """ repos = response.json()['results'] return repos[-1]['seq'] if len(repos) == self.per_page else None def transport_response_simplified(self, response): """(Override) Transform npm registry response to list for model manipulation """ repos = response.json()['results'] if len(repos) == self.per_page: repos = repos[:-1] return [self.get_model_from_repo(repo['id']) for repo in repos] def filter_before_inject(self, models_list): """(Override) Filter out documents in the CouchDB database not related to a npm package """ models_filtered = [] for model in models_list: package_name = model['name'] # document related to CouchDB internals if package_name.startswith('_design/'): continue models_filtered.append(model) return models_filtered def disable_deleted_repo_tasks(self, start, end, keep_these): """(Override) Disable the processing performed by that method as it is not relevant in this incremental lister context and it raises and exception due to a different index type (int instead of str) """ pass diff --git a/swh/lister/npm/tasks.py b/swh/lister/npm/tasks.py index 18c8374..8d1c369 100644 --- a/swh/lister/npm/tasks.py +++ b/swh/lister/npm/tasks.py @@ -1,62 +1,62 @@ # Copyright (C) 2018 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime from contextlib import contextmanager from swh.scheduler.celery_backend.config import app from swh.lister.npm.lister import NpmLister, NpmIncrementalLister from swh.lister.npm.models import NpmVisitModel @contextmanager def save_registry_state(lister): params = {'headers': lister.request_headers()} - registry_state = lister.session.get(lister.api_baseurl, **params) + registry_state = lister.session.get(lister.url, **params) registry_state = registry_state.json() keys = ('doc_count', 'doc_del_count', 'update_seq', 'purge_seq', 'disk_size', 'data_size', 'committed_update_seq', 'compacted_seq') state = {key: registry_state[key] for key in keys} state['visit_date'] = datetime.now() yield npm_visit = NpmVisitModel(**state) lister.db_session.add(npm_visit) lister.db_session.commit() def get_last_update_seq(lister): """Get latest ``update_seq`` value for listing only updated packages. """ query = lister.db_session.query(NpmVisitModel.update_seq) row = query.order_by(NpmVisitModel.uid.desc()).first() if not row: raise ValueError('No npm registry listing previously performed ! ' 'This is required prior to the execution of an ' 'incremental listing.') return row[0] @app.task(name=__name__ + '.NpmListerTask') def list_npm_full(**lister_args): 'Full lister for the npm (javascript) registry' lister = NpmLister(**lister_args) with save_registry_state(lister): lister.run() @app.task(name=__name__ + '.NpmIncrementalListerTask') def list_npm_incremental(**lister_args): 'Incremental lister for the npm (javascript) registry' lister = NpmIncrementalLister(**lister_args) update_seq_start = get_last_update_seq(lister) with save_registry_state(lister): lister.run(min_bound=update_seq_start) @app.task(name=__name__ + '.ping') def _ping(): return 'OK' diff --git a/swh/lister/phabricator/lister.py b/swh/lister/phabricator/lister.py index 68affff..7f60c14 100644 --- a/swh/lister/phabricator/lister.py +++ b/swh/lister/phabricator/lister.py @@ -1,154 +1,153 @@ # Copyright (C) 2019 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import random import urllib.parse from swh.lister.core.indexing_lister import IndexingHttpLister from swh.lister.phabricator.models import PhabricatorModel from collections import defaultdict logger = logging.getLogger(__name__) class PhabricatorLister(IndexingHttpLister): PATH_TEMPLATE = '?order=oldest&attachments[uris]=1&after=%s' DEFAULT_URL = 'https://forge.softwareheritage.org/api/diffusion.repository.search' # noqa MODEL = PhabricatorModel LISTER_NAME = 'phabricator' - def __init__(self, api_baseurl=None, instance=None, override_config=None): - super().__init__(api_baseurl=api_baseurl, - override_config=override_config) + def __init__(self, url=None, instance=None, override_config=None): + super().__init__(url=url, override_config=override_config) if not instance: - instance = urllib.parse.urlparse(self.api_baseurl).hostname + instance = urllib.parse.urlparse(self.url).hostname self.instance = instance @property def default_min_bound(self): """Starting boundary when `min_bound` is not defined (db empty). This is used within the fn:`run` call. """ return self._bootstrap_repositories_listing() def request_params(self, identifier): """Override the default params behavior to retrieve the api token Credentials are stored as: credentials: phabricator: : - username: password: """ creds = self.request_instance_credentials() if not creds: raise ValueError( 'Phabricator forge needs authentication credential to list.') api_token = random.choice(creds)['password'] return {'headers': self.request_headers() or {}, 'params': {'api.token': api_token}} def request_headers(self): """ (Override) Set requests headers to send when querying the Phabricator API """ return {'User-Agent': 'Software Heritage phabricator lister', 'Accept': 'application/json'} def get_model_from_repo(self, repo): url = get_repo_url(repo['attachments']['uris']['uris']) if url is None: return None return { 'uid': url, 'indexable': repo['id'], 'name': repo['fields']['shortName'], 'full_name': repo['fields']['name'], 'html_url': url, 'origin_url': url, 'origin_type': repo['fields']['vcs'], 'instance': self.instance, } def get_next_target_from_response(self, response): body = response.json()['result']['cursor'] if body['after'] != 'null': return body['after'] return None def transport_response_simplified(self, response): repos = response.json() if repos['result'] is None: raise ValueError( 'Problem during information fetch: %s' % repos['error_code']) repos = repos['result']['data'] return [self.get_model_from_repo(repo) for repo in repos] def filter_before_inject(self, models_list): """ (Overrides) IndexingLister.filter_before_inject Bounds query results by this Lister's set max_index. """ models_list = [m for m in models_list if m is not None] return super().filter_before_inject(models_list) def _bootstrap_repositories_listing(self): """ Method called when no min_bound value has been provided to the lister. Its purpose is to: 1. get the first repository data hosted on the Phabricator instance 2. inject them into the lister database 3. return the first repository index to start the listing after that value Returns: int: The first repository index """ params = '&order=oldest&limit=1' response = self.safely_issue_request(params) models_list = self.transport_response_simplified(response) self.max_index = models_list[0]['indexable'] models_list = self.filter_before_inject(models_list) injected = self.inject_repo_data_into_db(models_list) self.schedule_missing_tasks(models_list, injected) return self.max_index def get_repo_url(attachments): """ Return url for a hosted repository from its uris attachments according to the following priority lists: * protocol: https > http * identifier: shortname > callsign > id """ processed_urls = defaultdict(dict) for uri in attachments: protocol = uri['fields']['builtin']['protocol'] url = uri['fields']['uri']['effective'] identifier = uri['fields']['builtin']['identifier'] if protocol in ('http', 'https'): processed_urls[protocol][identifier] = url elif protocol is None: for protocol in ('https', 'http'): if url.startswith(protocol): processed_urls[protocol]['undefined'] = url break for protocol in ['https', 'http']: for identifier in ['shortname', 'callsign', 'id', 'undefined']: if (protocol in processed_urls and identifier in processed_urls[protocol]): return processed_urls[protocol][identifier] return None diff --git a/swh/lister/phabricator/tests/test_lister.py b/swh/lister/phabricator/tests/test_lister.py index 78cf006..f52e560 100644 --- a/swh/lister/phabricator/tests/test_lister.py +++ b/swh/lister/phabricator/tests/test_lister.py @@ -1,61 +1,60 @@ # Copyright (C) 2019 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import re import json import unittest from swh.lister.core.tests.test_lister import HttpListerTester from swh.lister.phabricator.lister import PhabricatorLister from swh.lister.phabricator.lister import get_repo_url class PhabricatorListerTester(HttpListerTester, unittest.TestCase): Lister = PhabricatorLister test_re = re.compile(r'\&after=([^?&]+)') lister_subdir = 'phabricator' good_api_response_file = 'api_response.json' good_api_response_undefined_protocol = 'api_response_undefined_'\ 'protocol.json' bad_api_response_file = 'api_empty_response.json' first_index = 1 last_index = 12 entries_per_page = 10 convert_type = int def get_fl(self, override_config=None): """(Override) Retrieve an instance of fake lister (fl). """ if override_config or self.fl is None: credentials = {'phabricator': {'fake': [ {'password': 'toto'} ]}} override_config = dict(credentials=credentials, **(override_config or {})) - self.fl = self.Lister( - api_baseurl='https://fakeurl', instance='fake', - override_config=override_config) + self.fl = self.Lister(url='https://fakeurl', instance='fake', + override_config=override_config) self.fl.INITIAL_BACKOFF = 1 self.fl.reset_backoff() return self.fl def test_get_repo_url(self): f = open('swh/lister/%s/tests/%s' % (self.lister_subdir, self.good_api_response_file)) api_response = json.load(f) repos = api_response['result']['data'] for repo in repos: self.assertEqual( 'https://forge.softwareheritage.org/source/%s.git' % (repo['fields']['shortName']), get_repo_url(repo['attachments']['uris']['uris'])) f = open('swh/lister/%s/tests/%s' % (self.lister_subdir, self.good_api_response_undefined_protocol)) repo = json.load(f) self.assertEqual( 'https://svn.blender.org/svnroot/bf-blender/', get_repo_url(repo['attachments']['uris']['uris'])) diff --git a/swh/lister/tests/test_cli.py b/swh/lister/tests/test_cli.py index a526384..6039ea4 100644 --- a/swh/lister/tests/test_cli.py +++ b/swh/lister/tests/test_cli.py @@ -1,141 +1,141 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import glob import pytest import traceback from datetime import timedelta import yaml from swh.core.utils import numfile_sortkey as sortkey from swh.scheduler import get_scheduler from swh.scheduler.tests.conftest import DUMP_FILES from swh.lister.core.lister_base import ListerBase from swh.lister.cli import lister as cli, get_lister, SUPPORTED_LISTERS from .test_utils import init_db from click.testing import CliRunner @pytest.fixture def swh_scheduler_config(request, postgresql_proc, postgresql): scheduler_config = { 'db': 'postgresql://{user}@{host}:{port}/{dbname}'.format( host=postgresql_proc.host, port=postgresql_proc.port, user='postgres', dbname='tests') } all_dump_files = sorted(glob.glob(DUMP_FILES), key=sortkey) cursor = postgresql.cursor() for fname in all_dump_files: with open(fname) as fobj: cursor.execute(fobj.read()) postgresql.commit() return scheduler_config def test_get_lister_wrong_input(): """Unsupported lister should raise""" with pytest.raises(ValueError) as e: get_lister('unknown', 'db-url') assert "Invalid lister" in str(e.value) def test_get_lister(): """Instantiating a supported lister should be ok """ db_url = init_db().url() for lister_name in SUPPORTED_LISTERS: lst = get_lister(lister_name, db_url) assert isinstance(lst, ListerBase) def test_get_lister_override(): """Overriding the lister configuration should populate its config """ db_url = init_db().url() listers = { - 'gitlab': ('api_baseurl', 'https://gitlab.uni/api/v4/'), - 'phabricator': ( - 'api_baseurl', - 'https://somewhere.org/api/diffusion.repository.search'), + 'gitlab': 'https://other.gitlab.uni/api/v4/', + 'phabricator': 'https://somewhere.org/api/diffusion.repository.search', + 'cgit': 'https://some.where/cgit', } # check the override ends up defined in the lister - for lister_name, (url_key, url_value) in listers.items(): + for lister_name, url in listers.items(): lst = get_lister( lister_name, db_url, **{ - url_key: url_value, + 'url': url, 'priority': 'high', 'policy': 'oneshot', }) - assert getattr(lst, url_key) == url_value + assert lst.url == url assert lst.config['priority'] == 'high' assert lst.config['policy'] == 'oneshot' # check the default urls are used and not the override (since it's not # passed) - for lister_name, (url_key, url_value) in listers.items(): + for lister_name, url in listers.items(): lst = get_lister(lister_name, db_url) # no override so this does not end up in lister's configuration - assert url_key not in lst.config + assert 'url' not in lst.config assert 'priority' not in lst.config assert 'oneshot' not in lst.config + assert lst.url == lst.DEFAULT_URL def test_task_types(swh_scheduler_config, tmp_path): db_url = init_db().url() configfile = tmp_path / 'config.yml' configfile.write_text(yaml.dump({'scheduler': { 'cls': 'local', 'args': swh_scheduler_config}})) runner = CliRunner() result = runner.invoke(cli, [ '--db-url', db_url, '--config-file', configfile.as_posix(), 'register-task-types']) assert result.exit_code == 0, traceback.print_exception(*result.exc_info) scheduler = get_scheduler(cls='local', args=swh_scheduler_config) all_tasks = [ 'list-bitbucket-full', 'list-bitbucket-incremental', 'list-cran', 'list-cgit', 'list-debian-distribution', 'list-gitlab-full', 'list-gitlab-incremental', 'list-github-full', 'list-github-incremental', 'list-gnu-full', 'list-npm-full', 'list-npm-incremental', 'list-phabricator-full', 'list-packagist', 'list-pypi', ] for task in all_tasks: task_type_desc = scheduler.get_task_type(task) assert task_type_desc assert task_type_desc['type'] == task assert task_type_desc['backoff_factor'] == 1 if task == 'list-npm-full': delay = timedelta(days=7) # overloaded in the plugin registry elif task.endswith('-full'): delay = timedelta(days=90) # default value for 'full' lister tasks else: delay = timedelta(days=1) # default value for other lister tasks assert task_type_desc['default_interval'] == delay, task