diff --git a/swh/lister/bitbucket/lister.py b/swh/lister/bitbucket/lister.py index a480532..5378983 100644 --- a/swh/lister/bitbucket/lister.py +++ b/swh/lister/bitbucket/lister.py @@ -1,82 +1,87 @@ # Copyright (C) 2017-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import iso8601 from datetime import datetime, timezone -from typing import Any +from typing import Any, Dict, List, Optional, Union from urllib import parse - +from requests import Response from swh.lister.bitbucket.models import BitBucketModel from swh.lister.core.indexing_lister import IndexingHttpLister logger = logging.getLogger(__name__) class BitBucketLister(IndexingHttpLister): PATH_TEMPLATE = '/repositories?after=%s' MODEL = BitBucketModel LISTER_NAME = 'bitbucket' DEFAULT_URL = 'https://api.bitbucket.org/2.0' instance = 'bitbucket' default_min_bound = datetime.fromtimestamp(0, timezone.utc) # type: Any - def __init__(self, url=None, override_config=None, per_page=100): + def __init__(self, url: str = None, + override_config=None, per_page: int = 100) -> None: super().__init__(url=url, override_config=override_config) per_page = self.config.get('per_page', per_page) self.PATH_TEMPLATE = '%s&pagelen=%s' % ( self.PATH_TEMPLATE, per_page) - def get_model_from_repo(self, repo): + def get_model_from_repo(self, repo: Dict) -> Dict[str, Any]: return { 'uid': repo['uuid'], 'indexable': iso8601.parse_date(repo['created_on']), 'name': repo['name'], 'full_name': repo['full_name'], 'html_url': repo['links']['html']['href'], 'origin_url': repo['links']['clone'][0]['href'], 'origin_type': repo['scm'], } - def get_next_target_from_response(self, response): + def get_next_target_from_response(self, response: Response + ) -> Union[None, datetime]: """This will read the 'next' link from the api response if any and return it as a datetime. Args: response (Response): requests' response from api call Returns: next date as a datetime """ body = response.json() next_ = body.get('next') if next_ is not None: next_ = parse.urlparse(next_) return iso8601.parse_date(parse.parse_qs(next_.query)['after'][0]) + return None - def transport_response_simplified(self, response): + def transport_response_simplified(self, response: Response + ) -> List[Dict[str, Any]]: repos = response.json()['values'] return [self.get_model_from_repo(repo) for repo in repos] - def request_uri(self, identifier): - identifier = parse.quote(identifier.isoformat()) - return super().request_uri(identifier or '1970-01-01') + def request_uri(self, identifier: datetime) -> str: + identifier_str = parse.quote(identifier.isoformat()) + return super().request_uri(identifier_str or '1970-01-01') - def is_within_bounds(self, inner, lower=None, upper=None): + def is_within_bounds(self, inner: int, lower: Optional[int] = None, + upper: Optional[int] = None) -> bool: # values are expected to be datetimes if lower is None and upper is None: ret = True elif lower is None: - ret = inner <= upper + ret = inner <= upper # type: ignore elif upper is None: ret = inner >= lower else: ret = lower <= inner <= upper return ret diff --git a/swh/lister/cgit/lister.py b/swh/lister/cgit/lister.py index d770cbd..1f5545c 100644 --- a/swh/lister/cgit/lister.py +++ b/swh/lister/cgit/lister.py @@ -1,148 +1,150 @@ # Copyright (C) 2019 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import re import logging from urllib.parse import urlparse, urljoin from bs4 import BeautifulSoup from requests import Session +# from requests.structures import CaseInsensitiveDict from requests.adapters import HTTPAdapter - +from typing import Any, Dict, Generator, Union from .models import CGitModel from swh.core.utils import grouper from swh.lister import USER_AGENT from swh.lister.core.lister_base import ListerBase logger = logging.getLogger(__name__) class CGitLister(ListerBase): """Lister class for CGit repositories. This lister will retrieve the list of published git repositories by parsing the HTML page(s) of the index retrieved at `url`. For each found git repository, a query is made at the given url found in this index to gather published "Clone" URLs to be used as origin URL for that git repo. If several "Clone" urls are provided, prefer the http/https one, if any, otherwise fall bak to the first one. A loader task is created for each git repository:: Task: Type: load-git Policy: recurring Args: Example:: Task: Type: load-git Policy: recurring Args: 'https://git.savannah.gnu.org/git/elisp-es.git' """ MODEL = CGitModel DEFAULT_URL = 'https://git.savannah.gnu.org/cgit/' LISTER_NAME = 'cgit' url_prefix_present = True - def __init__(self, url=None, instance=None, override_config=None): + def __init__(self, url=None, instance=None, + override_config=None): """Lister class for CGit repositories. Args: - url (str): main URL of the CGit instance, i.e. url of the index + url : main URL of the CGit instance, i.e. url of the index of published git repositories on this instance. - instance (str): Name of cgit instance. Defaults to url's hostname + instance : Name of cgit instance. Defaults to url's hostname if unset. """ super().__init__(override_config=override_config) if url is None: url = self.config.get('url', self.DEFAULT_URL) self.url = url if not instance: instance = urlparse(url).hostname self.instance = instance self.session = Session() self.session.mount(self.url, HTTPAdapter(max_retries=3)) self.session.headers = { 'User-Agent': USER_AGENT, } - def run(self): + def run(self) -> Dict[str, str]: status = 'uneventful' total = 0 for repos in grouper(self.get_repos(), 10): models = list(filter(None, (self.build_model(repo) for repo in repos))) injected_repos = self.inject_repo_data_into_db(models) self.schedule_missing_tasks(models, injected_repos) self.db_session.commit() total += len(injected_repos) logger.debug('Scheduled %s tasks for %s', total, self.url) status = 'eventful' return {'status': status} - def get_repos(self): + def get_repos(self) -> Generator: """Generate git 'project' URLs found on the current CGit server """ next_page = self.url while next_page: bs_idx = self.get_and_parse(next_page) for tr in bs_idx.find( 'div', {"class": "content"}).find_all( "tr", {"class": ""}): yield urljoin(self.url, tr.find('a')['href']) try: pager = bs_idx.find('ul', {'class': 'pager'}) current_page = pager.find('a', {'class': 'current'}) if current_page: next_page = current_page.parent.next_sibling.a['href'] next_page = urljoin(self.url, next_page) except (AttributeError, KeyError): # no pager, or no next page next_page = None - def build_model(self, repo_url): + def build_model(self, repo_url: str) -> Union[None, Dict[str, Any]]: """Given the URL of a git repo project page on a CGit server, return the repo description (dict) suitable for insertion in the db. """ bs = self.get_and_parse(repo_url) urls = [x['href'] for x in bs.find_all('a', {'rel': 'vcs-git'})] if not urls: - return + return None # look for the http/https url, if any, and use it as origin_url for url in urls: if urlparse(url).scheme in ('http', 'https'): origin_url = url break else: # otherwise, choose the first one origin_url = urls[0] return {'uid': repo_url, 'name': bs.find('a', title=re.compile('.+'))['title'], 'origin_type': 'git', 'instance': self.instance, 'origin_url': origin_url, } - def get_and_parse(self, url): + def get_and_parse(self, url: str) -> BeautifulSoup: "Get the given url and parse the retrieved HTML using BeautifulSoup" return BeautifulSoup(self.session.get(url).text, features='html.parser') diff --git a/swh/lister/core/indexing_lister.py b/swh/lister/core/indexing_lister.py index 2e7f300..8611e7f 100644 --- a/swh/lister/core/indexing_lister.py +++ b/swh/lister/core/indexing_lister.py @@ -1,259 +1,269 @@ # Copyright (C) 2015-2019 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import abc import logging from itertools import count import dateutil from sqlalchemy import func from .lister_transports import ListerHttpTransport from .lister_base import ListerBase +from requests import Response +from typing import Any, Dict, List, Tuple, Optional + logger = logging.getLogger(__name__) class IndexingLister(ListerBase): """Lister* intermediate class for any service that follows the pattern: - The service must report at least one stable unique identifier, known herein as the UID value, for every listed repository. - If the service splits the list of repositories into sublists, it must report at least one stable and sorted index identifier for every listed repository, known herein as the indexable value, which can be used as part of the service endpoint query to request a sublist beginning from that index. This might be the UID if the UID is monotonic. - Client sends a request to list repositories starting from a given index. - Client receives structured (json/xml/etc) response with information about a sequential series of repositories starting from that index and, if necessary/available, some indication of the URL or index for fetching the next series of repository data. See :class:`swh.lister.core.lister_base.ListerBase` for more details. This class cannot be instantiated. To create a new Lister for a source code listing service that follows the model described above, you must subclass this class and provide the required overrides in addition to any unmet implementation/override requirements of this class's base. (see parent class and member docstrings for details) Required Overrides:: def get_next_target_from_response """ flush_packet_db = 20 """Number of iterations in-between write flushes of lister repositories to db (see fn:`run`). """ default_min_bound = '' """Default initialization value for the minimum boundary index to use when undefined (see fn:`run`). """ @abc.abstractmethod - def get_next_target_from_response(self, response): + def get_next_target_from_response(self, response: Response): """Find the next server endpoint identifier given the entire response. Implementation of this method depends on the server API spec and the shape of the network response object returned by the transport_request method. Args: response (transport response): response page from the server Returns: index of next page, possibly extracted from a next href url """ pass # You probably don't need to override anything below this line. - def filter_before_inject(self, models_list): + def filter_before_inject( + self, models_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Overrides ListerBase.filter_before_inject Bounds query results by this Lister's set max_index. """ models_list = [ m for m in models_list if self.is_within_bounds(m['indexable'], None, self.max_index) ] return models_list def db_query_range(self, start, end): """Look in the db for a range of repositories with indexable values in the range [start, end] Args: start (model indexable type): start of desired indexable range end (model indexable type): end of desired indexable range Returns: a list of sqlalchemy.ext.declarative.declarative_base objects with indexable values within the given range """ retlist = self.db_session.query(self.MODEL) if start is not None: retlist = retlist.filter(self.MODEL.indexable >= start) if end is not None: retlist = retlist.filter(self.MODEL.indexable <= end) return retlist - def db_partition_indices(self, partition_size): + def db_partition_indices( + self, partition_size: int + ) -> List[Tuple[Optional[int], Optional[int]]]: """Describe an index-space compartmentalization of the db table in equal sized chunks. This is used to describe min&max bounds for parallelizing fetch tasks. Args: partition_size (int): desired size to make each partition Returns: a list of tuples (begin, end) of indexable value that declare approximately equal-sized ranges of existing repos """ n = max(self.db_num_entries(), 10) partition_size = min(partition_size, n) n_partitions = n // partition_size min_index = self.db_first_index() max_index = self.db_last_index() if min_index is None or max_index is None: # Nothing to list return [] if isinstance(min_index, str): def format_bound(bound): return bound.isoformat() min_index = dateutil.parser.parse(min_index) max_index = dateutil.parser.parse(max_index) elif isinstance(max_index - min_index, int): def format_bound(bound): return int(bound) else: def format_bound(bound): return bound partition_width = (max_index - min_index) / n_partitions # Generate n_partitions + 1 bounds for n_partitions partitons bounds = [ format_bound(min_index + i * partition_width) for i in range(n_partitions + 1) ] # Trim duplicate bounds bounds.append(None) bounds = [cur for cur, next in zip(bounds[:-1], bounds[1:]) if cur != next] # Remove bounds for lowest and highest partition bounds[0] = bounds[-1] = None return list(zip(bounds[:-1], bounds[1:])) def db_first_index(self): """Look in the db for the smallest indexable value Returns: the smallest indexable value of all repos in the db """ t = self.db_session.query(func.min(self.MODEL.indexable)).first() if t: return t[0] + return None def db_last_index(self): """Look in the db for the largest indexable value Returns: the largest indexable value of all repos in the db """ t = self.db_session.query(func.max(self.MODEL.indexable)).first() if t: return t[0] + return None - def disable_deleted_repo_tasks(self, start, end, keep_these): + def disable_deleted_repo_tasks( + self, start, end, keep_these): """Disable tasks for repos that no longer exist between start and end. Args: start: beginning of range to disable end: end of range to disable keep_these (uid list): do not disable repos with uids in this list """ if end is None: end = self.db_last_index() if not self.is_within_bounds(end, None, self.max_index): end = self.max_index deleted_repos = self.winnow_models( self.db_query_range(start, end), self.MODEL.uid, keep_these ) tasks_to_disable = [repo.task_id for repo in deleted_repos if repo.task_id is not None] if tasks_to_disable: self.scheduler.disable_tasks(tasks_to_disable) for repo in deleted_repos: repo.task_id = None def run(self, min_bound=None, max_bound=None): """Main entry function. Sequentially fetches repository data from the service according to the basic outline in the class docstring, continually fetching sublists until either there is no next index reference given or the given next index is greater than the desired max_bound. Args: min_bound (indexable type): optional index to start from max_bound (indexable type): optional index to stop at Returns: nothing """ status = 'uneventful' self.min_index = min_bound self.max_index = max_bound def ingest_indexes(): index = min_bound or self.default_min_bound for i in count(1): response, injected_repos = self.ingest_data(index) if not response and not injected_repos: logger.info('No response from api server, stopping') return next_index = self.get_next_target_from_response(response) # Determine if any repos were deleted, and disable their tasks. keep_these = list(injected_repos.keys()) self.disable_deleted_repo_tasks(index, next_index, keep_these) # termination condition if next_index is None or next_index == index: logger.info('stopping after index %s, no next link found', index) return index = next_index logger.debug('Index: %s', index) yield i for i in ingest_indexes(): if (i % self.flush_packet_db) == 0: logger.debug('Flushing updates at index %s', i) self.db_session.commit() self.db_session = self.mk_session() status = 'eventful' self.db_session.commit() self.db_session = self.mk_session() return {'status': status} class IndexingHttpLister(ListerHttpTransport, IndexingLister): """Convenience class for ensuring right lookup and init order when combining IndexingLister and ListerHttpTransport.""" + def __init__(self, url=None, override_config=None): IndexingLister.__init__(self, override_config=override_config) ListerHttpTransport.__init__(self, url=url) diff --git a/swh/lister/core/lister_base.py b/swh/lister/core/lister_base.py index a92aa93..5c02642 100644 --- a/swh/lister/core/lister_base.py +++ b/swh/lister/core/lister_base.py @@ -1,529 +1,535 @@ # Copyright (C) 2015-2020 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import abc import datetime import gzip import json import logging import os import re import time from sqlalchemy import create_engine, func from sqlalchemy.orm import sessionmaker -from typing import Any, Dict, List, Type, Union +from typing import Any, Dict, List, Type, Union, Optional from swh.core import config from swh.core.utils import grouper from swh.scheduler import get_scheduler, utils from .abstractattribute import AbstractAttribute +from requests import Response logger = logging.getLogger(__name__) def utcnow(): return datetime.datetime.now(tz=datetime.timezone.utc) class FetchError(RuntimeError): def __init__(self, response): self.response = response def __str__(self): return repr(self.response) class ListerBase(abc.ABC, config.SWHConfig): """Lister core base class. Generally a source code hosting service provides an API endpoint for listing the set of stored repositories. A Lister is the discovery service responsible for finding this list, all at once or sequentially by parts, and queueing local tasks to fetch and ingest the referenced repositories. The core method in this class is ingest_data. Any subclasses should be calling this method one or more times to fetch and ingest data from API endpoints. See swh.lister.core.lister_base.IndexingLister for example usage. This class cannot be instantiated. Any instantiable Lister descending from ListerBase must provide at least the required overrides. (see member docstrings for details): Required Overrides: MODEL def transport_request def transport_response_to_string def transport_response_simplified def transport_quota_check Optional Overrides: def filter_before_inject def is_within_bounds """ MODEL = AbstractAttribute( 'Subclass type (not instance) of swh.lister.core.models.ModelBase ' 'customized for a specific service.' ) # type: Union[AbstractAttribute, Type[Any]] LISTER_NAME = AbstractAttribute( "Lister's name") # type: Union[AbstractAttribute, str] def transport_request(self, identifier): """Given a target endpoint identifier to query, try once to request it. Implementation of this method determines the network request protocol. Args: identifier (string): unique identifier for an endpoint query. e.g. If the service indexes lists of repositories by date and time of creation, this might be that as a formatted string. Or it might be an integer UID. Or it might be nothing. It depends on what the service needs. Returns: the entire request response Raises: Will catch internal transport-dependent connection exceptions and raise swh.lister.core.lister_base.FetchError instead. Other non-connection exceptions should propagate unchanged. """ pass def transport_response_to_string(self, response): """Convert the server response into a formatted string for logging. Implementation of this method depends on the shape of the network response object returned by the transport_request method. Args: response: the server response Returns: a pretty string of the response """ pass def transport_response_simplified(self, response): """Convert the server response into list of a dict for each repo in the response, mapping columns in the lister's MODEL class to repo data. Implementation of this method depends on the server API spec and the shape of the network response object returned by the transport_request method. Args: response: response object from the server. Returns: list of repo MODEL dicts ( eg. [{'uid': r['id'], etc.} for r in response.json()] ) """ pass def transport_quota_check(self, response): """Check server response to see if we're hitting request rate limits. Implementation of this method depends on the server communication protocol and API spec and the shape of the network response object returned by the transport_request method. Args: response (session response): complete API query response Returns: 1) must retry request? True/False 2) seconds to delay if True """ pass - def filter_before_inject(self, models_list): + def filter_before_inject( + self, models_list: List[Dict]) -> List[Dict]: """Filter models_list entries prior to injection in the db. This is ran directly after `transport_response_simplified`. Default implementation is to have no filtering. Args: models_list: list of dicts returned by transport_response_simplified. Returns: models_list with entries changed according to custom logic. """ return models_list - def do_additional_checks(self, models_list): + def do_additional_checks( + self, models_list: List[Dict]) -> List[Dict]: """Execute some additional checks on the model list (after the filtering). Default implementation is to run no check at all and to return the input as is. Args: models_list: list of dicts returned by transport_response_simplified. Returns: models_list with entries if checks ok, False otherwise """ return models_list - def is_within_bounds(self, inner, lower=None, upper=None): + def is_within_bounds( + self, inner: int, + lower: Optional[int] = None, upper: Optional[int] = None) -> bool: """See if a sortable value is inside the range [lower,upper]. MAY BE OVERRIDDEN, for example if the server indexable* key is technically sortable but not automatically so. * - ( see: swh.lister.core.indexing_lister.IndexingLister ) Args: inner (sortable type): the value being checked lower (sortable type): optional lower bound upper (sortable type): optional upper bound Returns: whether inner is confined by the optional lower and upper bounds """ try: if lower is None and upper is None: return True elif lower is None: - ret = inner <= upper + ret = inner <= upper # type: ignore elif upper is None: ret = inner >= lower else: ret = lower <= inner <= upper self.string_pattern_check(inner, lower, upper) except Exception as e: logger.error(str(e) + ': %s, %s, %s' % (('inner=%s%s' % (type(inner), inner)), ('lower=%s%s' % (type(lower), lower)), ('upper=%s%s' % (type(upper), upper))) ) raise return ret # You probably don't need to override anything below this line. DEFAULT_CONFIG = { 'scheduler': ('dict', { 'cls': 'remote', 'args': { 'url': 'http://localhost:5008/' }, }), 'lister': ('dict', { 'cls': 'local', 'args': { 'db': 'postgresql:///lister', }, }), } @property def CONFIG_BASE_FILENAME(self): # noqa: N802 return 'lister_%s' % self.LISTER_NAME @property def ADDITIONAL_CONFIG(self): # noqa: N802 return { 'credentials': ('dict', {}), 'cache_responses': ('bool', False), 'cache_dir': ('str', '~/.cache/swh/lister/%s' % self.LISTER_NAME), } INITIAL_BACKOFF = 10 MAX_RETRIES = 7 CONN_SLEEP = 10 def __init__(self, override_config=None): self.backoff = self.INITIAL_BACKOFF logger.debug('Loading config from %s' % self.CONFIG_BASE_FILENAME) self.config = self.parse_config_file( base_filename=self.CONFIG_BASE_FILENAME, additional_configs=[self.ADDITIONAL_CONFIG] ) self.config['cache_dir'] = os.path.expanduser(self.config['cache_dir']) if self.config['cache_responses']: config.prepare_folders(self.config, 'cache_dir') if override_config: self.config.update(override_config) logger.debug('%s CONFIG=%s' % (self, self.config)) self.scheduler = get_scheduler(**self.config['scheduler']) self.db_engine = create_engine(self.config['lister']['args']['db']) self.mk_session = sessionmaker(bind=self.db_engine) self.db_session = self.mk_session() def reset_backoff(self): """Reset exponential backoff timeout to initial level.""" self.backoff = self.INITIAL_BACKOFF - def back_off(self): + def back_off(self) -> int: """Get next exponential backoff timeout.""" ret = self.backoff self.backoff *= 10 return ret - def safely_issue_request(self, identifier): + def safely_issue_request(self, identifier: int) -> Optional[Response]: """Make network request with retries, rate quotas, and response logs. Protocol is handled by the implementation of the transport_request method. Args: identifier: resource identifier Returns: server response """ retries_left = self.MAX_RETRIES do_cache = self.config['cache_responses'] r = None while retries_left > 0: try: r = self.transport_request(identifier) except FetchError: # network-level connection error, try again logger.warning( 'connection error on %s: sleep for %d seconds' % (identifier, self.CONN_SLEEP)) time.sleep(self.CONN_SLEEP) retries_left -= 1 continue if do_cache: self.save_response(r) # detect throttling must_retry, delay = self.transport_quota_check(r) if must_retry: logger.warning( 'rate limited on %s: sleep for %f seconds' % (identifier, delay)) time.sleep(delay) else: # request ok break retries_left -= 1 if not retries_left: logger.warning( 'giving up on %s: max retries exceeded' % identifier) return r - def db_query_equal(self, key, value): + def db_query_equal(self, key: Any, value: Any): """Look in the db for a row with key == value Args: key: column key to look at value: value to look for in that column Returns: sqlalchemy.ext.declarative.declarative_base object with the given key == value """ if isinstance(key, str): key = self.MODEL.__dict__[key] return self.db_session.query(self.MODEL) \ .filter(key == value).first() def winnow_models(self, mlist, key, to_remove): """Given a list of models, remove any with matching some member of a list of values. Args: mlist (list of model rows): the initial list of models key (column): the column to filter on to_remove (list): if anything in mlist has column equal to one of the values in to_remove, it will be removed from the result Returns: A list of model rows starting from mlist minus any matching rows """ if isinstance(key, str): key = self.MODEL.__dict__[key] if to_remove: return mlist.filter(~key.in_(to_remove)).all() else: return mlist.all() def db_num_entries(self): """Return the known number of entries in the lister db""" return self.db_session.query(func.count('*')).select_from(self.MODEL) \ .scalar() def db_inject_repo(self, model_dict): """Add/update a new repo to the db and mark it last_seen now. Args: model_dict: dictionary mapping model keys to values Returns: new or updated sqlalchemy.ext.declarative.declarative_base object associated with the injection """ sql_repo = self.db_query_equal('uid', model_dict['uid']) if not sql_repo: sql_repo = self.MODEL(**model_dict) self.db_session.add(sql_repo) else: for k in model_dict: setattr(sql_repo, k, model_dict[k]) sql_repo.last_seen = utcnow() return sql_repo def task_dict(self, origin_type: str, origin_url: str, **kwargs) -> Dict[str, Any]: """Return special dict format for the tasks list Args: origin_type (string) origin_url (string) Returns: the same information in a different form """ logger.debug('origin-url: %s, type: %s', origin_url, origin_type) _type = 'load-%s' % origin_type _policy = kwargs.get('policy', 'recurring') priority = kwargs.get('priority') kw = {'priority': priority} if priority else {} return utils.create_task_dict(_type, _policy, url=origin_url, **kw) def string_pattern_check(self, a, b, c=None): """When comparing indexable types in is_within_bounds, complex strings may not be allowed to differ in basic structure. If they do, it could be a sign of not understanding the data well. For instance, an ISO 8601 time string cannot be compared against its urlencoded equivalent, but this is an easy mistake to accidentally make. This method acts as a friendly sanity check. Args: a (string): inner component of the is_within_bounds method b (string): lower component of the is_within_bounds method c (string): upper component of the is_within_bounds method Returns: nothing Raises: TypeError if strings a, b, and c don't conform to the same basic pattern. """ if isinstance(a, str): a_pattern = re.sub('[a-zA-Z0-9]', '[a-zA-Z0-9]', re.escape(a)) if (isinstance(b, str) and (re.match(a_pattern, b) is None) - or isinstance(c, str) and (re.match(a_pattern, c) is None)): + or isinstance(c, str) and + (re.match(a_pattern, c) is None)): logger.debug(a_pattern) raise TypeError('incomparable string patterns detected') def inject_repo_data_into_db(self, models_list: List[Dict]) -> Dict: """Inject data into the db. Args: models_list: list of dicts mapping keys from the db model for each repo to be injected Returns: dict of uid:sql_repo pairs """ injected_repos = {} for m in models_list: injected_repos[m['uid']] = self.db_inject_repo(m) return injected_repos def schedule_missing_tasks( self, models_list: List[Dict], injected_repos: Dict) -> None: """Schedule any newly created db entries that do not have been scheduled yet. Args: models_list: List of dicts mapping keys in the db model for each repo injected_repos: Dict of uid:sql_repo pairs that have just been created Returns: Nothing. (Note that it Modifies injected_repos to set the new task_id). """ tasks = {} def _task_key(m): return '%s-%s' % ( m['type'], json.dumps(m['arguments'], sort_keys=True) ) for m in models_list: ir = injected_repos[m['uid']] if not ir.task_id: # Patching the model instance to add the policy/priority task # scheduling if 'policy' in self.config: m['policy'] = self.config['policy'] if 'priority' in self.config: m['priority'] = self.config['priority'] task_dict = self.task_dict(**m) tasks[_task_key(task_dict)] = (ir, m, task_dict) gen_tasks = (task_dicts for (_, _, task_dicts) in tasks.values()) for grouped_tasks in grouper(gen_tasks, n=1000): new_tasks = self.scheduler.create_tasks(list(grouped_tasks)) for task in new_tasks: ir, m, _ = tasks[_task_key(task)] ir.task_id = task['id'] - def ingest_data(self, identifier, checks=False): + def ingest_data(self, identifier: int, checks: bool = False): """The core data fetch sequence. Request server endpoint. Simplify and filter response list of repositories. Inject repo information into local db. Queue loader tasks for linked repositories. Args: identifier: Resource identifier. checks (bool): Additional checks required """ # Request (partial?) list of repositories info response = self.safely_issue_request(identifier) if not response: return response, [] models_list = self.transport_response_simplified(response) models_list = self.filter_before_inject(models_list) if checks: models_list = self.do_additional_checks(models_list) if not models_list: return response, [] # inject into local db injected = self.inject_repo_data_into_db(models_list) # queue workers self.schedule_missing_tasks(models_list, injected) return response, injected def save_response(self, response): """Log the response from a server request to a cache dir. Args: response: full server response cache_dir: system path for cache dir Returns: nothing """ datepath = utcnow().isoformat() fname = os.path.join( self.config['cache_dir'], datepath + '.gz', ) with gzip.open(fname, 'w') as f: f.write(bytes( self.transport_response_to_string(response), 'UTF-8' )) diff --git a/swh/lister/core/lister_transports.py b/swh/lister/core/lister_transports.py index a857167..8c7ceb3 100644 --- a/swh/lister/core/lister_transports.py +++ b/swh/lister/core/lister_transports.py @@ -1,232 +1,235 @@ # Copyright (C) 2017-2018 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import abc import random from datetime import datetime from email.utils import parsedate from pprint import pformat import logging import requests import xmltodict -from typing import Optional, Union +from typing import Optional, Union, Dict, Any +from requests import Response from swh.lister import USER_AGENT_TEMPLATE, __version__ from .abstractattribute import AbstractAttribute from .lister_base import FetchError logger = logging.getLogger(__name__) class ListerHttpTransport(abc.ABC): """Use the Requests library for making Lister endpoint requests. To be used in conjunction with ListerBase or a subclass of it. """ DEFAULT_URL = None # type: Optional[str] PATH_TEMPLATE = \ AbstractAttribute( 'string containing a python string format pattern that produces' ' the API endpoint path for listing stored repositories when given' ' an index, e.g., "/repositories?after=%s". To be implemented in' ' the API-specific class inheriting this.' ) # type: Union[AbstractAttribute, Optional[str]] EXPECTED_STATUS_CODES = (200, 429, 403, 404) - def request_headers(self): + def request_headers(self) -> Dict[str, Any]: """Returns dictionary of any request headers needed by the server. MAY BE OVERRIDDEN if request headers are needed. """ return { 'User-Agent': USER_AGENT_TEMPLATE % self.lister_version } def request_instance_credentials(self): """Returns dictionary of any credentials configuration needed by the forge instance to list. The 'credentials' configuration is expected to be a dict of multiple levels. The first level is the lister's name, the second is the lister's instance name, which value is expected to be a list of credential structures (typically a couple username/password). For example:: credentials: github: # github lister github: # has only one instance (so far) - username: some password: somekey - username: one password: onekey - ... gitlab: # gitlab lister riseup: # has many instances - username: someone password: ... - ... gitlab: - username: someone password: ... - ... Returns: list of credential dicts for the current lister. """ all_creds = self.config.get('credentials') if not all_creds: return [] lister_creds = all_creds.get(self.LISTER_NAME, {}) creds = lister_creds.get(self.instance, []) return creds def request_uri(self, identifier): """Get the full request URI given the transport_request identifier. MAY BE OVERRIDDEN if something more complex than the PATH_TEMPLATE is required. """ path = self.PATH_TEMPLATE % identifier return self.url + path - def request_params(self, identifier): + def request_params(self, identifier: int) -> Dict[str, Any]: """Get the full parameters passed to requests given the transport_request identifier. This uses credentials if any are provided (see request_instance_credentials). MAY BE OVERRIDDEN if something more complex than the request headers is needed. """ params = {} params['headers'] = self.request_headers() or {} creds = self.request_instance_credentials() if not creds: return params auth = random.choice(creds) if creds else None if auth: - params['auth'] = (auth['username'], auth['password']) + params['auth'] = (auth['username'], # type: ignore + auth['password']) return params def transport_quota_check(self, response): """Implements ListerBase.transport_quota_check with standard 429 code check for HTTP with Requests library. MAY BE OVERRIDDEN if the server notifies about rate limits in a non-standard way that doesn't use HTTP 429 and the Retry-After response header. ( https://tools.ietf.org/html/rfc6585#section-4 ) """ if response.status_code == 429: # HTTP too many requests retry_after = response.headers.get('Retry-After', self.back_off()) try: # might be seconds return True, float(retry_after) except Exception: # might be http-date at_date = datetime(*parsedate(retry_after)[:6]) from_now = (at_date - datetime.today()).total_seconds() + 5 return True, max(0, from_now) else: # response ok self.reset_backoff() return False, 0 def __init__(self, url=None): if not url: url = self.config.get('url') if not url: url = self.DEFAULT_URL if not url: raise NameError('HTTP Lister Transport requires an url.') self.url = url # eg. 'https://api.github.com' self.session = requests.Session() self.lister_version = __version__ - def _transport_action(self, identifier, method='get'): + def _transport_action( + self, identifier: int, method: str = 'get') -> Response: """Permit to ask information to the api prior to actually executing query. """ path = self.request_uri(identifier) params = self.request_params(identifier) logger.debug('path: %s', path) logger.debug('params: %s', params) logger.debug('method: %s', method) try: if method == 'head': response = self.session.head(path, **params) else: response = self.session.get(path, **params) except requests.exceptions.ConnectionError as e: logger.warning('Failed to fetch %s: %s', path, e) raise FetchError(e) else: if response.status_code not in self.EXPECTED_STATUS_CODES: raise FetchError(response) return response - def transport_head(self, identifier): + def transport_head(self, identifier: int) -> Response: """Retrieve head information on api. """ return self._transport_action(identifier, method='head') - def transport_request(self, identifier): + def transport_request(self, identifier: int) -> Response: """Implements ListerBase.transport_request for HTTP using Requests. Retrieve get information on api. """ return self._transport_action(identifier) - def transport_response_to_string(self, response): + def transport_response_to_string(self, response: Response) -> str: """Implements ListerBase.transport_response_to_string for HTTP given Requests responses. """ s = pformat(response.request.path_url) s += '\n#\n' + pformat(response.request.headers) s += '\n#\n' + pformat(response.status_code) s += '\n#\n' + pformat(response.headers) s += '\n#\n' try: # json? s += pformat(response.json()) except Exception: # not json try: # xml? s += pformat(xmltodict.parse(response.text)) except Exception: # not xml s += pformat(response.text) return s class ListerOnePageApiTransport(ListerHttpTransport): """Leverage requests library to retrieve basic html page and parse result. To be used in conjunction with ListerBase or a subclass of it. """ PAGE = AbstractAttribute( "URL of the API's unique page to retrieve and parse " "for information") # type: Union[AbstractAttribute, str] PATH_TEMPLATE = None # we do not use it def __init__(self, url=None): self.session = requests.Session() self.lister_version = __version__ def request_uri(self, _): """Get the full request URI given the transport_request identifier. """ return self.PAGE diff --git a/swh/lister/debian/lister.py b/swh/lister/debian/lister.py index ef9853b..b5c4c50 100644 --- a/swh/lister/debian/lister.py +++ b/swh/lister/debian/lister.py @@ -1,255 +1,256 @@ # Copyright (C) 2017-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import bz2 from collections import defaultdict import datetime import gzip import lzma import logging from debian.deb822 import Sources from sqlalchemy.orm import joinedload, load_only from sqlalchemy.schema import CreateTable, DropTable -from typing import Mapping, Optional +from typing import Mapping, Optional, Dict, Any +from requests import Response from swh.lister.debian.models import ( AreaSnapshot, Distribution, DistributionSnapshot, Package, TempPackage, ) from swh.lister.core.lister_base import ListerBase, FetchError from swh.lister.core.lister_transports import ListerHttpTransport decompressors = { 'gz': lambda f: gzip.GzipFile(fileobj=f), 'bz2': bz2.BZ2File, 'xz': lzma.LZMAFile, } logger = logging.getLogger(__name__) class DebianLister(ListerHttpTransport, ListerBase): MODEL = Package PATH_TEMPLATE = None LISTER_NAME = 'debian' instance = 'debian' def __init__(self, distribution: str = 'Debian', date: Optional[datetime.datetime] = None, override_config: Mapping = {}): """Initialize the debian lister for a given distribution at a given date. Args: distribution: name of the distribution (e.g. "Debian") date: date the snapshot is taken (defaults to now if empty) override_config: Override configuration (which takes precedence over the parameters if provided) """ ListerHttpTransport.__init__(self, url="notused") ListerBase.__init__(self, override_config=override_config) self.distribution = override_config.get('distribution', distribution) self.date = override_config.get('date', date) or datetime.datetime.now( tz=datetime.timezone.utc) - def transport_request(self, identifier): + def transport_request(self, identifier) -> Response: """Subvert ListerHttpTransport.transport_request, to try several index URIs in turn. The Debian repository format supports several compression algorithms across the ages, so we try several URIs. Once we have found a working URI, we break and set `self.decompressor` to the one that matched. Returns: a requests Response object. Raises: FetchError: when all the URIs failed to be retrieved. """ response = None compression = None for uri, compression in self.area.index_uris(): response = super().transport_request(uri) if response.status_code == 200: break else: raise FetchError( "Could not retrieve index for %s" % self.area ) self.decompressor = decompressors.get(compression) return response def request_uri(self, identifier): # In the overridden transport_request, we pass # ListerBase.transport_request() the full URI as identifier, so we # need to return it here. return identifier - def request_params(self, identifier): + def request_params(self, identifier) -> Dict[str, Any]: # Enable streaming to allow wrapping the response in the decompressor # in transport_response_simplified. params = super().request_params(identifier) params['stream'] = True return params def transport_response_simplified(self, response): """Decompress and parse the package index fetched in `transport_request`. For each package, we "pivot" the file list entries (Files, Checksums-Sha1, Checksums-Sha256), to return a files dict mapping filenames to their checksums. """ if self.decompressor: data = self.decompressor(response.raw) else: data = response.raw for src_pkg in Sources.iter_paragraphs(data.readlines()): files = defaultdict(dict) for field in src_pkg._multivalued_fields: if field.startswith('checksums-'): sum_name = field[len('checksums-'):] else: sum_name = 'md5sum' if field in src_pkg: for entry in src_pkg[field]: name = entry['name'] files[name]['name'] = entry['name'] files[name]['size'] = int(entry['size'], 10) files[name][sum_name] = entry[sum_name] yield { 'name': src_pkg['Package'], 'version': src_pkg['Version'], 'directory': src_pkg['Directory'], 'files': files, } def inject_repo_data_into_db(self, models_list): """Generate the Package entries that didn't previously exist. Contrary to ListerBase, we don't actually insert the data in database. `schedule_missing_tasks` does it once we have the origin and task identifiers. """ by_name_version = {} temp_packages = [] area_id = self.area.id for model in models_list: name = model['name'] version = model['version'] temp_packages.append({ 'area_id': area_id, 'name': name, 'version': version, }) by_name_version[name, version] = model # Add all the listed packages to a temporary table self.db_session.execute(CreateTable(TempPackage.__table__)) self.db_session.bulk_insert_mappings(TempPackage, temp_packages) def exists_tmp_pkg(db_session, model): return ( db_session.query(model) .filter(Package.area_id == TempPackage.area_id) .filter(Package.name == TempPackage.name) .filter(Package.version == TempPackage.version) .exists() ) # Filter out the packages that already exist in the main Package table new_packages = self.db_session\ .query(TempPackage)\ .options(load_only('name', 'version'))\ .filter(~exists_tmp_pkg(self.db_session, Package))\ .all() self.old_area_packages = self.db_session.query(Package).filter( exists_tmp_pkg(self.db_session, TempPackage) ).all() self.db_session.execute(DropTable(TempPackage.__table__)) added_packages = [] for package in new_packages: model = by_name_version[package.name, package.version] added_packages.append(Package(area=self.area, **model)) self.db_session.add_all(added_packages) return added_packages def schedule_missing_tasks(self, models_list, added_packages): """We create tasks at the end of the full snapshot processing""" return def create_tasks_for_snapshot(self, snapshot): tasks = [ snapshot.task_for_package(name, versions) for name, versions in snapshot.get_packages().items() ] return self.scheduler.create_tasks(tasks) def run(self): """Run the lister for a given (distribution, area) tuple. """ distribution = self.db_session\ .query(Distribution)\ .options(joinedload(Distribution.areas))\ .filter(Distribution.name == self.distribution)\ .one_or_none() if not distribution: logger.error("Distribution %s is not registered" % self.distribution) return {'status': 'failed'} if not distribution.type == 'deb': logger.error("Distribution %s is not a Debian derivative" % distribution) return {'status': 'failed'} date = self.date logger.debug('Creating snapshot for distribution %s on date %s' % (distribution, date)) snapshot = DistributionSnapshot(date=date, distribution=distribution) self.db_session.add(snapshot) for area in distribution.areas: if not area.active: continue self.area = area logger.debug('Processing area %s' % area) _, new_area_packages = self.ingest_data(None) area_snapshot = AreaSnapshot(snapshot=snapshot, area=area) self.db_session.add(area_snapshot) area_snapshot.packages.extend(new_area_packages) area_snapshot.packages.extend(self.old_area_packages) self.create_tasks_for_snapshot(snapshot) self.db_session.commit() return {'status': 'eventful'} diff --git a/swh/lister/github/lister.py b/swh/lister/github/lister.py index 8916b39..066b884 100644 --- a/swh/lister/github/lister.py +++ b/swh/lister/github/lister.py @@ -1,72 +1,79 @@ # Copyright (C) 2017-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import re -from typing import Any +from typing import Any, Dict, List, Tuple, Optional from swh.lister.core.indexing_lister import IndexingHttpLister from swh.lister.github.models import GitHubModel +from requests import Response + class GitHubLister(IndexingHttpLister): PATH_TEMPLATE = '/repositories?since=%d' MODEL = GitHubModel DEFAULT_URL = 'https://api.github.com' API_URL_INDEX_RE = re.compile(r'^.*/repositories\?since=(\d+)') LISTER_NAME = 'github' instance = 'github' # There is only 1 instance of such lister default_min_bound = 0 # type: Any - def get_model_from_repo(self, repo): + def get_model_from_repo(self, repo: Dict[str, Any]) -> Dict[str, Any]: return { 'uid': repo['id'], 'indexable': repo['id'], 'name': repo['name'], 'full_name': repo['full_name'], 'html_url': repo['html_url'], 'origin_url': repo['html_url'], 'origin_type': 'git', 'fork': repo['fork'], } - def transport_quota_check(self, response): + def transport_quota_check(self, response: Response) -> Tuple[bool, int]: x_rate_limit_remaining = response.headers.get('X-RateLimit-Remaining') if not x_rate_limit_remaining: return False, 0 reqs_remaining = int(x_rate_limit_remaining) if response.status_code == 403 and reqs_remaining == 0: delay = int(response.headers['Retry-After']) return True, delay return False, 0 - def get_next_target_from_response(self, response): + def get_next_target_from_response(self, + response: Response) -> Optional[int]: if 'next' in response.links: next_url = response.links['next']['url'] - return int(self.API_URL_INDEX_RE.match(next_url).group(1)) + return int( + self.API_URL_INDEX_RE.match(next_url).group(1)) # type: ignore + return None - def transport_response_simplified(self, response): + def transport_response_simplified(self, response: Response + ) -> List[Dict[str, Any]]: repos = response.json() return [self.get_model_from_repo(repo) for repo in repos if repo and 'id' in repo] - def request_headers(self): + def request_headers(self) -> Dict[str, Any]: """(Override) Set requests headers to send when querying the GitHub API """ headers = super().request_headers() headers['Accept'] = 'application/vnd.github.v3+json' return headers - def disable_deleted_repo_tasks(self, index, next_index, keep_these): + def disable_deleted_repo_tasks(self, index: int, + next_index: int, keep_these: int): """ (Overrides) Fix provided index value to avoid erroneously disabling some scheduler tasks """ # Next listed repository ids are strictly greater than the 'since' # parameter, so increment the index to avoid disabling the latest # created task when processing a new repositories page returned by # the Github API return super().disable_deleted_repo_tasks(index + 1, next_index, keep_these) diff --git a/swh/lister/gitlab/lister.py b/swh/lister/gitlab/lister.py index 6e01005..3d8d3d7 100644 --- a/swh/lister/gitlab/lister.py +++ b/swh/lister/gitlab/lister.py @@ -1,82 +1,91 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import time from urllib3.util import parse_url from ..core.page_by_page_lister import PageByPageHttpLister from .models import GitLabModel +from typing import Any, Dict, List, Tuple, Union, MutableMapping, Optional +from requests import Response + class GitLabLister(PageByPageHttpLister): # Template path expecting an integer that represents the page id PATH_TEMPLATE = '/projects?page=%d&order_by=id' DEFAULT_URL = 'https://gitlab.com/api/v4/' MODEL = GitLabModel LISTER_NAME = 'gitlab' def __init__(self, url=None, instance=None, override_config=None, sort='asc', per_page=20): super().__init__(url=url, override_config=override_config) if instance is None: instance = parse_url(self.url).host self.instance = instance self.PATH_TEMPLATE = '%s&sort=%s&per_page=%s' % ( self.PATH_TEMPLATE, sort, per_page) - def uid(self, repo): + def uid(self, repo: Dict[str, Any]) -> str: return '%s/%s' % (self.instance, repo['path_with_namespace']) - def get_model_from_repo(self, repo): + def get_model_from_repo(self, repo: Dict[str, Any]) -> Dict[str, Any]: return { 'instance': self.instance, 'uid': self.uid(repo), 'name': repo['name'], 'full_name': repo['path_with_namespace'], 'html_url': repo['web_url'], 'origin_url': repo['http_url_to_repo'], 'origin_type': 'git', } - def transport_quota_check(self, response): + def transport_quota_check(self, response: Response + ) -> Tuple[bool, Union[int, float]]: """Deal with rate limit if any. """ # not all gitlab instance have rate limit if 'RateLimit-Remaining' in response.headers: reqs_remaining = int(response.headers['RateLimit-Remaining']) if response.status_code == 403 and reqs_remaining == 0: reset_at = int(response.headers['RateLimit-Reset']) delay = min(reset_at - time.time(), 3600) return True, delay return False, 0 - def _get_int(self, headers, key): + def _get_int(self, headers: MutableMapping[str, Any], + key: str) -> Optional[int]: _val = headers.get(key) if _val: return int(_val) + return None - def get_next_target_from_response(self, response): + def get_next_target_from_response( + self, response: Response) -> Optional[int]: """Determine the next page identifier. """ return self._get_int(response.headers, 'x-next-page') - def get_pages_information(self): + def get_pages_information(self) -> Tuple[Optional[int], + Optional[int], Optional[int]]: """Determine pages information. """ response = self.transport_head(identifier=1) if not response.ok: raise ValueError( 'Problem during information fetch: %s' % response.status_code) h = response.headers return (self._get_int(h, 'x-total'), self._get_int(h, 'x-total-pages'), self._get_int(h, 'x-per-page')) - def transport_response_simplified(self, response): + def transport_response_simplified(self, response: Response + ) -> List[Dict[str, Any]]: repos = response.json() return [self.get_model_from_repo(repo) for repo in repos] diff --git a/swh/lister/gnu/lister.py b/swh/lister/gnu/lister.py index 3c00573..1f41b0a 100644 --- a/swh/lister/gnu/lister.py +++ b/swh/lister/gnu/lister.py @@ -1,111 +1,113 @@ # Copyright (C) 2019 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging from swh.scheduler import utils from swh.lister.core.simple_lister import SimpleLister from swh.lister.gnu.models import GNUModel from swh.lister.gnu.tree import GNUTree +from typing import Any, Dict, List +from requests import Response logger = logging.getLogger(__name__) class GNULister(SimpleLister): MODEL = GNUModel LISTER_NAME = 'gnu' instance = 'gnu' def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.gnu_tree = GNUTree('https://ftp.gnu.org/tree.json.gz') def task_dict(self, origin_type, origin_url, **kwargs): """Return task format dict This is overridden from the lister_base as more information is needed for the ingestion task creation. This creates tasks with args and kwargs set, for example: .. code-block:: python args: kwargs: { 'url': 'https://ftp.gnu.org/gnu/3dldf/', 'artifacts': [{ 'url': 'https://...', 'time': '2003-12-09T21:43:20+00:00', 'length': 128, 'version': '1.0.1', 'filename': 'something-1.0.1.tar.gz', }, ... ] } """ artifacts = self.gnu_tree.artifacts[origin_url] assert origin_type == 'tar' return utils.create_task_dict( 'load-archive-files', kwargs.get('policy', 'oneshot'), url=origin_url, artifacts=artifacts, retries_left=3, ) - def safely_issue_request(self, identifier): + def safely_issue_request(self, identifier: int) -> None: """Bypass the implementation. It's now the GNUTree which deals with querying the gnu mirror. As an implementation detail, we cannot change simply the base SimpleLister as other implementation still uses it. This shall be part of another refactoring pass. """ return None - def list_packages(self, response): + def list_packages(self, response: Response) -> List[Dict[str, Any]]: """List the actual gnu origins (package name) with their name, url and associated tarballs. Args: response: Unused Returns: List of packages name, url, last modification time:: [ { 'name': '3dldf', 'url': 'https://ftp.gnu.org/gnu/3dldf/', 'time_modified': '2003-12-09T20:43:20+00:00' }, { 'name': '8sync', 'url': 'https://ftp.gnu.org/gnu/8sync/', 'time_modified': '2016-12-06T02:37:10+00:00' }, ... ] """ return list(self.gnu_tree.projects.values()) - def get_model_from_repo(self, repo): + def get_model_from_repo(self, repo: Dict[str, Any]) -> Dict[str, Any]: """Transform from repository representation to model """ return { 'uid': repo['url'], 'name': repo['name'], 'full_name': repo['name'], 'html_url': repo['url'], 'origin_url': repo['url'], 'time_last_updated': repo['time_modified'], 'origin_type': 'tar', } diff --git a/swh/lister/npm/lister.py b/swh/lister/npm/lister.py index 8560e66..5214032 100644 --- a/swh/lister/npm/lister.py +++ b/swh/lister/npm/lister.py @@ -1,150 +1,157 @@ # Copyright (C) 2018-2019 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from swh.lister.core.indexing_lister import IndexingHttpLister from swh.lister.npm.models import NpmModel from swh.scheduler.utils import create_task_dict +from typing import Any, Dict, Optional, List +from requests import Response + class NpmListerBase(IndexingHttpLister): """List packages available in the npm registry in a paginated way """ MODEL = NpmModel LISTER_NAME = 'npm' instance = 'npm' def __init__(self, url='https://replicate.npmjs.com', per_page=1000, override_config=None): super().__init__(url=url, override_config=override_config) self.per_page = per_page + 1 self.PATH_TEMPLATE += '&limit=%s' % self.per_page @property - def ADDITIONAL_CONFIG(self): + def ADDITIONAL_CONFIG(self) -> Dict[str, Any]: """(Override) Add extra configuration """ default_config = super().ADDITIONAL_CONFIG default_config['loading_task_policy'] = ('str', 'recurring') return default_config - def get_model_from_repo(self, repo_name): + def get_model_from_repo(self, repo_name: str) -> Dict[str, str]: """(Override) Transform from npm package name to model """ package_url = 'https://www.npmjs.com/package/%s' % repo_name return { 'uid': repo_name, 'indexable': repo_name, 'name': repo_name, 'full_name': repo_name, 'html_url': package_url, 'origin_url': package_url, 'origin_type': 'npm', } - def task_dict(self, origin_type, origin_url, **kwargs): + def task_dict(self, origin_type: str, origin_url: str, **kwargs): """(Override) Return task dict for loading a npm package into the archive. This is overridden from the lister_base as more information is needed for the ingestion task creation. """ task_type = 'load-%s' % origin_type task_policy = self.config['loading_task_policy'] return create_task_dict(task_type, task_policy, url=origin_url) - def request_headers(self): + def request_headers(self) -> Dict[str, Any]: """(Override) Set requests headers to send when querying the npm registry. """ headers = super().request_headers() headers['Accept'] = 'application/json' return headers - def string_pattern_check(self, inner, lower, upper=None): + def string_pattern_check(self, inner: int, lower: int, upper: int = None): """ (Override) Inhibit the effect of that method as packages indices correspond to package names and thus do not respect any kind of fixed length string pattern """ pass class NpmLister(NpmListerBase): """List all packages available in the npm registry in a paginated way """ PATH_TEMPLATE = '/_all_docs?startkey="%s"' - def get_next_target_from_response(self, response): + def get_next_target_from_response( + self, response: Response) -> Optional[str]: """(Override) Get next npm package name to continue the listing """ repos = response.json()['rows'] return repos[-1]['id'] if len(repos) == self.per_page else None - def transport_response_simplified(self, response): + def transport_response_simplified( + self, response: Response) -> List[Dict[str, str]]: """(Override) Transform npm registry response to list for model manipulation """ repos = response.json()['rows'] if len(repos) == self.per_page: repos = repos[:-1] return [self.get_model_from_repo(repo['id']) for repo in repos] class NpmIncrementalLister(NpmListerBase): """List packages in the npm registry, updated since a specific update_seq value of the underlying CouchDB database, in a paginated way. """ PATH_TEMPLATE = '/_changes?since=%s' @property def CONFIG_BASE_FILENAME(self): # noqa: N802 return 'lister_npm_incremental' - def get_next_target_from_response(self, response): + def get_next_target_from_response( + self, response: Response) -> Optional[str]: """(Override) Get next npm package name to continue the listing. """ repos = response.json()['results'] return repos[-1]['seq'] if len(repos) == self.per_page else None - def transport_response_simplified(self, response): + def transport_response_simplified( + self, response: Response) -> List[Dict[str, str]]: """(Override) Transform npm registry response to list for model manipulation. """ repos = response.json()['results'] if len(repos) == self.per_page: repos = repos[:-1] return [self.get_model_from_repo(repo['id']) for repo in repos] - def filter_before_inject(self, models_list): + def filter_before_inject(self, models_list: List[Dict[str, Any]]): """(Override) Filter out documents in the CouchDB database not related to a npm package. """ models_filtered = [] for model in models_list: package_name = model['name'] # document related to CouchDB internals if package_name.startswith('_design/'): continue models_filtered.append(model) return models_filtered def disable_deleted_repo_tasks(self, start, end, keep_these): """(Override) Disable the processing performed by that method as it is not relevant in this incremental lister context. It also raises an exception due to a different index type (int instead of str). """ pass diff --git a/swh/lister/phabricator/lister.py b/swh/lister/phabricator/lister.py index f198a1b..a32a6c5 100644 --- a/swh/lister/phabricator/lister.py +++ b/swh/lister/phabricator/lister.py @@ -1,182 +1,190 @@ # Copyright (C) 2019 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import random import urllib.parse from collections import defaultdict from sqlalchemy import func from swh.lister.core.indexing_lister import IndexingHttpLister from swh.lister.phabricator.models import PhabricatorModel +from typing import Any, Dict, List, Optional +from requests import Response logger = logging.getLogger(__name__) class PhabricatorLister(IndexingHttpLister): PATH_TEMPLATE = '?order=oldest&attachments[uris]=1&after=%s' DEFAULT_URL = \ 'https://forge.softwareheritage.org/api/diffusion.repository.search' MODEL = PhabricatorModel LISTER_NAME = 'phabricator' def __init__(self, url=None, instance=None, override_config=None): super().__init__(url=url, override_config=override_config) if not instance: instance = urllib.parse.urlparse(self.url).hostname self.instance = instance - def request_params(self, identifier): + def request_params(self, identifier: int) -> Dict[str, Any]: """Override the default params behavior to retrieve the api token Credentials are stored as: credentials: phabricator: : - username: password: """ creds = self.request_instance_credentials() if not creds: raise ValueError( 'Phabricator forge needs authentication credential to list.') api_token = random.choice(creds)['password'] return {'headers': self.request_headers() or {}, 'params': {'api.token': api_token}} def request_headers(self): """ (Override) Set requests headers to send when querying the Phabricator API """ headers = super().request_headers() headers['Accept'] = 'application/json' return headers - def get_model_from_repo(self, repo): + def get_model_from_repo( + self, repo: Dict[str, Any]) -> Optional[Dict[str, Any]]: url = get_repo_url(repo['attachments']['uris']['uris']) if url is None: return None return { 'uid': url, 'indexable': repo['id'], 'name': repo['fields']['shortName'], 'full_name': repo['fields']['name'], 'html_url': url, 'origin_url': url, 'origin_type': repo['fields']['vcs'], 'instance': self.instance, } - def get_next_target_from_response(self, response): + def get_next_target_from_response( + self, response: Response) -> Optional[int]: body = response.json()['result']['cursor'] if body['after'] and body['after'] != 'null': return int(body['after']) + return None - def transport_response_simplified(self, response): + def transport_response_simplified( + self, response: Response) -> List[Optional[Dict[str, Any]]]: repos = response.json() if repos['result'] is None: raise ValueError( 'Problem during information fetch: %s' % repos['error_code']) repos = repos['result']['data'] return [self.get_model_from_repo(repo) for repo in repos] def filter_before_inject(self, models_list): """ (Overrides) IndexingLister.filter_before_inject Bounds query results by this Lister's set max_index. """ models_list = [m for m in models_list if m is not None] return super().filter_before_inject(models_list) - def disable_deleted_repo_tasks(self, index, next_index, keep_these): + def disable_deleted_repo_tasks( + self, index: int, next_index: int, keep_these: str): """ (Overrides) Fix provided index value to avoid: - database query error - erroneously disabling some scheduler tasks """ # First call to the Phabricator API uses an empty 'after' parameter, # so set the index to 0 to avoid database query error if index == '': index = 0 # Next listed repository ids are strictly greater than the 'after' # parameter, so increment the index to avoid disabling the latest # created task when processing a new repositories page returned by # the Phabricator API else: index = index + 1 return super().disable_deleted_repo_tasks(index, next_index, keep_these) - def db_first_index(self): + def db_first_index(self) -> Optional[int]: """ (Overrides) Filter results by Phabricator instance Returns: the smallest indexable value of all repos in the db """ t = self.db_session.query(func.min(self.MODEL.indexable)) t = t.filter(self.MODEL.instance == self.instance).first() if t: return t[0] + return None def db_last_index(self): """ (Overrides) Filter results by Phabricator instance Returns: the largest indexable value of all instance repos in the db """ t = self.db_session.query(func.max(self.MODEL.indexable)) t = t.filter(self.MODEL.instance == self.instance).first() if t: return t[0] - def db_query_range(self, start, end): + def db_query_range(self, start: int, end: int): """ (Overrides) Filter the results by the Phabricator instance to avoid disabling loading tasks for repositories hosted on a different instance. Returns: a list of sqlalchemy.ext.declarative.declarative_base objects with indexable values within the given range for the instance """ retlist = super().db_query_range(start, end) return retlist.filter(self.MODEL.instance == self.instance) -def get_repo_url(attachments): +def get_repo_url(attachments: List[Dict[str, Any]]) -> Optional[int]: """ Return url for a hosted repository from its uris attachments according to the following priority lists: * protocol: https > http * identifier: shortname > callsign > id """ - processed_urls = defaultdict(dict) + processed_urls = defaultdict(dict) # type: Dict[str, Any] for uri in attachments: protocol = uri['fields']['builtin']['protocol'] url = uri['fields']['uri']['effective'] identifier = uri['fields']['builtin']['identifier'] if protocol in ('http', 'https'): processed_urls[protocol][identifier] = url elif protocol is None: for protocol in ('https', 'http'): if url.startswith(protocol): processed_urls[protocol]['undefined'] = url break for protocol in ['https', 'http']: for identifier in ['shortname', 'callsign', 'id', 'undefined']: if (protocol in processed_urls and identifier in processed_urls[protocol]): return processed_urls[protocol][identifier] return None diff --git a/swh/lister/pypi/lister.py b/swh/lister/pypi/lister.py index 09a731d..0f22ae0 100644 --- a/swh/lister/pypi/lister.py +++ b/swh/lister/pypi/lister.py @@ -1,65 +1,68 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import random import xmltodict from .models import PyPIModel from swh.scheduler import utils from swh.lister.core.simple_lister import SimpleLister from swh.lister.core.lister_transports import ListerOnePageApiTransport +from typing import Any, Dict +from requests import Response + class PyPILister(ListerOnePageApiTransport, SimpleLister): MODEL = PyPIModel LISTER_NAME = 'pypi' PAGE = 'https://pypi.org/simple/' instance = 'pypi' # As of today only the main pypi.org is used def __init__(self, override_config=None): ListerOnePageApiTransport .__init__(self) SimpleLister.__init__(self, override_config=override_config) - def task_dict(self, origin_type, origin_url, **kwargs): + def task_dict(self, origin_type: str, origin_url: str, **kwargs): """(Override) Return task format dict This is overridden from the lister_base as more information is needed for the ingestion task creation. """ _type = 'load-%s' % origin_type _policy = kwargs.get('policy', 'recurring') return utils.create_task_dict( _type, _policy, url=origin_url) - def list_packages(self, response): + def list_packages(self, response: Response) -> list: """(Override) List the actual pypi origins from the response. """ result = xmltodict.parse(response.content) _packages = [p['#text'] for p in result['html']['body']['a']] random.shuffle(_packages) return _packages def origin_url(self, repo_name: str) -> str: """Returns origin_url """ return 'https://pypi.org/project/%s/' % repo_name - def get_model_from_repo(self, repo_name): + def get_model_from_repo(self, repo_name: str) -> Dict[str, Any]: """(Override) Transform from repository representation to model """ origin_url = self.origin_url(repo_name) return { 'uid': origin_url, 'name': repo_name, 'full_name': repo_name, 'html_url': origin_url, 'origin_url': origin_url, 'origin_type': 'pypi', }