diff --git a/.gitignore b/.gitignore index 2961003..0fe4dd9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,13 +1,14 @@ *.pyc *.sw? *~ /.coverage /.coverage.* .eggs/ __pycache__ build/ dist/ *.egg-info version.txt swh/lister/_version.py .tox/ +.mypy_cache/ diff --git a/MANIFEST.in b/MANIFEST.in index 62515f4..50909d7 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,9 +1,10 @@ include Makefile include README include requirements.txt include requirements-swh.txt include requirements-test.txt include version.txt include swh/lister/cran/list_all_packages.R recursive-include swh/lister/*/tests/ *.json *.html *.txt *.* * recursive-include swh/lister/*/tests/data/ *.* * +recursive-include swh py.typed diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..34ca4f4 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,39 @@ +[mypy] +namespace_packages = True +warn_unused_ignores = True + +# support for sqlalchemy magic: see https://github.com/dropbox/sqlalchemy-stubs +plugins = sqlmypy + + +# 3rd party libraries without stubs (yet) + +[mypy-bs4.*] +ignore_missing_imports = True + +[mypy-celery.*] +ignore_missing_imports = True + +[mypy-debian.*] +ignore_missing_imports = True + +[mypy-iso8601.*] +ignore_missing_imports = True + +[mypy-pkg_resources.*] +ignore_missing_imports = True + +[mypy-pytest.*] +ignore_missing_imports = True + +[mypy-requests_mock.*] +ignore_missing_imports = True + +[mypy-testing.postgresql.*] +ignore_missing_imports = True + +[mypy-urllib3.util.*] +ignore_missing_imports = True + +[mypy-xmltodict.*] +ignore_missing_imports = True diff --git a/requirements-test.txt b/requirements-test.txt index 71f2b3d..7cc4593 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,4 +1,5 @@ pytest pytest-postgresql requests_mock testing.postgresql +sqlalchemy-stubs diff --git a/swh/__init__.py b/swh/__init__.py index 69e3be5..f14e196 100644 --- a/swh/__init__.py +++ b/swh/__init__.py @@ -1 +1,4 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) +from pkgutil import extend_path +from typing import Iterable + +__path__ = extend_path(__path__, __name__) # type: Iterable[str] diff --git a/swh/lister/bitbucket/lister.py b/swh/lister/bitbucket/lister.py index 559117c..0c78af0 100644 --- a/swh/lister/bitbucket/lister.py +++ b/swh/lister/bitbucket/lister.py @@ -1,80 +1,81 @@ # Copyright (C) 2017-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import iso8601 from datetime import datetime, timezone +from typing import Any from urllib import parse from swh.lister.bitbucket.models import BitBucketModel from swh.lister.core.indexing_lister import IndexingHttpLister logger = logging.getLogger(__name__) class BitBucketLister(IndexingHttpLister): PATH_TEMPLATE = '/repositories?after=%s' MODEL = BitBucketModel LISTER_NAME = 'bitbucket' DEFAULT_URL = 'https://api.bitbucket.org/2.0' instance = 'bitbucket' - default_min_bound = datetime.fromtimestamp(0, timezone.utc) + default_min_bound = datetime.fromtimestamp(0, timezone.utc) # type: Any def __init__(self, url=None, override_config=None, per_page=100): super().__init__(url=url, override_config=override_config) per_page = self.config.get('per_page', per_page) self.PATH_TEMPLATE = '%s&pagelen=%s' % ( self.PATH_TEMPLATE, per_page) def get_model_from_repo(self, repo): return { 'uid': repo['uuid'], 'indexable': iso8601.parse_date(repo['created_on']), 'name': repo['name'], 'full_name': repo['full_name'], 'html_url': repo['links']['html']['href'], 'origin_url': repo['links']['clone'][0]['href'], 'origin_type': repo['scm'], } def get_next_target_from_response(self, response): """This will read the 'next' link from the api response if any and return it as a datetime. Args: response (Response): requests' response from api call Returns: next date as a datetime """ body = response.json() next_ = body.get('next') if next_ is not None: next_ = parse.urlparse(next_) return iso8601.parse_date(parse.parse_qs(next_.query)['after'][0]) def transport_response_simplified(self, response): repos = response.json()['values'] return [self.get_model_from_repo(repo) for repo in repos] def request_uri(self, identifier): identifier = parse.quote(identifier.isoformat()) return super().request_uri(identifier or '1970-01-01') def is_within_bounds(self, inner, lower=None, upper=None): # values are expected to be datetimes if lower is None and upper is None: ret = True elif lower is None: ret = inner <= upper elif upper is None: ret = inner >= lower else: ret = lower <= inner <= upper return ret diff --git a/swh/lister/bitbucket/tests/test_lister.py b/swh/lister/bitbucket/tests/test_lister.py index ce03b32..9e2378a 100644 --- a/swh/lister/bitbucket/tests/test_lister.py +++ b/swh/lister/bitbucket/tests/test_lister.py @@ -1,102 +1,101 @@ # Copyright (C) 2017-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import re import unittest from datetime import timedelta - from urllib.parse import unquote import iso8601 import requests_mock from swh.lister.bitbucket.lister import BitBucketLister from swh.lister.core.tests.test_lister import HttpListerTester -def convert_type(req_index): +def _convert_type(req_index): """Convert the req_index to its right type according to the model's "indexable" column. """ return iso8601.parse_date(unquote(req_index)) class BitBucketListerTester(HttpListerTester, unittest.TestCase): Lister = BitBucketLister test_re = re.compile(r'/repositories\?after=([^?&]+)') lister_subdir = 'bitbucket' good_api_response_file = 'data/https_api.bitbucket.org/response.json' bad_api_response_file = 'data/https_api.bitbucket.org/empty_response.json' - first_index = convert_type('2008-07-12T07:44:01.476818+00:00') - last_index = convert_type('2008-07-19T06:16:43.044743+00:00') + first_index = _convert_type('2008-07-12T07:44:01.476818+00:00') + last_index = _convert_type('2008-07-19T06:16:43.044743+00:00') entries_per_page = 10 - convert_type = staticmethod(convert_type) + convert_type = _convert_type def request_index(self, request): """(Override) This is needed to emulate the listing bootstrap when no min_bound is provided to run """ m = self.test_re.search(request.path_url) - idx = convert_type(m.group(1)) + idx = _convert_type(m.group(1)) if idx == self.Lister.default_min_bound: idx = self.first_index return idx @requests_mock.Mocker() def test_fetch_none_nodb(self, http_mocker): """Overridden because index is not an integer nor a string """ http_mocker.get(self.test_re, text=self.mock_response) fl = self.get_fl() self.disable_scheduler(fl) self.disable_db(fl) # stores no results fl.run(min_bound=self.first_index - timedelta(days=3), max_bound=self.first_index) def test_is_within_bounds(self): fl = self.get_fl() self.assertTrue(fl.is_within_bounds( iso8601.parse_date('2008-07-15'), self.first_index, self.last_index)) self.assertFalse(fl.is_within_bounds( iso8601.parse_date('2008-07-20'), self.first_index, self.last_index)) self.assertFalse(fl.is_within_bounds( iso8601.parse_date('2008-07-11'), self.first_index, self.last_index)) def test_lister_bitbucket(swh_listers, requests_mock_datadir): """Simple bitbucket listing should create scheduled tasks """ lister = swh_listers['bitbucket'] lister.run() r = lister.scheduler.search_tasks(task_type='load-hg') assert len(r) == 10 for row in r: assert row['type'] == 'load-hg' # arguments check args = row['arguments']['args'] assert len(args) == 1 url = args[0] assert url.startswith('https://bitbucket.org') # kwargs kwargs = row['arguments']['kwargs'] assert kwargs == {} assert row['policy'] == 'recurring' assert row['priority'] is None diff --git a/swh/lister/core/abstractattribute.py b/swh/lister/core/abstractattribute.py index 3adabaf..fdb4219 100644 --- a/swh/lister/core/abstractattribute.py +++ b/swh/lister/core/abstractattribute.py @@ -1,26 +1,27 @@ # Copyright (C) 2017 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information class AbstractAttribute: """AbstractAttributes in a base class must be overridden by the subclass. It's like the :func:`abc.abstractmethod` decorator, but for things that are explicitly attributes/properties, not methods, without the need for empty method def boilerplate. Like abc.abstractmethod, the class containing AbstractAttributes must inherit from :class:`abc.ABC` or use the :class:`abc.ABCMeta` metaclass. Usage example:: import abc class ClassContainingAnAbstractAttribute(abc.ABC): - foo = AbstractAttribute('descriptive docstring for foo') + foo: Union[AbstractAttribute, Any] = \ + AbstractAttribute('docstring for foo') """ __isabstractmethod__ = True def __init__(self, docstring=None): if docstring is not None: self.__doc__ = 'AbstractAttribute: ' + docstring diff --git a/swh/lister/core/lister_base.py b/swh/lister/core/lister_base.py index eb4c039..a9f1e02 100644 --- a/swh/lister/core/lister_base.py +++ b/swh/lister/core/lister_base.py @@ -1,518 +1,521 @@ # Copyright (C) 2015-2018 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import abc import datetime import gzip import json import logging import os import re import time from sqlalchemy import create_engine, func from sqlalchemy.orm import sessionmaker +from typing import Any, Type, Union from swh.core import config from swh.scheduler import get_scheduler, utils from .abstractattribute import AbstractAttribute logger = logging.getLogger(__name__) def utcnow(): return datetime.datetime.now(tz=datetime.timezone.utc) class FetchError(RuntimeError): def __init__(self, response): self.response = response def __str__(self): return repr(self.response) class ListerBase(abc.ABC, config.SWHConfig): """Lister core base class. Generally a source code hosting service provides an API endpoint for listing the set of stored repositories. A Lister is the discovery service responsible for finding this list, all at once or sequentially by parts, and queueing local tasks to fetch and ingest the referenced repositories. The core method in this class is ingest_data. Any subclasses should be calling this method one or more times to fetch and ingest data from API endpoints. See swh.lister.core.lister_base.IndexingLister for example usage. This class cannot be instantiated. Any instantiable Lister descending from ListerBase must provide at least the required overrides. (see member docstrings for details): Required Overrides: MODEL def transport_request def transport_response_to_string def transport_response_simplified def transport_quota_check Optional Overrides: def filter_before_inject def is_within_bounds """ - MODEL = AbstractAttribute('Subclass type (not instance)' - ' of swh.lister.core.models.ModelBase' - ' customized for a specific service.') - LISTER_NAME = AbstractAttribute("Lister's name") + MODEL = AbstractAttribute( + 'Subclass type (not instance) of swh.lister.core.models.ModelBase ' + 'customized for a specific service.' + ) # type: Union[AbstractAttribute, Type[Any]] + LISTER_NAME = AbstractAttribute( + "Lister's name") # type: Union[AbstractAttribute, str] def transport_request(self, identifier): """Given a target endpoint identifier to query, try once to request it. Implementation of this method determines the network request protocol. Args: identifier (string): unique identifier for an endpoint query. e.g. If the service indexes lists of repositories by date and time of creation, this might be that as a formatted string. Or it might be an integer UID. Or it might be nothing. It depends on what the service needs. Returns: the entire request response Raises: Will catch internal transport-dependent connection exceptions and raise swh.lister.core.lister_base.FetchError instead. Other non-connection exceptions should propagate unchanged. """ pass def transport_response_to_string(self, response): """Convert the server response into a formatted string for logging. Implementation of this method depends on the shape of the network response object returned by the transport_request method. Args: response: the server response Returns: a pretty string of the response """ pass def transport_response_simplified(self, response): """Convert the server response into list of a dict for each repo in the response, mapping columns in the lister's MODEL class to repo data. Implementation of this method depends on the server API spec and the shape of the network response object returned by the transport_request method. Args: response: response object from the server. Returns: list of repo MODEL dicts ( eg. [{'uid': r['id'], etc.} for r in response.json()] ) """ pass def transport_quota_check(self, response): """Check server response to see if we're hitting request rate limits. Implementation of this method depends on the server communication protocol and API spec and the shape of the network response object returned by the transport_request method. Args: response (session response): complete API query response Returns: 1) must retry request? True/False 2) seconds to delay if True """ pass def filter_before_inject(self, models_list): """Filter models_list entries prior to injection in the db. This is ran directly after `transport_response_simplified`. Default implementation is to have no filtering. Args: models_list: list of dicts returned by transport_response_simplified. Returns: models_list with entries changed according to custom logic. """ return models_list def do_additional_checks(self, models_list): """Execute some additional checks on the model list (after the filtering). Default implementation is to run no check at all and to return the input as is. Args: models_list: list of dicts returned by transport_response_simplified. Returns: models_list with entries if checks ok, False otherwise """ return models_list def is_within_bounds(self, inner, lower=None, upper=None): """See if a sortable value is inside the range [lower,upper]. MAY BE OVERRIDDEN, for example if the server indexable* key is technically sortable but not automatically so. * - ( see: swh.lister.core.indexing_lister.IndexingLister ) Args: inner (sortable type): the value being checked lower (sortable type): optional lower bound upper (sortable type): optional upper bound Returns: whether inner is confined by the optional lower and upper bounds """ try: if lower is None and upper is None: return True elif lower is None: ret = inner <= upper elif upper is None: ret = inner >= lower else: ret = lower <= inner <= upper self.string_pattern_check(inner, lower, upper) except Exception as e: logger.error(str(e) + ': %s, %s, %s' % (('inner=%s%s' % (type(inner), inner)), ('lower=%s%s' % (type(lower), lower)), ('upper=%s%s' % (type(upper), upper))) ) raise return ret # You probably don't need to override anything below this line. DEFAULT_CONFIG = { 'scheduler': ('dict', { 'cls': 'remote', 'args': { 'url': 'http://localhost:5008/' }, }), 'lister': ('dict', { 'cls': 'local', 'args': { 'db': 'postgresql:///lister', }, }), } @property def CONFIG_BASE_FILENAME(self): # noqa: N802 return 'lister_%s' % self.LISTER_NAME @property def ADDITIONAL_CONFIG(self): # noqa: N802 return { 'credentials': ('dict', {}), 'cache_responses': ('bool', False), 'cache_dir': ('str', '~/.cache/swh/lister/%s' % self.LISTER_NAME), } INITIAL_BACKOFF = 10 MAX_RETRIES = 7 CONN_SLEEP = 10 def __init__(self, override_config=None): self.backoff = self.INITIAL_BACKOFF logger.debug('Loading config from %s' % self.CONFIG_BASE_FILENAME) self.config = self.parse_config_file( base_filename=self.CONFIG_BASE_FILENAME, additional_configs=[self.ADDITIONAL_CONFIG] ) self.config['cache_dir'] = os.path.expanduser(self.config['cache_dir']) if self.config['cache_responses']: config.prepare_folders(self.config, 'cache_dir') if override_config: self.config.update(override_config) logger.debug('%s CONFIG=%s' % (self, self.config)) self.scheduler = get_scheduler(**self.config['scheduler']) self.db_engine = create_engine(self.config['lister']['args']['db']) self.mk_session = sessionmaker(bind=self.db_engine) self.db_session = self.mk_session() def reset_backoff(self): """Reset exponential backoff timeout to initial level.""" self.backoff = self.INITIAL_BACKOFF def back_off(self): """Get next exponential backoff timeout.""" ret = self.backoff self.backoff *= 10 return ret def safely_issue_request(self, identifier): """Make network request with retries, rate quotas, and response logs. Protocol is handled by the implementation of the transport_request method. Args: identifier: resource identifier Returns: server response """ retries_left = self.MAX_RETRIES do_cache = self.config['cache_responses'] r = None while retries_left > 0: try: r = self.transport_request(identifier) except FetchError: # network-level connection error, try again logger.warning( 'connection error on %s: sleep for %d seconds' % (identifier, self.CONN_SLEEP)) time.sleep(self.CONN_SLEEP) retries_left -= 1 continue if do_cache: self.save_response(r) # detect throttling must_retry, delay = self.transport_quota_check(r) if must_retry: logger.warning( 'rate limited on %s: sleep for %f seconds' % (identifier, delay)) time.sleep(delay) else: # request ok break retries_left -= 1 if not retries_left: logger.warning( 'giving up on %s: max retries exceeded' % identifier) return r def db_query_equal(self, key, value): """Look in the db for a row with key == value Args: key: column key to look at value: value to look for in that column Returns: sqlalchemy.ext.declarative.declarative_base object with the given key == value """ if isinstance(key, str): key = self.MODEL.__dict__[key] return self.db_session.query(self.MODEL) \ .filter(key == value).first() def winnow_models(self, mlist, key, to_remove): """Given a list of models, remove any with matching some member of a list of values. Args: mlist (list of model rows): the initial list of models key (column): the column to filter on to_remove (list): if anything in mlist has column equal to one of the values in to_remove, it will be removed from the result Returns: A list of model rows starting from mlist minus any matching rows """ if isinstance(key, str): key = self.MODEL.__dict__[key] if to_remove: return mlist.filter(~key.in_(to_remove)).all() else: return mlist.all() def db_num_entries(self): """Return the known number of entries in the lister db""" return self.db_session.query(func.count('*')).select_from(self.MODEL) \ .scalar() def db_inject_repo(self, model_dict): """Add/update a new repo to the db and mark it last_seen now. Args: model_dict: dictionary mapping model keys to values Returns: new or updated sqlalchemy.ext.declarative.declarative_base object associated with the injection """ sql_repo = self.db_query_equal('uid', model_dict['uid']) if not sql_repo: sql_repo = self.MODEL(**model_dict) self.db_session.add(sql_repo) else: for k in model_dict: setattr(sql_repo, k, model_dict[k]) sql_repo.last_seen = utcnow() return sql_repo def task_dict(self, origin_type, origin_url, **kwargs): """Return special dict format for the tasks list Args: origin_type (string) origin_url (string) Returns: the same information in a different form """ logger.debug('origin-url: %s, type: %s', origin_url, origin_type) _type = 'load-%s' % origin_type _policy = kwargs.get('policy', 'recurring') priority = kwargs.get('priority') kw = {'priority': priority} if priority else {} return utils.create_task_dict(_type, _policy, origin_url, **kw) def string_pattern_check(self, a, b, c=None): """When comparing indexable types in is_within_bounds, complex strings may not be allowed to differ in basic structure. If they do, it could be a sign of not understanding the data well. For instance, an ISO 8601 time string cannot be compared against its urlencoded equivalent, but this is an easy mistake to accidentally make. This method acts as a friendly sanity check. Args: a (string): inner component of the is_within_bounds method b (string): lower component of the is_within_bounds method c (string): upper component of the is_within_bounds method Returns: nothing Raises: TypeError if strings a, b, and c don't conform to the same basic pattern. """ if isinstance(a, str): a_pattern = re.sub('[a-zA-Z0-9]', '[a-zA-Z0-9]', re.escape(a)) if (isinstance(b, str) and (re.match(a_pattern, b) is None) or isinstance(c, str) and (re.match(a_pattern, c) is None)): logger.debug(a_pattern) raise TypeError('incomparable string patterns detected') def inject_repo_data_into_db(self, models_list): """Inject data into the db. Args: models_list: list of dicts mapping keys from the db model for each repo to be injected Returns: dict of uid:sql_repo pairs """ injected_repos = {} for m in models_list: injected_repos[m['uid']] = self.db_inject_repo(m) return injected_repos def schedule_missing_tasks(self, models_list, injected_repos): """Find any newly created db entries that do not have been scheduled yet. Args: models_list ([Model]): List of dicts mapping keys in the db model for each repo injected_repos ([dict]): Dict of uid:sql_repo pairs that have just been created Returns: Nothing. Modifies injected_repos. """ tasks = {} def _task_key(m): return '%s-%s' % ( m['type'], json.dumps(m['arguments'], sort_keys=True) ) for m in models_list: ir = injected_repos[m['uid']] if not ir.task_id: # Patching the model instance to add the policy/priority task # scheduling if 'policy' in self.config: m['policy'] = self.config['policy'] if 'priority' in self.config: m['priority'] = self.config['priority'] task_dict = self.task_dict(**m) tasks[_task_key(task_dict)] = (ir, m, task_dict) new_tasks = self.scheduler.create_tasks( (task_dicts for (_, _, task_dicts) in tasks.values())) for task in new_tasks: ir, m, _ = tasks[_task_key(task)] ir.task_id = task['id'] def ingest_data(self, identifier, checks=False): """The core data fetch sequence. Request server endpoint. Simplify and filter response list of repositories. Inject repo information into local db. Queue loader tasks for linked repositories. Args: identifier: Resource identifier. checks (bool): Additional checks required """ # Request (partial?) list of repositories info response = self.safely_issue_request(identifier) if not response: return response, [] models_list = self.transport_response_simplified(response) models_list = self.filter_before_inject(models_list) if checks: models_list = self.do_additional_checks(models_list) if not models_list: return response, [] # inject into local db injected = self.inject_repo_data_into_db(models_list) # queue workers self.schedule_missing_tasks(models_list, injected) return response, injected def save_response(self, response): """Log the response from a server request to a cache dir. Args: response: full server response cache_dir: system path for cache dir Returns: nothing """ datepath = utcnow().isoformat() fname = os.path.join( self.config['cache_dir'], datepath + '.gz', ) with gzip.open(fname, 'w') as f: f.write(bytes( self.transport_response_to_string(response), 'UTF-8' )) diff --git a/swh/lister/core/lister_transports.py b/swh/lister/core/lister_transports.py index 6f814ef..f7a62c4 100644 --- a/swh/lister/core/lister_transports.py +++ b/swh/lister/core/lister_transports.py @@ -1,232 +1,235 @@ # Copyright (C) 2017-2018 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import abc import random from datetime import datetime from email.utils import parsedate from pprint import pformat import logging import requests import xmltodict +from typing import Optional, Union + try: from swh.lister._version import __version__ except ImportError: __version__ = 'devel' from .abstractattribute import AbstractAttribute from .lister_base import FetchError logger = logging.getLogger(__name__) class ListerHttpTransport(abc.ABC): """Use the Requests library for making Lister endpoint requests. To be used in conjunction with ListerBase or a subclass of it. """ - DEFAULT_URL = None - PATH_TEMPLATE = AbstractAttribute('string containing a python string' - ' format pattern that produces the API' - ' endpoint path for listing stored' - ' repositories when given an index.' - ' eg. "/repositories?after=%s".' - 'To be implemented in the API-specific' - ' class inheriting this.') + DEFAULT_URL = None # type: Optional[str] + PATH_TEMPLATE = \ + AbstractAttribute( + 'string containing a python string format pattern that produces' + ' the API endpoint path for listing stored repositories when given' + ' an index, e.g., "/repositories?after=%s". To be implemented in' + ' the API-specific class inheriting this.' + ) # type: Union[AbstractAttribute, Optional[str]] EXPECTED_STATUS_CODES = (200, 429, 403, 404) def request_headers(self): """Returns dictionary of any request headers needed by the server. MAY BE OVERRIDDEN if request headers are needed. """ return { 'User-Agent': 'Software Heritage lister (%s)' % self.lister_version } def request_instance_credentials(self): """Returns dictionary of any credentials configuration needed by the forge instance to list. The 'credentials' configuration is expected to be a dict of multiple levels. The first level is the lister's name, the second is the lister's instance name, which value is expected to be a list of credential structures (typically a couple username/password). For example: credentials: github: # github lister github: # has only one instance (so far) - username: some password: somekey - username: one password: onekey - ... gitlab: # gitlab lister riseup: # has many instances - username: someone password: ... - ... gitlab: - username: someone password: ... - ... Returns: list of credential dicts for the current lister. """ all_creds = self.config.get('credentials') if not all_creds: return [] lister_creds = all_creds.get(self.LISTER_NAME, {}) creds = lister_creds.get(self.instance, []) return creds def request_uri(self, identifier): """Get the full request URI given the transport_request identifier. MAY BE OVERRIDDEN if something more complex than the PATH_TEMPLATE is required. """ path = self.PATH_TEMPLATE % identifier return self.url + path def request_params(self, identifier): """Get the full parameters passed to requests given the transport_request identifier. This uses credentials if any are provided (see request_instance_credentials). MAY BE OVERRIDDEN if something more complex than the request headers is needed. """ params = {} params['headers'] = self.request_headers() or {} creds = self.request_instance_credentials() if not creds: return params auth = random.choice(creds) if creds else None if auth: params['auth'] = (auth['username'], auth['password']) return params def transport_quota_check(self, response): """Implements ListerBase.transport_quota_check with standard 429 code check for HTTP with Requests library. MAY BE OVERRIDDEN if the server notifies about rate limits in a non-standard way that doesn't use HTTP 429 and the Retry-After response header. ( https://tools.ietf.org/html/rfc6585#section-4 ) """ if response.status_code == 429: # HTTP too many requests retry_after = response.headers.get('Retry-After', self.back_off()) try: # might be seconds return True, float(retry_after) except Exception: # might be http-date at_date = datetime(*parsedate(retry_after)[:6]) from_now = (at_date - datetime.today()).total_seconds() + 5 return True, max(0, from_now) else: # response ok self.reset_backoff() return False, 0 def __init__(self, url=None): if not url: url = self.config.get('url') if not url: url = self.DEFAULT_URL if not url: raise NameError('HTTP Lister Transport requires an url.') self.url = url # eg. 'https://api.github.com' self.session = requests.Session() self.lister_version = __version__ def _transport_action(self, identifier, method='get'): """Permit to ask information to the api prior to actually executing query. """ path = self.request_uri(identifier) params = self.request_params(identifier) logger.debug('path: %s', path) logger.debug('params: %s', params) logger.debug('method: %s', method) try: if method == 'head': response = self.session.head(path, **params) else: response = self.session.get(path, **params) except requests.exceptions.ConnectionError as e: logger.warning('Failed to fetch %s: %s', path, e) raise FetchError(e) else: if response.status_code not in self.EXPECTED_STATUS_CODES: raise FetchError(response) return response def transport_head(self, identifier): """Retrieve head information on api. """ return self._transport_action(identifier, method='head') def transport_request(self, identifier): """Implements ListerBase.transport_request for HTTP using Requests. Retrieve get information on api. """ return self._transport_action(identifier) def transport_response_to_string(self, response): """Implements ListerBase.transport_response_to_string for HTTP given Requests responses. """ s = pformat(response.request.path_url) s += '\n#\n' + pformat(response.request.headers) s += '\n#\n' + pformat(response.status_code) s += '\n#\n' + pformat(response.headers) s += '\n#\n' try: # json? s += pformat(response.json()) except Exception: # not json try: # xml? s += pformat(xmltodict.parse(response.text)) except Exception: # not xml s += pformat(response.text) return s class ListerOnePageApiTransport(ListerHttpTransport): """Leverage requests library to retrieve basic html page and parse result. To be used in conjunction with ListerBase or a subclass of it. """ - PAGE = AbstractAttribute("The server api's unique page to retrieve and " - "parse for information") + PAGE = AbstractAttribute( + "URL of the API's unique page to retrieve and parse " + "for information") # type: Union[AbstractAttribute, str] PATH_TEMPLATE = None # we do not use it def __init__(self, url=None): self.session = requests.Session() self.lister_version = __version__ def request_uri(self, _): """Get the full request URI given the transport_request identifier. """ return self.PAGE diff --git a/swh/lister/core/models.py b/swh/lister/core/models.py index 9fbd157..62ab0b7 100644 --- a/swh/lister/core/models.py +++ b/swh/lister/core/models.py @@ -1,72 +1,79 @@ # Copyright (C) 2015-2017 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import abc from datetime import datetime import logging from sqlalchemy import Column, DateTime, Integer, String from sqlalchemy.ext.declarative import DeclarativeMeta +from typing import Type, Union from .abstractattribute import AbstractAttribute from swh.storage.schemata.distribution import SQLBase logger = logging.getLogger(__name__) class ABCSQLMeta(abc.ABCMeta, DeclarativeMeta): pass class ModelBase(SQLBase, metaclass=ABCSQLMeta): """a common repository""" __abstract__ = True - __tablename__ = AbstractAttribute + __tablename__ = \ + AbstractAttribute # type: Union[Type[AbstractAttribute], str] - uid = AbstractAttribute('Column(, primary_key=True)') + uid = AbstractAttribute( + 'Column(, primary_key=True)' + ) # type: Union[AbstractAttribute, Column] name = Column(String, index=True) full_name = Column(String, index=True) html_url = Column(String) origin_url = Column(String) origin_type = Column(String) last_seen = Column(DateTime, nullable=False) task_id = Column(Integer) def __init__(self, **kw): kw['last_seen'] = datetime.now() super().__init__(**kw) class IndexingModelBase(ModelBase, metaclass=ABCSQLMeta): __abstract__ = True - __tablename__ = AbstractAttribute + __tablename__ = \ + AbstractAttribute # type: Union[Type[AbstractAttribute], str] # The value used for sorting, segmenting, or api query paging, # because uids aren't always sequential. - indexable = AbstractAttribute('Column(, index=True)') + indexable = AbstractAttribute( + 'Column(, index=True)' + ) # type: Union[AbstractAttribute, Column] def initialize(db_engine, drop_tables=False, **kwargs): """Default database initialization function for a lister. Typically called from the lister's initialization hook. Args: models (list): list of SQLAlchemy tables/models to drop/create. db_engine (): the SQLAlchemy DB engine. drop_tables (bool): if True, tables will be dropped before (re)creating them. """ if drop_tables: logger.info('Dropping tables') SQLBase.metadata.drop_all(db_engine, checkfirst=True) logger.info('Creating tables') SQLBase.metadata.create_all(db_engine, checkfirst=True) diff --git a/swh/lister/core/tests/test_abstractattribute.py b/swh/lister/core/tests/test_abstractattribute.py index bfadca6..8190d01 100644 --- a/swh/lister/core/tests/test_abstractattribute.py +++ b/swh/lister/core/tests/test_abstractattribute.py @@ -1,64 +1,66 @@ # Copyright (C) 2017 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import abc import unittest +from typing import Any + from swh.lister.core.abstractattribute import AbstractAttribute class BaseClass(abc.ABC): - v1 = AbstractAttribute - v2 = AbstractAttribute() - v3 = AbstractAttribute('changed docstring') + v1 = AbstractAttribute # type: Any + v2 = AbstractAttribute() # type: Any + v3 = AbstractAttribute('changed docstring') # type: Any v4 = 'qux' class BadSubclass1(BaseClass): pass class BadSubclass2(BaseClass): v1 = 'foo' v2 = 'bar' class BadSubclass3(BaseClass): v2 = 'bar' v3 = 'baz' class GoodSubclass(BaseClass): v1 = 'foo' v2 = 'bar' v3 = 'baz' class TestAbstractAttributes(unittest.TestCase): def test_aa(self): with self.assertRaises(TypeError): BaseClass() with self.assertRaises(TypeError): BadSubclass1() with self.assertRaises(TypeError): BadSubclass2() with self.assertRaises(TypeError): BadSubclass3() self.assertIsInstance(GoodSubclass(), GoodSubclass) gsc = GoodSubclass() self.assertEqual(gsc.v1, 'foo') self.assertEqual(gsc.v2, 'bar') self.assertEqual(gsc.v3, 'baz') self.assertEqual(gsc.v4, 'qux') def test_aa_docstrings(self): self.assertEqual(BaseClass.v1.__doc__, AbstractAttribute.__doc__) self.assertEqual(BaseClass.v2.__doc__, AbstractAttribute.__doc__) self.assertEqual(BaseClass.v3.__doc__, 'AbstractAttribute: changed docstring') diff --git a/swh/lister/core/tests/test_lister.py b/swh/lister/core/tests/test_lister.py index ac5103e..f19bdd2 100644 --- a/swh/lister/core/tests/test_lister.py +++ b/swh/lister/core/tests/test_lister.py @@ -1,412 +1,427 @@ # Copyright (C) 2019 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import abc import datetime import time from unittest import TestCase from unittest.mock import Mock, patch import requests_mock from sqlalchemy import create_engine +from typing import Any, Callable, Optional, Pattern, Type, Union from swh.lister.core.abstractattribute import AbstractAttribute from swh.lister.tests.test_utils import init_db def noop(*args, **kwargs): pass class HttpListerTesterBase(abc.ABC): """Testing base class for listers. This contains methods for both :class:`HttpSimpleListerTester` and :class:`HttpListerTester`. See :class:`swh.lister.gitlab.tests.test_lister` for an example of how to customize for a specific listing service. """ - Lister = AbstractAttribute('The lister class to test') - lister_subdir = AbstractAttribute('bitbucket, github, etc.') - good_api_response_file = AbstractAttribute('Example good response body') + Lister = AbstractAttribute( + 'Lister class to test') # type: Union[AbstractAttribute, Type[Any]] + lister_subdir = AbstractAttribute( + 'bitbucket, github, etc.') # type: Union[AbstractAttribute, str] + good_api_response_file = AbstractAttribute( + 'Example good response body') # type: Union[AbstractAttribute, str] LISTER_NAME = 'fake-lister' # May need to override this if the headers are used for something def response_headers(self, request): return {} # May need to override this if the server uses non-standard rate limiting # method. # Please keep the requested retry delay reasonably low. def mock_rate_quota(self, n, request, context): self.rate_limit += 1 context.status_code = 429 context.headers['Retry-After'] = '1' return '{"error":"dummy"}' def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.rate_limit = 1 self.response = None self.fl = None self.helper = None self.scheduler_tasks = [] if self.__class__ != HttpListerTesterBase: self.run = TestCase.run.__get__(self, self.__class__) else: self.run = noop def mock_limit_n_response(self, n, request, context): self.fl.reset_backoff() if self.rate_limit <= n: return self.mock_rate_quota(n, request, context) else: return self.mock_response(request, context) def mock_limit_twice_response(self, request, context): return self.mock_limit_n_response(2, request, context) def get_api_response(self, identifier): fl = self.get_fl() if self.response is None: self.response = fl.safely_issue_request(identifier) return self.response def get_fl(self, override_config=None): """Retrieve an instance of fake lister (fl). """ if override_config or self.fl is None: self.fl = self.Lister(url='https://fakeurl', override_config=override_config) self.fl.INITIAL_BACKOFF = 1 self.fl.reset_backoff() self.scheduler_tasks = [] return self.fl def disable_scheduler(self, fl): fl.schedule_missing_tasks = Mock(return_value=None) def mock_scheduler(self, fl): def _create_tasks(tasks): task_id = 0 current_nb_tasks = len(self.scheduler_tasks) if current_nb_tasks > 0: task_id = self.scheduler_tasks[-1]['id'] + 1 for task in tasks: scheduler_task = dict(task) scheduler_task.update({ 'status': 'next_run_not_scheduled', 'retries_left': 0, 'priority': None, 'id': task_id, 'current_interval': datetime.timedelta(days=64) }) self.scheduler_tasks.append(scheduler_task) task_id = task_id + 1 return self.scheduler_tasks[current_nb_tasks:] def _disable_tasks(task_ids): for task_id in task_ids: self.scheduler_tasks[task_id]['status'] = 'disabled' fl.scheduler.create_tasks = Mock(wraps=_create_tasks) fl.scheduler.disable_tasks = Mock(wraps=_disable_tasks) def disable_db(self, fl): fl.winnow_models = Mock(return_value=[]) fl.db_inject_repo = Mock(return_value=fl.MODEL()) fl.disable_deleted_repo_tasks = Mock(return_value=None) def init_db(self, db, model): engine = create_engine(db.url()) model.metadata.create_all(engine) @requests_mock.Mocker() def test_is_within_bounds(self, http_mocker): fl = self.get_fl() self.assertFalse(fl.is_within_bounds(1, 2, 3)) self.assertTrue(fl.is_within_bounds(2, 1, 3)) self.assertTrue(fl.is_within_bounds(1, 1, 1)) self.assertTrue(fl.is_within_bounds(1, None, None)) self.assertTrue(fl.is_within_bounds(1, None, 2)) self.assertTrue(fl.is_within_bounds(1, 0, None)) self.assertTrue(fl.is_within_bounds("b", "a", "c")) self.assertFalse(fl.is_within_bounds("a", "b", "c")) self.assertTrue(fl.is_within_bounds("a", None, "c")) self.assertTrue(fl.is_within_bounds("a", None, None)) self.assertTrue(fl.is_within_bounds("b", "a", None)) self.assertFalse(fl.is_within_bounds("a", "b", None)) self.assertTrue(fl.is_within_bounds("aa:02", "aa:01", "aa:03")) self.assertFalse(fl.is_within_bounds("aa:12", None, "aa:03")) with self.assertRaises(TypeError): fl.is_within_bounds(1.0, "b", None) with self.assertRaises(TypeError): fl.is_within_bounds("A:B", "A::B", None) class HttpListerTester(HttpListerTesterBase, abc.ABC): """Base testing class for subclass of :class:`swh.lister.core.indexing_lister.IndexingHttpLister` See :class:`swh.lister.github.tests.test_gh_lister` for an example of how to customize for a specific listing service. """ - last_index = AbstractAttribute('Last index in good_api_response') - first_index = AbstractAttribute('First index in good_api_response') - bad_api_response_file = AbstractAttribute('Example bad response body') - entries_per_page = AbstractAttribute('Number of results in good response') - test_re = AbstractAttribute('Compiled regex matching the server url. Must' - ' capture the index value.') - convert_type = str + last_index = AbstractAttribute( + 'Last index ' + 'in good_api_response') # type: Union[AbstractAttribute, int] + first_index = AbstractAttribute( + 'First index in ' + ' good_api_response') # type: Union[AbstractAttribute, Optional[int]] + bad_api_response_file = AbstractAttribute( + 'Example bad response body') # type: Union[AbstractAttribute, str] + entries_per_page = AbstractAttribute( + 'Number of results in ' + 'good response') # type: Union[AbstractAttribute, int] + test_re = AbstractAttribute( + 'Compiled regex matching the server url. Must capture the ' + 'index value.') # type: Union[AbstractAttribute, Pattern] + convert_type = str # type: Callable[..., Any] """static method used to convert the "request_index" to its right type (for indexing listers for example, this is in accordance with the model's "indexable" column). """ def mock_response(self, request, context): self.fl.reset_backoff() self.rate_limit = 1 context.status_code = 200 custom_headers = self.response_headers(request) context.headers.update(custom_headers) req_index = self.request_index(request) if req_index == self.first_index: response_file = self.good_api_response_file else: response_file = self.bad_api_response_file with open('swh/lister/%s/tests/%s' % (self.lister_subdir, response_file), 'r', encoding='utf-8') as r: return r.read() def request_index(self, request): m = self.test_re.search(request.path_url) if m and (len(m.groups()) > 0): return self.convert_type(m.group(1)) def create_fl_with_db(self, http_mocker): http_mocker.get(self.test_re, text=self.mock_response) db = init_db() fl = self.get_fl(override_config={ 'lister': { 'cls': 'local', 'args': {'db': db.url()} } }) fl.db = db self.init_db(db, fl.MODEL) self.mock_scheduler(fl) return fl @requests_mock.Mocker() def test_fetch_no_bounds_yesdb(self, http_mocker): fl = self.create_fl_with_db(http_mocker) fl.run() self.assertEqual(fl.db_last_index(), self.last_index) ingested_repos = list(fl.db_query_range(self.first_index, self.last_index)) self.assertEqual(len(ingested_repos), self.entries_per_page) @requests_mock.Mocker() def test_fetch_multiple_pages_yesdb(self, http_mocker): fl = self.create_fl_with_db(http_mocker) fl.run(min_bound=self.first_index) self.assertEqual(fl.db_last_index(), self.last_index) partitions = fl.db_partition_indices(5) self.assertGreater(len(partitions), 0) for k in partitions: self.assertLessEqual(len(k), 5) self.assertGreater(len(k), 0) @requests_mock.Mocker() def test_fetch_none_nodb(self, http_mocker): http_mocker.get(self.test_re, text=self.mock_response) fl = self.get_fl() self.disable_scheduler(fl) self.disable_db(fl) fl.run(min_bound=1, max_bound=1) # stores no results # FIXME: Determine what this method tries to test and add checks to # actually test @requests_mock.Mocker() def test_fetch_one_nodb(self, http_mocker): http_mocker.get(self.test_re, text=self.mock_response) fl = self.get_fl() self.disable_scheduler(fl) self.disable_db(fl) fl.run(min_bound=self.first_index, max_bound=self.first_index) # FIXME: Determine what this method tries to test and add checks to # actually test @requests_mock.Mocker() def test_fetch_multiple_pages_nodb(self, http_mocker): http_mocker.get(self.test_re, text=self.mock_response) fl = self.get_fl() self.disable_scheduler(fl) self.disable_db(fl) fl.run(min_bound=self.first_index) # FIXME: Determine what this method tries to test and add checks to # actually test @requests_mock.Mocker() def test_repos_list(self, http_mocker): """Test the number of repos listed by the lister """ http_mocker.get(self.test_re, text=self.mock_response) li = self.get_fl().transport_response_simplified( self.get_api_response(self.first_index) ) self.assertIsInstance(li, list) self.assertEqual(len(li), self.entries_per_page) @requests_mock.Mocker() def test_model_map(self, http_mocker): """Check if all the keys of model are present in the model created by the `transport_response_simplified` """ http_mocker.get(self.test_re, text=self.mock_response) fl = self.get_fl() li = fl.transport_response_simplified( self.get_api_response(self.first_index)) di = li[0] self.assertIsInstance(di, dict) pubs = [k for k in vars(fl.MODEL).keys() if not k.startswith('_')] for k in pubs: if k not in ['last_seen', 'task_id', 'id']: self.assertIn(k, di) @requests_mock.Mocker() def test_api_request(self, http_mocker): """Test API request for rate limit handling """ http_mocker.get(self.test_re, text=self.mock_limit_twice_response) with patch.object(time, 'sleep', wraps=time.sleep) as sleepmock: self.get_api_response(self.first_index) self.assertEqual(sleepmock.call_count, 2) def scheduled_tasks_test(self, next_api_response_file, next_last_index, http_mocker): """Check that no loading tasks get disabled when processing a new page of repositories returned by a forge API """ fl = self.create_fl_with_db(http_mocker) # process first page of repositories listing fl.run() # process second page of repositories listing prev_last_index = self.last_index self.first_index = self.last_index self.last_index = next_last_index self.good_api_response_file = next_api_response_file fl.run(min_bound=prev_last_index) # check expected number of ingested repos and loading tasks ingested_repos = list(fl.db_query_range(0, self.last_index)) self.assertEqual(len(ingested_repos), len(self.scheduler_tasks)) self.assertEqual(len(ingested_repos), 2 * self.entries_per_page) # check tasks are not disabled for task in self.scheduler_tasks: self.assertTrue(task['status'] != 'disabled') class HttpSimpleListerTester(HttpListerTesterBase, abc.ABC): """Base testing class for subclass of :class:`swh.lister.core.simple)_lister.SimpleLister` See :class:`swh.lister.pypi.tests.test_lister` for an example of how to customize for a specific listing service. """ - entries = AbstractAttribute('Number of results in good response') - PAGE = AbstractAttribute("The server api's unique page to retrieve and " - "parse for information") + entries = AbstractAttribute( + 'Number of results ' + 'in good response') # type: Union[AbstractAttribute, int] + PAGE = AbstractAttribute( + "URL of the server api's unique page to retrieve and " + "parse for information") # type: Union[AbstractAttribute, str] def get_fl(self, override_config=None): """Retrieve an instance of fake lister (fl). """ if override_config or self.fl is None: self.fl = self.Lister( override_config=override_config) self.fl.INITIAL_BACKOFF = 1 self.fl.reset_backoff() return self.fl def mock_response(self, request, context): self.fl.reset_backoff() self.rate_limit = 1 context.status_code = 200 custom_headers = self.response_headers(request) context.headers.update(custom_headers) response_file = self.good_api_response_file with open('swh/lister/%s/tests/%s' % (self.lister_subdir, response_file), 'r', encoding='utf-8') as r: return r.read() @requests_mock.Mocker() def test_api_request(self, http_mocker): """Test API request for rate limit handling """ http_mocker.get(self.PAGE, text=self.mock_limit_twice_response) with patch.object(time, 'sleep', wraps=time.sleep) as sleepmock: self.get_api_response(0) self.assertEqual(sleepmock.call_count, 2) @requests_mock.Mocker() def test_model_map(self, http_mocker): """Check if all the keys of model are present in the model created by the `transport_response_simplified` """ http_mocker.get(self.PAGE, text=self.mock_response) fl = self.get_fl() li = fl.list_packages(self.get_api_response(0)) li = fl.transport_response_simplified(li) di = li[0] self.assertIsInstance(di, dict) pubs = [k for k in vars(fl.MODEL).keys() if not k.startswith('_')] for k in pubs: if k not in ['last_seen', 'task_id', 'id']: self.assertIn(k, di) @requests_mock.Mocker() def test_repos_list(self, http_mocker): """Test the number of packages listed by the lister """ http_mocker.get(self.PAGE, text=self.mock_response) li = self.get_fl().list_packages( self.get_api_response(0) ) self.assertIsInstance(li, list) self.assertEqual(len(li), self.entries) diff --git a/swh/lister/cran/lister.py b/swh/lister/cran/lister.py index 77539f4..3f632e6 100644 --- a/swh/lister/cran/lister.py +++ b/swh/lister/cran/lister.py @@ -1,132 +1,132 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import json import logging import pkg_resources import subprocess from typing import List, Mapping from swh.lister.cran.models import CRANModel from swh.lister.core.simple_lister import SimpleLister from swh.scheduler.utils import create_task_dict logger = logging.getLogger(__name__) def read_cran_data() -> List[Mapping[str, str]]: """Execute r script to read cran listing. """ filepath = pkg_resources.resource_filename('swh.lister.cran', 'list_all_packages.R') logger.debug('script list-all-packages.R path: %s', filepath) response = subprocess.run( filepath, stdout=subprocess.PIPE, shell=False, encoding='utf-8') return json.loads(response.stdout) def compute_package_url(repo: Mapping[str, str]) -> str: """Compute the package url from the repo dict. Args: repo: dict with key 'Package', 'Version' Returns: the package url """ return 'https://cran.r-project.org/src/contrib' \ - '/%(Package)s_%(Version)s.tar.gz' % repo + '/%(Package)s_%(Version)s.tar.gz'.format(repo) class CRANLister(SimpleLister): MODEL = CRANModel LISTER_NAME = 'cran' instance = 'cran' def task_dict(self, origin_type, origin_url, **kwargs): """Return task format dict. This creates tasks with args and kwargs set, for example:: args: ['package', 'https://cran.r-project.org/...', 'version'] kwargs: {} """ policy = kwargs.get('policy', 'oneshot') package = kwargs.get('name') version = kwargs.get('version') return create_task_dict( 'load-%s' % origin_type, policy, package, origin_url, version, retries_left=3, ) def safely_issue_request(self, identifier): """Bypass the implementation. It's now the `list_packages` which returns data. As an implementation detail, we cannot change simply the base SimpleLister yet as other implementation still uses it. This shall be part of another refactoring pass. """ return None def list_packages(self, *args) -> List[Mapping[str, str]]: """Runs R script which uses inbuilt API to return a json response containing data about the R packages. Returns: List of Dict about r packages. For example: .. code-block:: python [ { 'Package': 'A3', 'Version': '1.0.0', 'Title': 'Accurate, Adaptable, and Accessible Error Metrics for Predictive\nModels', 'Description': 'Supplies tools for tabulating and analyzing the results of predictive models. The methods employed are ... ' }, { 'Package': 'abbyyR', 'Version': '0.5.4', 'Title': 'Access to Abbyy OCR (OCR) API', 'Description': 'Get text from images of text using Abbyy Cloud Optical Character\n ...' }, ... ] """ return read_cran_data() def get_model_from_repo( self, repo: Mapping[str, str]) -> Mapping[str, str]: """Transform from repository representation to model """ logger.debug('repo: %s', repo) project_url = compute_package_url(repo) package = repo['Package'] return { 'uid': package, 'name': package, 'full_name': repo['Title'], 'version': repo['Version'], 'html_url': project_url, 'origin_url': project_url, 'origin_type': 'tar', } diff --git a/swh/lister/github/lister.py b/swh/lister/github/lister.py index 9d22172..63462c1 100644 --- a/swh/lister/github/lister.py +++ b/swh/lister/github/lister.py @@ -1,66 +1,68 @@ # Copyright (C) 2017-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import re import time +from typing import Any + from swh.lister.core.indexing_lister import IndexingHttpLister from swh.lister.github.models import GitHubModel class GitHubLister(IndexingHttpLister): PATH_TEMPLATE = '/repositories?since=%d' MODEL = GitHubModel DEFAULT_URL = 'https://api.github.com' API_URL_INDEX_RE = re.compile(r'^.*/repositories\?since=(\d+)') LISTER_NAME = 'github' instance = 'github' # There is only 1 instance of such lister - default_min_bound = 0 + default_min_bound = 0 # type: Any def get_model_from_repo(self, repo): return { 'uid': repo['id'], 'indexable': repo['id'], 'name': repo['name'], 'full_name': repo['full_name'], 'html_url': repo['html_url'], 'origin_url': repo['html_url'], 'origin_type': 'git', 'fork': repo['fork'], } def transport_quota_check(self, response): x_rate_limit_remaining = response.headers.get('X-RateLimit-Remaining') if not x_rate_limit_remaining: return False, 0 reqs_remaining = int(x_rate_limit_remaining) if response.status_code == 403 and reqs_remaining == 0: reset_at = int(response.headers['X-RateLimit-Reset']) delay = min(reset_at - time.time(), 3600) return True, delay return False, 0 def get_next_target_from_response(self, response): if 'next' in response.links: next_url = response.links['next']['url'] return int(self.API_URL_INDEX_RE.match(next_url).group(1)) def transport_response_simplified(self, response): repos = response.json() return [self.get_model_from_repo(repo) for repo in repos] def request_headers(self): return {'Accept': 'application/vnd.github.v3+json'} def disable_deleted_repo_tasks(self, index, next_index, keep_these): """ (Overrides) Fix provided index value to avoid erroneously disabling some scheduler tasks """ # Next listed repository ids are strictly greater than the 'since' # parameter, so increment the index to avoid disabling the latest # created task when processing a new repositories page returned by # the Github API return super().disable_deleted_repo_tasks(index + 1, next_index, keep_these) diff --git a/swh/lister/phabricator/lister.py b/swh/lister/phabricator/lister.py index 28828c6..fee968b 100644 --- a/swh/lister/phabricator/lister.py +++ b/swh/lister/phabricator/lister.py @@ -1,180 +1,181 @@ # Copyright (C) 2019 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import random import urllib.parse from collections import defaultdict from sqlalchemy import func from swh.lister.core.indexing_lister import IndexingHttpLister from swh.lister.phabricator.models import PhabricatorModel logger = logging.getLogger(__name__) class PhabricatorLister(IndexingHttpLister): PATH_TEMPLATE = '?order=oldest&attachments[uris]=1&after=%s' - DEFAULT_URL = 'https://forge.softwareheritage.org/api/diffusion.repository.search' # noqa + DEFAULT_URL = \ + 'https://forge.softwareheritage.org/api/diffusion.repository.search' MODEL = PhabricatorModel LISTER_NAME = 'phabricator' def __init__(self, url=None, instance=None, override_config=None): super().__init__(url=url, override_config=override_config) if not instance: instance = urllib.parse.urlparse(self.url).hostname self.instance = instance def request_params(self, identifier): """Override the default params behavior to retrieve the api token Credentials are stored as: credentials: phabricator: : - username: password: """ creds = self.request_instance_credentials() if not creds: raise ValueError( 'Phabricator forge needs authentication credential to list.') api_token = random.choice(creds)['password'] return {'headers': self.request_headers() or {}, 'params': {'api.token': api_token}} def request_headers(self): """ (Override) Set requests headers to send when querying the Phabricator API """ return {'User-Agent': 'Software Heritage phabricator lister', 'Accept': 'application/json'} def get_model_from_repo(self, repo): url = get_repo_url(repo['attachments']['uris']['uris']) if url is None: return None return { 'uid': url, 'indexable': repo['id'], 'name': repo['fields']['shortName'], 'full_name': repo['fields']['name'], 'html_url': url, 'origin_url': url, 'origin_type': repo['fields']['vcs'], 'instance': self.instance, } def get_next_target_from_response(self, response): body = response.json()['result']['cursor'] if body['after'] and body['after'] != 'null': return int(body['after']) def transport_response_simplified(self, response): repos = response.json() if repos['result'] is None: raise ValueError( 'Problem during information fetch: %s' % repos['error_code']) repos = repos['result']['data'] return [self.get_model_from_repo(repo) for repo in repos] def filter_before_inject(self, models_list): """ (Overrides) IndexingLister.filter_before_inject Bounds query results by this Lister's set max_index. """ models_list = [m for m in models_list if m is not None] return super().filter_before_inject(models_list) def disable_deleted_repo_tasks(self, index, next_index, keep_these): """ (Overrides) Fix provided index value to avoid: - database query error - erroneously disabling some scheduler tasks """ # First call to the Phabricator API uses an empty 'after' parameter, # so set the index to 0 to avoid database query error if index == '': index = 0 # Next listed repository ids are strictly greater than the 'after' # parameter, so increment the index to avoid disabling the latest # created task when processing a new repositories page returned by # the Phabricator API else: index = index + 1 return super().disable_deleted_repo_tasks(index, next_index, keep_these) def db_first_index(self): """ (Overrides) Filter results by Phabricator instance Returns: the smallest indexable value of all repos in the db """ t = self.db_session.query(func.min(self.MODEL.indexable)) t = t.filter(self.MODEL.instance == self.instance).first() if t: return t[0] def db_last_index(self): """ (Overrides) Filter results by Phabricator instance Returns: the largest indexable value of all instance repos in the db """ t = self.db_session.query(func.max(self.MODEL.indexable)) t = t.filter(self.MODEL.instance == self.instance).first() if t: return t[0] def db_query_range(self, start, end): """ (Overrides) Filter the results by the Phabricator instance to avoid disabling loading tasks for repositories hosted on a different instance. Returns: a list of sqlalchemy.ext.declarative.declarative_base objects with indexable values within the given range for the instance """ retlist = super().db_query_range(start, end) return retlist.filter(self.MODEL.instance == self.instance) def get_repo_url(attachments): """ Return url for a hosted repository from its uris attachments according to the following priority lists: * protocol: https > http * identifier: shortname > callsign > id """ processed_urls = defaultdict(dict) for uri in attachments: protocol = uri['fields']['builtin']['protocol'] url = uri['fields']['uri']['effective'] identifier = uri['fields']['builtin']['identifier'] if protocol in ('http', 'https'): processed_urls[protocol][identifier] = url elif protocol is None: for protocol in ('https', 'http'): if url.startswith(protocol): processed_urls[protocol]['undefined'] = url break for protocol in ['https', 'http']: for identifier in ['shortname', 'callsign', 'id', 'undefined']: if (protocol in processed_urls and identifier in processed_urls[protocol]): return processed_urls[protocol][identifier] return None diff --git a/swh/lister/py.typed b/swh/lister/py.typed new file mode 100644 index 0000000..1242d43 --- /dev/null +++ b/swh/lister/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561. diff --git a/tox.ini b/tox.ini index 2ca724a..731c462 100644 --- a/tox.ini +++ b/tox.ini @@ -1,27 +1,35 @@ [tox] -envlist=flake8,py3 +envlist=flake8,mypy,py3 [testenv:py3] deps = swh.core[http] >= 0.0.61 .[testing] pytest-cov commands = pytest --cov={envsitepackagesdir}/swh/lister/ --cov-branch \ {envsitepackagesdir}/swh/lister/ {posargs} [testenv:py3-dev] deps = swh.core[http] >= 0.0.61 .[testing] pytest-cov ipdb commands = pytest {envsitepackagesdir}/swh/lister/ {posargs} [testenv:flake8] skip_install = true deps = flake8 commands = {envpython} -m flake8 + +[testenv:mypy] +skip_install = true +deps = + .[testing] + mypy +commands = + mypy swh