diff --git a/swh/scanner/cli.py b/swh/scanner/cli.py --- a/swh/scanner/cli.py +++ b/swh/scanner/cli.py @@ -155,8 +155,23 @@ type=click.Choice(["origin"]), help="Add selected additional information about known software artifacts.", ) +@click.option( + "--debug-http", + is_flag=True, + help="add debug information about performed http request,", +) @click.pass_context -def scan(ctx, root_path, api_url, patterns, out_fmt, interactive, policy, extra_info): +def scan( + ctx, + root_path, + api_url, + patterns, + out_fmt, + interactive, + policy, + extra_info, + debug_http, +): """Scan a source code project to discover files and directories already present in the archive. @@ -202,7 +217,16 @@ config = setup_config(ctx, api_url) extra_info = set(extra_info) - scanner.scan(config, root_path, patterns, out_fmt, interactive, policy, extra_info) + scanner.scan( + config, + root_path, + patterns, + out_fmt, + interactive, + policy, + extra_info, + debug_http=debug_http, + ) @scanner.group("db", help="Manage local knowledge base for swh-scanner") diff --git a/swh/scanner/client.py b/swh/scanner/client.py --- a/swh/scanner/client.py +++ b/swh/scanner/client.py @@ -12,7 +12,9 @@ import asyncio import itertools -from typing import Any, Dict, List, Optional +import sys +import time +from typing import Any, Dict, List, Optional, Tuple import aiohttp @@ -23,17 +25,46 @@ # Maximum number of SWHIDs that can be requested by a single call to the # Web API endpoint /known/ QUERY_LIMIT = 1000 +MAX_RETRY = 10 KNOWN_EP = "known/" GRAPH_RANDOMWALK_EP = "graph/randomwalk/" +def _get_chunk(swhids): + """slice a list of `swhids` into smaller list of size QUERY_LIMIT""" + for i in range(0, len(swhids), QUERY_LIMIT): + yield swhids[i : i + QUERY_LIMIT] + + +def _parse_limit_header(response) -> Tuple[Optional[int], Optional[int], Optional[int]]: + """parse the X-RateLimit Headers if any""" + limit = response.headers.get("X-RateLimit-Limit") + if limit is not None: + limit = int(limit) + remaining = response.headers.get("X-RateLimit-Remaining") + if remaining is not None: + remaining = int(remaining) + reset = response.headers.get("X-RateLimit-Reset") + if reset is not None: + reset = int(reset) + return (limit, remaining, reset) + + class Client: """Manage requests to the Software Heritage Web API.""" - def __init__(self, api_url: str, session: aiohttp.ClientSession): + def __init__( + self, + api_url: str, + session: aiohttp.ClientSession, + debug=False, + ): + self._sleep = 0 + self._debug = debug self.api_url = api_url self.session = session + self._known_endpoint = self.api_url + KNOWN_EP async def get_origin(self, swhid: CoreSWHID) -> Optional[Any]: """Walk the compressed graph to discover the origin of a given swhid""" @@ -70,27 +101,119 @@ value['known'] = False if the SWHID is not found """ - endpoint = self.api_url + KNOWN_EP requests = [] - def get_chunk(swhids): - for i in range(0, len(swhids), QUERY_LIMIT): - yield swhids[i : i + QUERY_LIMIT] - - async def make_request(swhids): - swhids = [str(swhid) for swhid in swhids] - async with self.session.post(endpoint, json=swhids) as resp: - if resp.status != 200: - error_response(resp.reason, resp.status, endpoint) + swh_ids = [str(swhid) for swhid in swhids] - return await resp.json() - - if len(swhids) > QUERY_LIMIT: - for swhids_chunk in get_chunk(swhids): - requests.append(asyncio.create_task(make_request(swhids_chunk))) + if len(swhids) <= QUERY_LIMIT: + return await self._make_request(swh_ids) + else: + for swhids_chunk in _get_chunk(swh_ids): + task = asyncio.create_task(self._make_request(swhids_chunk)) + requests.append(task) res = await asyncio.gather(*requests) # concatenate list of dictionaries return dict(itertools.chain.from_iterable(e.items() for e in res)) + + def _mark_success(self, limit=None, remaining=None, reset=None): + """call when a request is successfully made, this will adjust the rate + + The extra argument can be used to transmit the X-RateLimit information + from the server. This will be used to adjust the request rate""" + self._sleep = 0 + factor = 0 + current = time.time() + if self._debug: + dbg_msg = f"HTTP GOOD {current:.2f}:" + if limit is None or remaining is None or reset is None: + if self._debug: + dbg_msg += " no rate limit data;" else: - return await make_request(swhids) + time_windows = reset - current + if self._debug: + dbg_msg += f" remaining-request={remaining}/{limit}" + dbg_msg += f" reset-in={time_windows:.2f}" + if time_windows <= 0: + return + used_up = remaining / limit + if remaining <= 0: + # no more credit, we can sit up and wait. + # + # XXX we should warn the user. This can get very long. + self._sleep = time_windows + factor = -1 + elif 0.6 < used_up: + # let us not limit the first flight of request. + factor = 0 + else: + # the deeper we consume the credit the higher is the rate + # limiting, let's put a brake on our current rate the lower we get + # + # (The factor range from 1 to 1000) + factor = (0.4 + used_up) ** -1.5 + if factor >= 0: + self._sleep = ((time_windows / remaining)) * factor + if self._debug: + dbg_msg += f"; current-sleep={self._sleep}" + print(dbg_msg, file=sys.stderr) + + def _mark_failure(self, limit=None, remaining=None, reset=None): + """call when a request failed, this will reduce the request rate. + + The extra argument can be used to transmit the X-RateLimit information + from the server. This will be used to adjust the request rate""" + current = time.time() + if self._debug: + dbg_msg = f"HTTP BAD {current:.2f}:" + time_set = False + if remaining is None or reset is None: + if self._debug: + dbg_msg += " no rate limit data" + else: + wait_for = reset - current + if self._debug: + dbg_msg += f" remaining-request={remaining}/{limit}" + dbg_msg += f" reset-in={wait_for:.2f}" + if remaining <= 0: + # Add some margin to please the rate limiting code + wait_for *= 1.1 + if wait_for > 0 and wait_for >= self._sleep: + self._sleep = wait_for + time_set = True + if not time_set: + if self._sleep <= 0: + self._sleep = 1 + else: + self._sleep *= 2 + if self._debug: + dbg_msg += "; current-sleep={self._sleep}" + print(dbg_msg, file=sys.stderr) + + async def _make_request(self, swhids): + endpoint = self._known_endpoint + + data = None + + retry = MAX_RETRY + + while data is None: + # slow the pace of request if needed + if self._sleep > 0: + time.sleep(self._sleep) + async with self.session.post(endpoint, json=swhids) as resp: + rate_limit = _parse_limit_header(resp) + if resp.status == 200: + try: + # inform of success before the await + self._mark_success(*rate_limit) + data = await resp.json() + except aiohttp.client_exceptions.ClientConnectionError: + raise + else: + break + self._mark_failure(*rate_limit) + retry -= 1 + if retry <= 0 or resp.status == 413: # 413: Payload Too Large + error_response(resp.reason, resp.status, endpoint) + return data diff --git a/swh/scanner/scanner.py b/swh/scanner/scanner.py --- a/swh/scanner/scanner.py +++ b/swh/scanner/scanner.py @@ -31,6 +31,7 @@ source_tree: Directory, nodes_data: MerkleNodeInfo, extra_info: set, + debug_http=False, ) -> None: """Scan a given source code according to the policy given in input.""" api_url = config["web-api"]["url"] @@ -41,7 +42,7 @@ headers = {} async with aiohttp.ClientSession(headers=headers, trust_env=True) as session: - client = Client(api_url, session) + client = Client(api_url, session, debug=debug_http) for info in extra_info: if info == "known": await policy.run(client) @@ -78,6 +79,7 @@ interactive: bool, policy: str, extra_info: set, + debug_http=False, ): """Scan a source code project to discover files and directories already present in the archive""" @@ -91,7 +93,16 @@ policy = get_policy_obj(source_tree, nodes_data, policy) loop = asyncio.get_event_loop() - loop.run_until_complete(run(config, policy, source_tree, nodes_data, extra_info)) + loop.run_until_complete( + run( + config, + policy, + source_tree, + nodes_data, + extra_info, + debug_http=debug_http, + ) + ) out = Output(root_path, nodes_data, source_tree) if interactive: