diff --git a/swh/lister/bitbucket/lister.py b/swh/lister/bitbucket/lister.py --- a/swh/lister/bitbucket/lister.py +++ b/swh/lister/bitbucket/lister.py @@ -7,9 +7,9 @@ import iso8601 from datetime import datetime, timezone -from typing import Any +from typing import Any, Dict, List, Optional from urllib import parse - +from requests import Response from swh.lister.bitbucket.models import BitBucketModel from swh.lister.core.indexing_lister import IndexingHttpLister @@ -26,14 +26,16 @@ instance = 'bitbucket' default_min_bound = datetime.fromtimestamp(0, timezone.utc) # type: Any - def __init__(self, url=None, override_config=None, per_page=100): + def __init__(self, url: Optional[str] = None, + override_config: Optional[bool] = None, + per_page: int = 100) -> None: super().__init__(url=url, override_config=override_config) per_page = self.config.get('per_page', per_page) self.PATH_TEMPLATE = '%s&pagelen=%s' % ( self.PATH_TEMPLATE, per_page) - def get_model_from_repo(self, repo): + def get_model_from_repo(self, repo: Dict) -> Dict[str, Any]: return { 'uid': repo['uuid'], 'indexable': iso8601.parse_date(repo['created_on']), @@ -44,7 +46,8 @@ 'origin_type': repo['scm'], } - def get_next_target_from_response(self, response): + def get_next_target_from_response(self, response: Response + ) -> Optional[datetime]: """This will read the 'next' link from the api response if any and return it as a datetime. @@ -60,23 +63,28 @@ if next_ is not None: next_ = parse.urlparse(next_) return iso8601.parse_date(parse.parse_qs(next_.query)['after'][0]) + return None - def transport_response_simplified(self, response): + def transport_response_simplified(self, response: Response + ) -> List[Dict[str, Any]]: repos = response.json()['values'] return [self.get_model_from_repo(repo) for repo in repos] - def request_uri(self, identifier): - identifier = parse.quote(identifier.isoformat()) - return super().request_uri(identifier or '1970-01-01') + def request_uri(self, identifier: datetime) -> str: # type: ignore + identifier_str = parse.quote(identifier.isoformat()) + return super().request_uri(identifier_str or '1970-01-01') - def is_within_bounds(self, inner, lower=None, upper=None): + def is_within_bounds(self, inner: int, lower: Optional[int] = None, + upper: Optional[int] = None) -> bool: # values are expected to be datetimes - if lower is None and upper is None: - ret = True - elif lower is None: - ret = inner <= upper - elif upper is None: - ret = inner >= lower + if lower is None: + if upper is None: + ret = True + else: + ret = inner <= upper else: - ret = lower <= inner <= upper + if upper is None: + ret = inner >= lower + else: + ret = lower <= inner <= upper return ret diff --git a/swh/lister/cgit/lister.py b/swh/lister/cgit/lister.py --- a/swh/lister/cgit/lister.py +++ b/swh/lister/cgit/lister.py @@ -8,8 +8,10 @@ from bs4 import BeautifulSoup from requests import Session -from requests.adapters import HTTPAdapter +from requests.adapters import HTTPAdapter +from requests.structures import CaseInsensitiveDict +from typing import Dict, Generator, Optional, Any from .models import CGitModel from swh.core.utils import grouper @@ -54,7 +56,9 @@ LISTER_NAME = 'cgit' url_prefix_present = True - def __init__(self, url=None, instance=None, override_config=None): + def __init__(self, url: Optional[str] = None, + instance: Optional[str] = None, + override_config: Optional[str] = None) -> None: """Lister class for CGit repositories. Args: @@ -75,11 +79,11 @@ self.instance = instance self.session = Session() self.session.mount(self.url, HTTPAdapter(max_retries=3)) - self.session.headers = { + self.session.headers = CaseInsensitiveDict({ 'User-Agent': USER_AGENT, - } + }) - def run(self): + def run(self) -> Dict[str, Any]: status = 'uneventful' total = 0 for repos in grouper(self.get_repos(), 10): @@ -94,7 +98,7 @@ return {'status': status} - def get_repos(self): + def get_repos(self) -> Generator[str, None, None]: """Generate git 'project' URLs found on the current CGit server """ @@ -116,7 +120,7 @@ # no pager, or no next page next_page = None - def build_model(self, repo_url): + def build_model(self, repo_url: str) -> Optional[Dict[str, str]]: """Given the URL of a git repo project page on a CGit server, return the repo description (dict) suitable for insertion in the db. """ @@ -124,7 +128,7 @@ urls = [x['href'] for x in bs.find_all('a', {'rel': 'vcs-git'})] if not urls: - return + return None # look for the http/https url, if any, and use it as origin_url for url in urls: @@ -138,11 +142,11 @@ return {'uid': repo_url, 'name': bs.find('a', title=re.compile('.+'))['title'], 'origin_type': 'git', - 'instance': self.instance, + 'instance': str(self.instance), 'origin_url': origin_url, } - def get_and_parse(self, url): + def get_and_parse(self, url: str) -> BeautifulSoup: "Get the given url and parse the retrieved HTML using BeautifulSoup" return BeautifulSoup(self.session.get(url).text, features='html.parser') diff --git a/swh/lister/core/indexing_lister.py b/swh/lister/core/indexing_lister.py --- a/swh/lister/core/indexing_lister.py +++ b/swh/lister/core/indexing_lister.py @@ -12,10 +12,14 @@ from .lister_transports import ListerHttpTransport from .lister_base import ListerBase +from requests import Response +from typing import Any, Dict, List, Tuple, Optional, Generic, TypeVar + logger = logging.getLogger(__name__) +T = TypeVar('T') -class IndexingLister(ListerBase): +class IndexingLister(ListerBase, Generic[T]): """Lister* intermediate class for any service that follows the pattern: - The service must report at least one stable unique identifier, known @@ -55,7 +59,9 @@ """ @abc.abstractmethod - def get_next_target_from_response(self, response): + def get_next_target_from_response( + self, response: Response + ) -> T: """Find the next server endpoint identifier given the entire response. Implementation of this method depends on the server API spec @@ -71,7 +77,8 @@ # You probably don't need to override anything below this line. - def filter_before_inject(self, models_list): + def filter_before_inject( + self, models_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: """Overrides ListerBase.filter_before_inject Bounds query results by this Lister's set max_index. @@ -100,7 +107,9 @@ retlist = retlist.filter(self.MODEL.indexable <= end) return retlist - def db_partition_indices(self, partition_size): + def db_partition_indices( + self, partition_size: int + ) -> List[Tuple[Optional[int], Optional[int]]]: """Describe an index-space compartmentalization of the db table in equal sized chunks. This is used to describe min&max bounds for parallelizing fetch tasks. @@ -165,6 +174,7 @@ t = self.db_session.query(func.min(self.MODEL.indexable)).first() if t: return t[0] + return None def db_last_index(self): """Look in the db for the largest indexable value @@ -175,6 +185,7 @@ t = self.db_session.query(func.max(self.MODEL.indexable)).first() if t: return t[0] + return None def disable_deleted_repo_tasks(self, start, end, keep_these): """Disable tasks for repos that no longer exist between start and end. @@ -254,6 +265,7 @@ class IndexingHttpLister(ListerHttpTransport, IndexingLister): """Convenience class for ensuring right lookup and init order when combining IndexingLister and ListerHttpTransport.""" + def __init__(self, url=None, override_config=None): IndexingLister.__init__(self, override_config=override_config) ListerHttpTransport.__init__(self, url=url) diff --git a/swh/lister/core/lister_base.py b/swh/lister/core/lister_base.py --- a/swh/lister/core/lister_base.py +++ b/swh/lister/core/lister_base.py @@ -2,6 +2,11 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from requests import Response +from .abstractattribute import AbstractAttribute +from swh.scheduler import get_scheduler, utils +from swh.core.utils import grouper +from swh.core import config import abc import datetime import gzip @@ -13,17 +18,13 @@ from sqlalchemy import create_engine, func from sqlalchemy.orm import sessionmaker -from typing import Any, Dict, List, Type, Union - -from swh.core import config -from swh.core.utils import grouper -from swh.scheduler import get_scheduler, utils - -from .abstractattribute import AbstractAttribute - +from typing import Any, Dict, List, Type, Union, Optional +from typing import Generic, TypeVar, Tuple logger = logging.getLogger(__name__) +T = TypeVar('T') + def utcnow(): return datetime.datetime.now(tz=datetime.timezone.utc) @@ -37,7 +38,7 @@ return repr(self.response) -class ListerBase(abc.ABC, config.SWHConfig): +class ListerBase(abc.ABC, config.SWHConfig, Generic[T]): """Lister core base class. Generally a source code hosting service provides an API endpoint for listing the set of stored repositories. A Lister is the discovery @@ -73,7 +74,7 @@ LISTER_NAME = AbstractAttribute( "Lister's name") # type: Union[AbstractAttribute, str] - def transport_request(self, identifier): + def transport_request(self, identifier: Union[str, int]): """Given a target endpoint identifier to query, try once to request it. Implementation of this method determines the network request protocol. @@ -93,7 +94,7 @@ """ pass - def transport_response_to_string(self, response): + def transport_response_to_string(self, response: Response): """Convert the server response into a formatted string for logging. Implementation of this method depends on the shape of the network @@ -106,7 +107,7 @@ """ pass - def transport_response_simplified(self, response): + def transport_response_simplified(self, response: Response): """Convert the server response into list of a dict for each repo in the response, mapping columns in the lister's MODEL class to repo data. @@ -122,7 +123,7 @@ """ pass - def transport_quota_check(self, response): + def transport_quota_check(self, response: Response): """Check server response to see if we're hitting request rate limits. Implementation of this method depends on the server communication @@ -137,7 +138,8 @@ """ pass - def filter_before_inject(self, models_list): + def filter_before_inject( + self, models_list: List[Dict]) -> List[Dict]: """Filter models_list entries prior to injection in the db. This is ran directly after `transport_response_simplified`. @@ -152,7 +154,8 @@ """ return models_list - def do_additional_checks(self, models_list): + def do_additional_checks( + self, models_list: List[Dict]) -> Union[List[Dict], bool]: """Execute some additional checks on the model list (after the filtering). @@ -169,7 +172,9 @@ """ return models_list - def is_within_bounds(self, inner, lower=None, upper=None): + def is_within_bounds( + self, inner: int, + lower: Optional[int] = None, upper: Optional[int] = None) -> bool: """See if a sortable value is inside the range [lower,upper]. MAY BE OVERRIDDEN, for example if the server indexable* key is @@ -188,7 +193,7 @@ if lower is None and upper is None: return True elif lower is None: - ret = inner <= upper + ret = inner <= upper # type: ignore elif upper is None: ret = inner >= lower else: @@ -238,7 +243,7 @@ MAX_RETRIES = 7 CONN_SLEEP = 10 - def __init__(self, override_config=None): + def __init__(self, override_config=None) -> None: self.backoff = self.INITIAL_BACKOFF logger.debug('Loading config from %s' % self.CONFIG_BASE_FILENAME) self.config = self.parse_config_file( @@ -258,17 +263,18 @@ self.mk_session = sessionmaker(bind=self.db_engine) self.db_session = self.mk_session() - def reset_backoff(self): + def reset_backoff(self) -> None: """Reset exponential backoff timeout to initial level.""" self.backoff = self.INITIAL_BACKOFF - def back_off(self): + def back_off(self) -> int: """Get next exponential backoff timeout.""" ret = self.backoff self.backoff *= 10 return ret - def safely_issue_request(self, identifier): + def safely_issue_request(self, identifier: Union[str, int] + ) -> Optional[Response]: """Make network request with retries, rate quotas, and response logs. Protocol is handled by the implementation of the transport_request @@ -315,7 +321,7 @@ return r - def db_query_equal(self, key, value): + def db_query_equal(self, key: Any, value: Any): """Look in the db for a row with key == value Args: @@ -396,7 +402,7 @@ kw = {'priority': priority} if priority else {} return utils.create_task_dict(_type, _policy, url=origin_url, **kw) - def string_pattern_check(self, a, b, c=None): + def string_pattern_check(self, a, b, c=None) -> None: """When comparing indexable types in is_within_bounds, complex strings may not be allowed to differ in basic structure. If they do, it could be a sign of not understanding the data well. For instance, @@ -418,8 +424,8 @@ a_pattern = re.sub('[a-zA-Z0-9]', '[a-zA-Z0-9]', re.escape(a)) - if (isinstance(b, str) and (re.match(a_pattern, b) is None) - or isinstance(c, str) and (re.match(a_pattern, c) is None)): + if (isinstance(b, str) and (re.match(a_pattern, b) is None) or + isinstance(c, str) and (re.match(a_pattern, c) is None)): logger.debug(a_pattern) raise TypeError('incomparable string patterns detected') @@ -456,7 +462,7 @@ """ tasks = {} - def _task_key(m): + def _task_key(m) -> str: return '%s-%s' % ( m['type'], json.dumps(m['arguments'], sort_keys=True) @@ -481,7 +487,8 @@ ir, m, _ = tasks[_task_key(task)] ir.task_id = task['id'] - def ingest_data(self, identifier, checks=False): + def ingest_data(self, identifier: Union[str, int], checks: bool = False + ) -> Tuple[Optional[Response], Union[Dict, List]]: """The core data fetch sequence. Request server endpoint. Simplify and filter response list of repositories. Inject repo information into local db. Queue loader tasks for linked repositories. @@ -506,7 +513,7 @@ self.schedule_missing_tasks(models_list, injected) return response, injected - def save_response(self, response): + def save_response(self, response: Response): """Log the response from a server request to a cache dir. Args: diff --git a/swh/lister/core/lister_transports.py b/swh/lister/core/lister_transports.py --- a/swh/lister/core/lister_transports.py +++ b/swh/lister/core/lister_transports.py @@ -12,7 +12,8 @@ import requests import xmltodict -from typing import Optional, Union +from typing import Optional, Union, Dict, Any, List +from requests import Response from swh.lister import USER_AGENT_TEMPLATE, __version__ @@ -39,7 +40,7 @@ EXPECTED_STATUS_CODES = (200, 429, 403, 404) - def request_headers(self): + def request_headers(self) -> Dict[str, Any]: """Returns dictionary of any request headers needed by the server. MAY BE OVERRIDDEN if request headers are needed. @@ -48,7 +49,7 @@ 'User-Agent': USER_AGENT_TEMPLATE % self.lister_version } - def request_instance_credentials(self): + def request_instance_credentials(self) -> List[Dict[str, Any]]: """Returns dictionary of any credentials configuration needed by the forge instance to list. @@ -81,23 +82,23 @@ list of credential dicts for the current lister. """ - all_creds = self.config.get('credentials') + all_creds = self.config.get('credentials') # type: ignore if not all_creds: return [] - lister_creds = all_creds.get(self.LISTER_NAME, {}) - creds = lister_creds.get(self.instance, []) + lister_creds = all_creds.get(self.LISTER_NAME, {}) # type: ignore + creds = lister_creds.get(self.instance, []) # type: ignore return creds - def request_uri(self, identifier): + def request_uri(self, identifier: Union[str, int]) -> str: """Get the full request URI given the transport_request identifier. MAY BE OVERRIDDEN if something more complex than the PATH_TEMPLATE is required. """ - path = self.PATH_TEMPLATE % identifier + path = self.PATH_TEMPLATE % identifier # type: ignore return self.url + path - def request_params(self, identifier): + def request_params(self, identifier: Union[str, int]) -> Dict[str, Any]: """Get the full parameters passed to requests given the transport_request identifier. @@ -108,14 +109,15 @@ is needed. """ - params = {} + params: Dict[str, Any] = {} params['headers'] = self.request_headers() or {} creds = self.request_instance_credentials() if not creds: return params auth = random.choice(creds) if creds else None if auth: - params['auth'] = (auth['username'], auth['password']) + params['auth'] = (auth['username'], + auth['password']) return params def transport_quota_check(self, response): @@ -152,7 +154,9 @@ self.session = requests.Session() self.lister_version = __version__ - def _transport_action(self, identifier, method='get'): + def _transport_action( + self, identifier: Union[str, int], + method: str = 'get') -> Response: """Permit to ask information to the api prior to actually executing query. @@ -176,13 +180,13 @@ raise FetchError(response) return response - def transport_head(self, identifier): + def transport_head(self, identifier: str) -> Response: """Retrieve head information on api. """ return self._transport_action(identifier, method='head') - def transport_request(self, identifier): + def transport_request(self, identifier: Union[str, int]) -> Response: """Implements ListerBase.transport_request for HTTP using Requests. Retrieve get information on api. @@ -190,7 +194,7 @@ """ return self._transport_action(identifier) - def transport_response_to_string(self, response): + def transport_response_to_string(self, response: Response) -> str: """Implements ListerBase.transport_response_to_string for HTTP given Requests responses. """ diff --git a/swh/lister/core/page_by_page_lister.py b/swh/lister/core/page_by_page_lister.py --- a/swh/lister/core/page_by_page_lister.py +++ b/swh/lister/core/page_by_page_lister.py @@ -7,6 +7,8 @@ from .lister_transports import ListerHttpTransport from .lister_base import ListerBase +from requests import Response +from typing import Optional, Tuple, List, Dict, Any, Union class PageByPageLister(ListerBase): @@ -38,7 +40,8 @@ """ @abc.abstractmethod - def get_next_target_from_response(self, response): + def get_next_target_from_response(self, response: Response + ) -> Optional[int]: """Find the next server endpoint page given the entire response. Implementation of this method depends on the server API spec @@ -58,7 +61,8 @@ pass @abc.abstractmethod - def get_pages_information(self): + def get_pages_information( + self) -> Tuple[Optional[int], Optional[int], Optional[int]]: """Find the total number of pages. Implementation of this method depends on the server API spec @@ -79,7 +83,9 @@ # You probably don't need to override anything below this line. - def do_additional_checks(self, models_list): + def do_additional_checks(self, + models_list: List[Dict[str, Any]] + ) -> Union[List[Dict[str, Any]], bool]: """Potentially check for existence of repositories in models_list. This will be called only if check_existence is flipped on in @@ -92,7 +98,9 @@ return False return models_list - def run(self, min_bound=None, max_bound=None, check_existence=False): + def run(self, min_bound: Optional[int] = None, + max_bound: Optional[int] = None, + check_existence: bool = False): """Main entry function. Sequentially fetches repository data from the service according to the basic outline in the class docstring. Continually fetching sublists until either there @@ -130,7 +138,8 @@ break status = 'eventful' - next_page = self.get_next_target_from_response(response) + next_page = self.get_next_target_from_response( + response) # type: ignore # termination condition @@ -159,6 +168,8 @@ combining PageByPageLister and ListerHttpTransport. """ - def __init__(self, url=None, override_config=None): + + def __init__(self, url: Optional[str] = None, + override_config: Optional[bool] = None) -> None: PageByPageLister.__init__(self, override_config=override_config) ListerHttpTransport.__init__(self, url=url) diff --git a/swh/lister/core/simple_lister.py b/swh/lister/core/simple_lister.py --- a/swh/lister/core/simple_lister.py +++ b/swh/lister/core/simple_lister.py @@ -9,6 +9,7 @@ from swh.core import utils from .lister_base import ListerBase +from requests import Response logger = logging.getLogger(__name__) @@ -29,7 +30,7 @@ db (see fn:`ingest_data`). """ - def list_packages(self, response: Any) -> List[Any]: + def list_packages(self, response: Response) -> List[Any]: """Listing packages method. """ diff --git a/swh/lister/cran/lister.py b/swh/lister/cran/lister.py --- a/swh/lister/cran/lister.py +++ b/swh/lister/cran/lister.py @@ -8,13 +8,15 @@ import pkg_resources import subprocess -from typing import List, Mapping, Tuple +from typing import List, Mapping, Tuple, Any, Dict, Optional, Union from swh.lister.cran.models import CRANModel from swh.lister.core.simple_lister import SimpleLister from swh.scheduler.utils import create_task_dict +from requests import Response + logger = logging.getLogger(__name__) @@ -27,8 +29,10 @@ LISTER_NAME = 'cran' instance = 'cran' - def task_dict(self, origin_type, origin_url, version=None, html_url=None, - policy=None, **kwargs): + def task_dict(self, origin_type: str, origin_url: str, + version: Optional[str] = None, + html_url: Optional[str] = None, policy: Optional[str] = None, + **kwargs) -> Dict[str, Any]: """Return task format dict. This creates tasks with args and kwargs set, for example:: @@ -54,7 +58,7 @@ }], retries_left=3 ) - def safely_issue_request(self, identifier): + def safely_issue_request(self, identifier: Union[str, int]) -> None: """Bypass the implementation. It's now the `list_packages` which returns data. @@ -65,7 +69,7 @@ """ return None - def list_packages(self, response) -> List[Mapping[str, str]]: + def list_packages(self, response: Response) -> List[Mapping[str, str]]: """Runs R script which uses inbuilt API to return a json response containing data about the R packages. diff --git a/swh/lister/debian/lister.py b/swh/lister/debian/lister.py --- a/swh/lister/debian/lister.py +++ b/swh/lister/debian/lister.py @@ -13,7 +13,8 @@ from debian.deb822 import Sources from sqlalchemy.orm import joinedload, load_only from sqlalchemy.schema import CreateTable, DropTable -from typing import Mapping, Optional +from typing import Mapping, Optional, Dict, Any, Union +from requests import Response from swh.lister.debian.models import ( AreaSnapshot, Distribution, DistributionSnapshot, Package, @@ -58,7 +59,7 @@ self.date = override_config.get('date', date) or datetime.datetime.now( tz=datetime.timezone.utc) - def transport_request(self, identifier): + def transport_request(self, identifier: Union[str, int]) -> Response: """Subvert ListerHttpTransport.transport_request, to try several index URIs in turn. @@ -88,13 +89,13 @@ self.decompressor = decompressors.get(compression) return response - def request_uri(self, identifier): + def request_uri(self, identifier: Union[str, int]): # In the overridden transport_request, we pass # ListerBase.transport_request() the full URI as identifier, so we # need to return it here. return identifier - def request_params(self, identifier): + def request_params(self, identifier: Union[str, int]) -> Dict[str, Any]: # Enable streaming to allow wrapping the response in the decompressor # in transport_response_simplified. params = super().request_params(identifier) @@ -108,13 +109,14 @@ Checksums-Sha1, Checksums-Sha256), to return a files dict mapping filenames to their checksums. """ + data = response.raw if self.decompressor: - data = self.decompressor(response.raw) + data = self.decompressor(data) else: data = response.raw for src_pkg in Sources.iter_paragraphs(data.readlines()): - files = defaultdict(dict) + files: Dict[str, Any] = defaultdict(dict) for field in src_pkg._multivalued_fields: if field.startswith('checksums-'): diff --git a/swh/lister/github/lister.py b/swh/lister/github/lister.py --- a/swh/lister/github/lister.py +++ b/swh/lister/github/lister.py @@ -5,11 +5,13 @@ import re -from typing import Any +from typing import Any, Dict, List, Tuple, Optional from swh.lister.core.indexing_lister import IndexingHttpLister from swh.lister.github.models import GitHubModel +from requests import Response + class GitHubLister(IndexingHttpLister): PATH_TEMPLATE = '/repositories?since=%d' @@ -20,7 +22,7 @@ instance = 'github' # There is only 1 instance of such lister default_min_bound = 0 # type: Any - def get_model_from_repo(self, repo): + def get_model_from_repo(self, repo: Dict[str, Any]) -> Dict[str, Any]: return { 'uid': repo['id'], 'indexable': repo['id'], @@ -32,7 +34,7 @@ 'fork': repo['fork'], } - def transport_quota_check(self, response): + def transport_quota_check(self, response: Response) -> Tuple[bool, int]: x_rate_limit_remaining = response.headers.get('X-RateLimit-Remaining') if not x_rate_limit_remaining: return False, 0 @@ -42,17 +44,21 @@ return True, delay return False, 0 - def get_next_target_from_response(self, response): + def get_next_target_from_response(self, + response: Response) -> Optional[int]: if 'next' in response.links: next_url = response.links['next']['url'] - return int(self.API_URL_INDEX_RE.match(next_url).group(1)) + return int( + self.API_URL_INDEX_RE.match(next_url).group(1)) # type: ignore + return None - def transport_response_simplified(self, response): + def transport_response_simplified(self, response: Response + ) -> List[Dict[str, Any]]: repos = response.json() return [self.get_model_from_repo(repo) for repo in repos if repo and 'id' in repo] - def request_headers(self): + def request_headers(self) -> Dict[str, Any]: """(Override) Set requests headers to send when querying the GitHub API """ @@ -60,7 +66,8 @@ headers['Accept'] = 'application/vnd.github.v3+json' return headers - def disable_deleted_repo_tasks(self, index, next_index, keep_these): + def disable_deleted_repo_tasks(self, index: int, + next_index: int, keep_these: int): """ (Overrides) Fix provided index value to avoid erroneously disabling some scheduler tasks """ diff --git a/swh/lister/gitlab/lister.py b/swh/lister/gitlab/lister.py --- a/swh/lister/gitlab/lister.py +++ b/swh/lister/gitlab/lister.py @@ -9,6 +9,9 @@ from ..core.page_by_page_lister import PageByPageHttpLister from .models import GitLabModel +from typing import Any, Dict, List, Tuple, Union, MutableMapping, Optional +from requests import Response + class GitLabLister(PageByPageHttpLister): # Template path expecting an integer that represents the page id @@ -17,8 +20,10 @@ MODEL = GitLabModel LISTER_NAME = 'gitlab' - def __init__(self, url=None, instance=None, - override_config=None, sort='asc', per_page=20): + def __init__(self, url: Optional[str] = None, + instance: Optional[str] = None, + override_config: Optional[bool] = None, + sort: str = 'asc', per_page: int = 20) -> None: super().__init__(url=url, override_config=override_config) if instance is None: instance = parse_url(self.url).host @@ -26,10 +31,10 @@ self.PATH_TEMPLATE = '%s&sort=%s&per_page=%s' % ( self.PATH_TEMPLATE, sort, per_page) - def uid(self, repo): + def uid(self, repo: Dict[str, Any]) -> str: return '%s/%s' % (self.instance, repo['path_with_namespace']) - def get_model_from_repo(self, repo): + def get_model_from_repo(self, repo: Dict[str, Any]) -> Dict[str, Any]: return { 'instance': self.instance, 'uid': self.uid(repo), @@ -40,7 +45,8 @@ 'origin_type': 'git', } - def transport_quota_check(self, response): + def transport_quota_check(self, response: Response + ) -> Tuple[bool, Union[int, float]]: """Deal with rate limit if any. """ @@ -53,22 +59,26 @@ return True, delay return False, 0 - def _get_int(self, headers, key): + def _get_int(self, headers: MutableMapping[str, Any], + key: str) -> Optional[int]: _val = headers.get(key) if _val: return int(_val) + return None - def get_next_target_from_response(self, response): + def get_next_target_from_response( + self, response: Response) -> Optional[int]: """Determine the next page identifier. """ return self._get_int(response.headers, 'x-next-page') - def get_pages_information(self): + def get_pages_information(self) -> Tuple[Optional[int], + Optional[int], Optional[int]]: """Determine pages information. """ - response = self.transport_head(identifier=1) + response = self.transport_head(identifier=1) # type: ignore if not response.ok: raise ValueError( 'Problem during information fetch: %s' % response.status_code) @@ -77,6 +87,7 @@ self._get_int(h, 'x-total-pages'), self._get_int(h, 'x-per-page')) - def transport_response_simplified(self, response): + def transport_response_simplified(self, response: Response + ) -> List[Dict[str, Any]]: repos = response.json() return [self.get_model_from_repo(repo) for repo in repos] diff --git a/swh/lister/gnu/lister.py b/swh/lister/gnu/lister.py --- a/swh/lister/gnu/lister.py +++ b/swh/lister/gnu/lister.py @@ -10,6 +10,8 @@ from swh.lister.gnu.models import GNUModel from swh.lister.gnu.tree import GNUTree +from typing import Any, Dict, List, Mapping, Union +from requests import Response logger = logging.getLogger(__name__) @@ -19,11 +21,12 @@ LISTER_NAME = 'gnu' instance = 'gnu' - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs: Mapping[str, str]) -> None: super().__init__(*args, **kwargs) self.gnu_tree = GNUTree('https://ftp.gnu.org/tree.json.gz') - def task_dict(self, origin_type, origin_url, **kwargs): + def task_dict(self, origin_type: str, + origin_url: str, **kwargs) -> Dict[str, Any]: """Return task format dict This is overridden from the lister_base as more information is @@ -58,7 +61,7 @@ retries_left=3, ) - def safely_issue_request(self, identifier): + def safely_issue_request(self, identifier: Union[str, int]) -> None: """Bypass the implementation. It's now the GNUTree which deals with querying the gnu mirror. @@ -69,7 +72,7 @@ """ return None - def list_packages(self, response): + def list_packages(self, response: Response) -> List[Dict[str, Any]]: """List the actual gnu origins (package name) with their name, url and associated tarballs. @@ -96,7 +99,7 @@ """ return list(self.gnu_tree.projects.values()) - def get_model_from_repo(self, repo): + def get_model_from_repo(self, repo: Dict[str, Any]) -> Dict[str, Any]: """Transform from repository representation to model """ diff --git a/swh/lister/npm/lister.py b/swh/lister/npm/lister.py --- a/swh/lister/npm/lister.py +++ b/swh/lister/npm/lister.py @@ -6,6 +6,9 @@ from swh.lister.npm.models import NpmModel from swh.scheduler.utils import create_task_dict +from typing import Dict, Optional, List, Any, Mapping +from requests import Response + class NpmListerBase(IndexingHttpLister): """List packages available in the npm registry in a paginated way @@ -15,14 +18,16 @@ LISTER_NAME = 'npm' instance = 'npm' - def __init__(self, url='https://replicate.npmjs.com', - per_page=1000, override_config=None): + def __init__(self, url: str = 'https://replicate.npmjs.com', + per_page: int = 1000, + override_config: Optional[str] = None) -> None: super().__init__(url=url, override_config=override_config) self.per_page = per_page + 1 - self.PATH_TEMPLATE += '&limit=%s' % self.per_page + self.PATH_TEMPLATE: str = self.PATH_TEMPLATE + '&limit=%s' % ( + self.per_page) @property - def ADDITIONAL_CONFIG(self): + def ADDITIONAL_CONFIG(self) -> Dict[str, str]: """(Override) Add extra configuration """ @@ -30,7 +35,7 @@ default_config['loading_task_policy'] = ('str', 'recurring') return default_config - def get_model_from_repo(self, repo_name): + def get_model_from_repo(self, repo_name: str) -> Dict[str, str]: """(Override) Transform from npm package name to model """ @@ -45,7 +50,9 @@ 'origin_type': 'npm', } - def task_dict(self, origin_type, origin_url, **kwargs): + def task_dict(self, origin_type: str, + origin_url: str, + **kwargs: Mapping[str, str]) -> Dict[str, Any]: """(Override) Return task dict for loading a npm package into the archive. @@ -58,7 +65,7 @@ return create_task_dict(task_type, task_policy, url=origin_url) - def request_headers(self): + def request_headers(self) -> Dict[str, str]: """(Override) Set requests headers to send when querying the npm registry. @@ -67,7 +74,7 @@ headers['Accept'] = 'application/json' return headers - def string_pattern_check(self, inner, lower, upper=None): + def string_pattern_check(self, inner: int, lower: int, upper: int = None): """ (Override) Inhibit the effect of that method as packages indices correspond to package names and thus do not respect any kind of fixed length string pattern @@ -82,14 +89,16 @@ """ PATH_TEMPLATE = '/_all_docs?startkey="%s"' - def get_next_target_from_response(self, response): + def get_next_target_from_response( + self, response: Response) -> Optional[str]: """(Override) Get next npm package name to continue the listing """ repos = response.json()['rows'] return repos[-1]['id'] if len(repos) == self.per_page else None - def transport_response_simplified(self, response): + def transport_response_simplified( + self, response: Response) -> List[Dict[str, str]]: """(Override) Transform npm registry response to list for model manipulation """ @@ -110,14 +119,16 @@ def CONFIG_BASE_FILENAME(self): # noqa: N802 return 'lister_npm_incremental' - def get_next_target_from_response(self, response): + def get_next_target_from_response( + self, response: Response) -> Optional[str]: """(Override) Get next npm package name to continue the listing. """ repos = response.json()['results'] return repos[-1]['seq'] if len(repos) == self.per_page else None - def transport_response_simplified(self, response): + def transport_response_simplified( + self, response: Response) -> List[Dict[str, str]]: """(Override) Transform npm registry response to list for model manipulation. @@ -127,7 +138,8 @@ repos = repos[:-1] return [self.get_model_from_repo(repo['id']) for repo in repos] - def filter_before_inject(self, models_list): + def filter_before_inject( + self, models_list: List[Dict[str, str]]) -> List[Dict[str, str]]: """(Override) Filter out documents in the CouchDB database not related to a npm package. diff --git a/swh/lister/packagist/lister.py b/swh/lister/packagist/lister.py --- a/swh/lister/packagist/lister.py +++ b/swh/lister/packagist/lister.py @@ -7,7 +7,7 @@ import logging import random -from typing import Any, Dict, List, Mapping +from typing import Any, Dict, List, Mapping, Optional from swh.scheduler import utils from swh.lister.core.simple_lister import SimpleLister @@ -57,7 +57,7 @@ PAGE = 'https://packagist.org/packages/list.json' instance = 'packagist' - def __init__(self, override_config=None): + def __init__(self, override_config: Optional[bool] = None) -> None: ListerOnePageApiTransport .__init__(self) SimpleLister.__init__(self, override_config=override_config) diff --git a/swh/lister/phabricator/lister.py b/swh/lister/phabricator/lister.py --- a/swh/lister/phabricator/lister.py +++ b/swh/lister/phabricator/lister.py @@ -14,6 +14,8 @@ from swh.lister.core.indexing_lister import IndexingHttpLister from swh.lister.phabricator.models import PhabricatorModel +from typing import Any, Dict, List, Optional, Union +from requests import Response logger = logging.getLogger(__name__) @@ -25,13 +27,15 @@ MODEL = PhabricatorModel LISTER_NAME = 'phabricator' - def __init__(self, url=None, instance=None, override_config=None): + def __init__(self, url: Optional[str] = None, + instance: Optional[str] = None, + override_config: Optional[bool] = None) -> None: super().__init__(url=url, override_config=override_config) if not instance: instance = urllib.parse.urlparse(self.url).hostname self.instance = instance - def request_params(self, identifier): + def request_params(self, identifier: Union[str, int]) -> Dict[str, Any]: """Override the default params behavior to retrieve the api token Credentials are stored as: @@ -52,7 +56,7 @@ return {'headers': self.request_headers() or {}, 'params': {'api.token': api_token}} - def request_headers(self): + def request_headers(self) -> Dict[str, Any]: """ (Override) Set requests headers to send when querying the Phabricator API @@ -61,7 +65,8 @@ headers['Accept'] = 'application/json' return headers - def get_model_from_repo(self, repo): + def get_model_from_repo( + self, repo: Dict[str, Any]) -> Optional[Dict[str, Any]]: url = get_repo_url(repo['attachments']['uris']['uris']) if url is None: return None @@ -76,12 +81,15 @@ 'instance': self.instance, } - def get_next_target_from_response(self, response): + def get_next_target_from_response( + self, response: Response) -> Optional[int]: body = response.json()['result']['cursor'] if body['after'] and body['after'] != 'null': return int(body['after']) + return None - def transport_response_simplified(self, response): + def transport_response_simplified( + self, response: Response) -> List[Optional[Dict[str, Any]]]: repos = response.json() if repos['result'] is None: raise ValueError( @@ -89,7 +97,8 @@ repos = repos['result']['data'] return [self.get_model_from_repo(repo) for repo in repos] - def filter_before_inject(self, models_list): + def filter_before_inject( + self, models_list: List[Dict[str, str]]) -> List[Dict[str, str]]: """ (Overrides) IndexingLister.filter_before_inject Bounds query results by this Lister's set max_index. @@ -97,7 +106,8 @@ models_list = [m for m in models_list if m is not None] return super().filter_before_inject(models_list) - def disable_deleted_repo_tasks(self, index, next_index, keep_these): + def disable_deleted_repo_tasks( + self, index: int, next_index: int, keep_these: str): """ (Overrides) Fix provided index value to avoid: @@ -117,7 +127,7 @@ return super().disable_deleted_repo_tasks(index, next_index, keep_these) - def db_first_index(self): + def db_first_index(self) -> Optional[int]: """ (Overrides) Filter results by Phabricator instance @@ -128,8 +138,9 @@ t = t.filter(self.MODEL.instance == self.instance).first() if t: return t[0] + return None - def db_last_index(self): + def db_last_index(self) -> Optional[int]: """ (Overrides) Filter results by Phabricator instance @@ -140,8 +151,9 @@ t = t.filter(self.MODEL.instance == self.instance).first() if t: return t[0] + return None - def db_query_range(self, start, end): + def db_query_range(self, start: int, end: int) -> List: """ (Overrides) Filter the results by the Phabricator instance to avoid disabling loading tasks for repositories hosted on a @@ -155,14 +167,14 @@ return retlist.filter(self.MODEL.instance == self.instance) -def get_repo_url(attachments): +def get_repo_url(attachments: List[Dict[str, Any]]) -> Optional[int]: """ Return url for a hosted repository from its uris attachments according to the following priority lists: * protocol: https > http * identifier: shortname > callsign > id """ - processed_urls = defaultdict(dict) + processed_urls = defaultdict(dict) # type: Dict[str, Any] for uri in attachments: protocol = uri['fields']['builtin']['protocol'] url = uri['fields']['uri']['effective'] diff --git a/swh/lister/pypi/lister.py b/swh/lister/pypi/lister.py --- a/swh/lister/pypi/lister.py +++ b/swh/lister/pypi/lister.py @@ -12,6 +12,9 @@ from swh.lister.core.simple_lister import SimpleLister from swh.lister.core.lister_transports import ListerOnePageApiTransport +from typing import Any, Dict, List, Optional, Mapping +from requests import Response + class PyPILister(ListerOnePageApiTransport, SimpleLister): MODEL = PyPIModel @@ -19,11 +22,13 @@ PAGE = 'https://pypi.org/simple/' instance = 'pypi' # As of today only the main pypi.org is used - def __init__(self, override_config=None): + def __init__(self, override_config: Optional[str] = None) -> None: ListerOnePageApiTransport .__init__(self) SimpleLister.__init__(self, override_config=override_config) - def task_dict(self, origin_type, origin_url, **kwargs): + def task_dict(self, origin_type: str, + origin_url: str, + **kwargs: Mapping[str, str]) -> Dict[str, Any]: """(Override) Return task format dict This is overridden from the lister_base as more information is @@ -35,7 +40,7 @@ return utils.create_task_dict( _type, _policy, url=origin_url) - def list_packages(self, response): + def list_packages(self, response: Response) -> List[Dict]: """(Override) List the actual pypi origins from the response. """ @@ -50,7 +55,7 @@ """ return 'https://pypi.org/project/%s/' % repo_name - def get_model_from_repo(self, repo_name): + def get_model_from_repo(self, repo_name: str) -> Dict[str, Any]: """(Override) Transform from repository representation to model """