diff --git a/swh/lister/github/lister.py b/swh/lister/github/lister.py index f0ace80..8947595 100644 --- a/swh/lister/github/lister.py +++ b/swh/lister/github/lister.py @@ -1,363 +1,205 @@ -# Copyright (C) 2020 The Software Heritage developers +# Copyright (C) 2020-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from dataclasses import asdict, dataclass import datetime import logging -import random -import time from typing import Any, Dict, Iterator, List, Optional, Set from urllib.parse import parse_qs, urlparse import iso8601 -import requests -from tenacity import ( - retry, - retry_any, - retry_if_exception_type, - retry_if_result, - wait_exponential, -) from swh.scheduler.interface import SchedulerInterface from swh.scheduler.model import ListedOrigin -from .. import USER_AGENT from ..pattern import CredentialsType, Lister +from .utils import GitHubSession, MissingRateLimitReset logger = logging.getLogger(__name__) -class RateLimited(Exception): - def __init__(self, response): - self.reset_time: Optional[int] - - # Figure out how long we need to sleep because of that rate limit - ratelimit_reset = response.headers.get("X-Ratelimit-Reset") - retry_after = response.headers.get("Retry-After") - if ratelimit_reset is not None: - self.reset_time = int(ratelimit_reset) - elif retry_after is not None: - self.reset_time = int(time.time()) + int(retry_after) + 1 - else: - logger.warning( - "Received a rate-limit-like status code %s, but no rate-limit " - "headers set. Response content: %s", - response.status_code, - response.content, - ) - self.reset_time = None - self.response = response - - -class MissingRateLimitReset(Exception): - pass - - -class GitHubSession: - """Manages a :class:`requests.Session` with (optionally) multiple credentials, - and cycles through them when reaching rate-limits.""" - - def __init__(self, credentials: Optional[List[Dict[str, str]]] = None) -> None: - """Initialize a requests session with the proper headers for requests to - GitHub.""" - self.credentials = credentials - if self.credentials: - random.shuffle(self.credentials) - - self.session = requests.Session() - - self.session.headers.update( - {"Accept": "application/vnd.github.v3+json", "User-Agent": USER_AGENT} - ) - - self.anonymous = not self.credentials - - if self.anonymous: - logger.warning("No tokens set in configuration, using anonymous mode") - - self.token_index = -1 - self.current_user: Optional[str] = None - - if not self.anonymous: - # Initialize the first token value in the session headers - self.set_next_session_token() - - def set_next_session_token(self) -> None: - """Update the current authentication token with the next one in line.""" - - assert self.credentials - - self.token_index = (self.token_index + 1) % len(self.credentials) - - auth = self.credentials[self.token_index] - - self.current_user = auth["username"] - logger.debug("Using authentication token for user %s", self.current_user) - - if "password" in auth: - token = auth["password"] - else: - token = auth["token"] - - self.session.headers.update({"Authorization": f"token {token}"}) - - @retry( - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_any( - # ChunkedEncodingErrors happen when the TLS connection gets reset, e.g. - # when running the lister on a connection with high latency - retry_if_exception_type(requests.exceptions.ChunkedEncodingError), - # 502 status codes happen for a Server Error, sometimes - retry_if_result(lambda r: r.status_code == 502), - ), - ) - def _request(self, url: str) -> requests.Response: - response = self.session.get(url) - - if ( - # GitHub returns inconsistent status codes between unauthenticated - # rate limit and authenticated rate limits. Handle both. - response.status_code == 429 - or (self.anonymous and response.status_code == 403) - ): - raise RateLimited(response) - - return response - - def request(self, url) -> requests.Response: - """Repeatedly requests the given URL, cycling through credentials and sleeping - if necessary; until either a successful response or :exc:`MissingRateLimitReset` - """ - # The following for/else loop handles rate limiting; if successful, - # it provides the rest of the function with a `response` object. - # - # If all tokens are rate-limited, we sleep until the reset time, - # then `continue` into another iteration of the outer while loop, - # attempting to get data from the same URL again. - - while True: - max_attempts = len(self.credentials) if self.credentials else 1 - reset_times: Dict[int, int] = {} # token index -> time - for attempt in range(max_attempts): - try: - return self._request(url) - except RateLimited as e: - reset_info = "(unknown reset)" - if e.reset_time is not None: - reset_times[self.token_index] = e.reset_time - reset_info = "(resetting in %ss)" % (e.reset_time - time.time()) - - if not self.anonymous: - logger.info( - "Rate limit exhausted for current user %s %s", - self.current_user, - reset_info, - ) - # Use next token in line - self.set_next_session_token() - # Wait one second to avoid triggering GitHub's abuse rate limits - time.sleep(1) - - # All tokens have been rate-limited. What do we do? - - if not reset_times: - logger.warning( - "No X-Ratelimit-Reset value found in responses for any token; " - "Giving up." - ) - raise MissingRateLimitReset() - - sleep_time = max(reset_times.values()) - time.time() + 1 - logger.info( - "Rate limits exhausted for all tokens. Sleeping for %f seconds.", - sleep_time, - ) - time.sleep(sleep_time) - - @dataclass class GitHubListerState: """State of the GitHub lister""" last_seen_id: int = 0 """Numeric id of the last repository listed on an incremental pass""" class GitHubLister(Lister[GitHubListerState, List[Dict[str, Any]]]): """List origins from GitHub. By default, the lister runs in incremental mode: it lists all repositories, starting with the `last_seen_id` stored in the scheduler backend. Providing the `first_id` and `last_id` arguments enables the "relisting" mode: in that mode, the lister finds the origins present in the range **excluding** `first_id` and **including** `last_id`. In this mode, the lister can overrun the `last_id`: it will always record all the origins seen in a given page. As the lister is fully idempotent, this is not a practical problem. Once relisting completes, the lister state in the scheduler backend is not updated. When the config contains a set of credentials, we shuffle this list at the beginning of the listing. To follow GitHub's `abuse rate limit policy`_, we keep using the same token over and over again, until its rate limit runs out. Once that happens, we switch to the next token over in our shuffled list. When a request fails with a rate limit exception for all tokens, we pause the listing until the largest value for X-Ratelimit-Reset over all tokens. When the credentials aren't set in the lister config, the lister can run in anonymous mode too (e.g. for testing purposes). .. _abuse rate limit policy: https://developer.github.com/v3/guides/best-practices-for-integrators/#dealing-with-abuse-rate-limits Args: first_id: the id of the first repo to list last_id: stop listing after seeing a repo with an id higher than this value. """ # noqa: B950 LISTER_NAME = "github" API_URL = "https://api.github.com/repositories" PAGE_SIZE = 1000 def __init__( self, scheduler: SchedulerInterface, credentials: CredentialsType = None, first_id: Optional[int] = None, last_id: Optional[int] = None, ): super().__init__( scheduler=scheduler, credentials=credentials, url=self.API_URL, instance="github", ) self.first_id = first_id self.last_id = last_id self.relisting = self.first_id is not None or self.last_id is not None self.github_session = GitHubSession(credentials=self.credentials) def state_from_dict(self, d: Dict[str, Any]) -> GitHubListerState: return GitHubListerState(**d) def state_to_dict(self, state: GitHubListerState) -> Dict[str, Any]: return asdict(state) def get_pages(self) -> Iterator[List[Dict[str, Any]]]: current_id = 0 if self.first_id is not None: current_id = self.first_id elif self.state is not None: current_id = self.state.last_seen_id current_url = f"{self.API_URL}?since={current_id}&per_page={self.PAGE_SIZE}" while self.last_id is None or current_id < self.last_id: logger.debug("Getting page %s", current_url) try: response = self.github_session.request(current_url) except MissingRateLimitReset: # Give up break # We've successfully retrieved a (non-ratelimited) `response`. We # still need to check it for validity. if response.status_code != 200: logger.warning( "Got unexpected status_code %s: %s", response.status_code, response.content, ) break yield response.json() if "next" not in response.links: # No `next` link, we've reached the end of the world logger.debug( "No next link found in the response headers, all caught up" ) break # GitHub strongly advises to use the next link directly. We still # parse it to get the id of the last repository we've reached so # far. next_url = response.links["next"]["url"] parsed_url = urlparse(next_url) if not parsed_url.query: logger.warning("Failed to parse url %s", next_url) break parsed_query = parse_qs(parsed_url.query) current_id = int(parsed_query["since"][0]) current_url = next_url def get_origins_from_page( self, page: List[Dict[str, Any]] ) -> Iterator[ListedOrigin]: """Convert a page of GitHub repositories into a list of ListedOrigins. This records the html_url, as well as the pushed_at value if it exists. """ assert self.lister_obj.id is not None seen_in_page: Set[str] = set() for repo in page: if not repo: # null repositories in listings happen sometimes... continue if repo["html_url"] in seen_in_page: continue seen_in_page.add(repo["html_url"]) pushed_at_str = repo.get("pushed_at") pushed_at: Optional[datetime.datetime] = None if pushed_at_str: pushed_at = iso8601.parse_date(pushed_at_str) yield ListedOrigin( lister_id=self.lister_obj.id, url=repo["html_url"], visit_type="git", last_update=pushed_at, ) def commit_page(self, page: List[Dict[str, Any]]): """Update the currently stored state using the latest listed page""" if self.relisting: # Don't update internal state when relisting return if not page: # Sometimes, when you reach the end of the world, GitHub returns an empty # page of repositories return last_id = page[-1]["id"] if last_id > self.state.last_seen_id: self.state.last_seen_id = last_id def finalize(self): if self.relisting: return # Pull fresh lister state from the scheduler backend scheduler_state = self.get_state_from_scheduler() # Update the lister state in the backend only if the last seen id of # the current run is higher than that stored in the database. if self.state.last_seen_id > scheduler_state.last_seen_id: self.updated = True diff --git a/swh/lister/github/tests/test_lister.py b/swh/lister/github/tests/test_lister.py index 6bb2264..2c874ae 100644 --- a/swh/lister/github/tests/test_lister.py +++ b/swh/lister/github/tests/test_lister.py @@ -1,417 +1,418 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import logging +import time from typing import Any, Dict, Iterator, List, Optional, Union import pytest import requests_mock -from swh.lister.github.lister import GitHubLister, time +from swh.lister.github.lister import GitHubLister from swh.lister.pattern import CredentialsType, ListerStats from swh.scheduler.interface import SchedulerInterface from swh.scheduler.model import Lister NUM_PAGES = 10 ORIGIN_COUNT = GitHubLister.PAGE_SIZE * NUM_PAGES def github_repo(i: int) -> Dict[str, Union[int, str]]: """Basic repository information returned by the GitHub API""" repo: Dict[str, Union[int, str]] = { "id": i, "html_url": f"https://github.com/origin/{i}", } # Set the pushed_at date on one of the origins if i == 4321: repo["pushed_at"] = "2018-11-08T13:16:24Z" return repo def github_response_callback( request: requests_mock.request._RequestObjectProxy, context: requests_mock.response._Context, ) -> List[Dict[str, Union[str, int]]]: """Return minimal GitHub API responses for the common case where the loader hasn't been rate-limited""" # Check request headers assert request.headers["Accept"] == "application/vnd.github.v3+json" assert "Software Heritage Lister" in request.headers["User-Agent"] # Check request parameters: per_page == 1000, since = last_repo_id assert "per_page" in request.qs assert request.qs["per_page"] == [str(GitHubLister.PAGE_SIZE)] assert "since" in request.qs since = int(request.qs["since"][0]) next_page = since + GitHubLister.PAGE_SIZE if next_page < ORIGIN_COUNT: # the first id for the next page is within our origin count; add a Link # header to the response next_url = ( GitHubLister.API_URL + f"?per_page={GitHubLister.PAGE_SIZE}&since={next_page}" ) context.headers["Link"] = f"<{next_url}>; rel=next" return [github_repo(i) for i in range(since + 1, min(next_page, ORIGIN_COUNT) + 1)] @pytest.fixture() def requests_mocker() -> Iterator[requests_mock.Mocker]: with requests_mock.Mocker() as mock: mock.get(GitHubLister.API_URL, json=github_response_callback) yield mock def get_lister_data(swh_scheduler: SchedulerInterface) -> Lister: """Retrieve the data for the GitHub Lister""" return swh_scheduler.get_or_create_lister(name="github", instance_name="github") def set_lister_state(swh_scheduler: SchedulerInterface, state: Dict[str, Any]) -> None: """Set the state of the lister in database""" lister = swh_scheduler.get_or_create_lister(name="github", instance_name="github") lister.current_state = state swh_scheduler.update_lister(lister) def check_origin_4321(swh_scheduler: SchedulerInterface, lister: Lister) -> None: """Check that origin 4321 exists and has the proper last_update timestamp""" origin_4321_req = swh_scheduler.get_listed_origins( url="https://github.com/origin/4321" ) assert len(origin_4321_req.results) == 1 origin_4321 = origin_4321_req.results[0] assert origin_4321.lister_id == lister.id assert origin_4321.visit_type == "git" assert origin_4321.last_update == datetime.datetime( 2018, 11, 8, 13, 16, 24, tzinfo=datetime.timezone.utc ) def check_origin_5555(swh_scheduler: SchedulerInterface, lister: Lister) -> None: """Check that origin 5555 exists and has no last_update timestamp""" origin_5555_req = swh_scheduler.get_listed_origins( url="https://github.com/origin/5555" ) assert len(origin_5555_req.results) == 1 origin_5555 = origin_5555_req.results[0] assert origin_5555.lister_id == lister.id assert origin_5555.visit_type == "git" assert origin_5555.last_update is None def test_from_empty_state( swh_scheduler, caplog, requests_mocker: requests_mock.Mocker ) -> None: caplog.set_level(logging.DEBUG, "swh.lister.github.lister") # Run the lister in incremental mode lister = GitHubLister(scheduler=swh_scheduler) res = lister.run() assert res == ListerStats(pages=NUM_PAGES, origins=ORIGIN_COUNT) listed_origins = swh_scheduler.get_listed_origins(limit=ORIGIN_COUNT + 1) assert len(listed_origins.results) == ORIGIN_COUNT assert listed_origins.next_page_token is None lister_data = get_lister_data(swh_scheduler) assert lister_data.current_state == {"last_seen_id": ORIGIN_COUNT} check_origin_4321(swh_scheduler, lister_data) check_origin_5555(swh_scheduler, lister_data) def test_incremental(swh_scheduler, caplog, requests_mocker) -> None: caplog.set_level(logging.DEBUG, "swh.lister.github.lister") # Number of origins to skip skip_origins = 2000 expected_origins = ORIGIN_COUNT - skip_origins # Bump the last_seen_id in the scheduler backend set_lister_state(swh_scheduler, {"last_seen_id": skip_origins}) # Run the lister in incremental mode lister = GitHubLister(scheduler=swh_scheduler) res = lister.run() # add 1 page to the number of full_pages if partial_page_len is not 0 full_pages, partial_page_len = divmod(expected_origins, GitHubLister.PAGE_SIZE) expected_pages = full_pages + bool(partial_page_len) assert res == ListerStats(pages=expected_pages, origins=expected_origins) listed_origins = swh_scheduler.get_listed_origins(limit=expected_origins + 1) assert len(listed_origins.results) == expected_origins assert listed_origins.next_page_token is None lister_data = get_lister_data(swh_scheduler) assert lister_data.current_state == {"last_seen_id": ORIGIN_COUNT} check_origin_4321(swh_scheduler, lister_data) check_origin_5555(swh_scheduler, lister_data) def test_relister(swh_scheduler, caplog, requests_mocker) -> None: caplog.set_level(logging.DEBUG, "swh.lister.github.lister") # Only set this state as a canary: in the currently tested mode, the lister # should not be touching it. set_lister_state(swh_scheduler, {"last_seen_id": 123}) # Use "relisting" mode to list origins between id 10 and 1011 lister = GitHubLister(scheduler=swh_scheduler, first_id=10, last_id=1011) res = lister.run() # Make sure we got two full pages of results assert res == ListerStats(pages=2, origins=2000) # Check that the relisting mode hasn't touched the stored state. lister_data = get_lister_data(swh_scheduler) assert lister_data.current_state == {"last_seen_id": 123} def github_ratelimit_callback( request: requests_mock.request._RequestObjectProxy, context: requests_mock.response._Context, ratelimit_reset: Optional[int], ) -> Dict[str, str]: """Return a rate-limited GitHub API response.""" # Check request headers assert request.headers["Accept"] == "application/vnd.github.v3+json" assert "Software Heritage Lister" in request.headers["User-Agent"] if "Authorization" in request.headers: context.status_code = 429 else: context.status_code = 403 if ratelimit_reset is not None: context.headers["X-Ratelimit-Reset"] = str(ratelimit_reset) return { "message": "API rate limit exceeded for .", "documentation_url": "https://developer.github.com/v3/#rate-limiting", } @pytest.fixture() def num_before_ratelimit() -> int: """Number of successful requests before the ratelimit hits""" return 0 @pytest.fixture() def num_ratelimit() -> Optional[int]: """Number of rate-limited requests; None means infinity""" return None @pytest.fixture() def ratelimit_reset() -> Optional[int]: """Value of the X-Ratelimit-Reset header on ratelimited responses""" return None @pytest.fixture() def requests_ratelimited( num_before_ratelimit: int, num_ratelimit: Optional[int], ratelimit_reset: Optional[int], ) -> Iterator[requests_mock.Mocker]: """Mock requests to the GitHub API, returning a rate-limiting status code after `num_before_ratelimit` requests. GitHub does inconsistent rate-limiting: - Anonymous requests return a 403 status code - Authenticated requests return a 429 status code, with an X-Ratelimit-Reset header. This fixture takes multiple arguments (which can be overridden with a :func:`pytest.mark.parametrize` parameter): - num_before_ratelimit: the global number of requests until the ratelimit triggers - num_ratelimit: the number of requests that return a rate-limited response. - ratelimit_reset: the timestamp returned in X-Ratelimit-Reset if the request is authenticated. The default values set in the previous fixtures make all requests return a rate limit response. """ current_request = 0 def response_callback(request, context): nonlocal current_request current_request += 1 if num_before_ratelimit < current_request and ( num_ratelimit is None or current_request < num_before_ratelimit + num_ratelimit + 1 ): return github_ratelimit_callback(request, context, ratelimit_reset) else: return github_response_callback(request, context) with requests_mock.Mocker() as mock: mock.get(GitHubLister.API_URL, json=response_callback) yield mock def test_anonymous_ratelimit(swh_scheduler, caplog, requests_ratelimited) -> None: - caplog.set_level(logging.DEBUG, "swh.lister.github.lister") + caplog.set_level(logging.DEBUG, "swh.lister.github.utils") lister = GitHubLister(scheduler=swh_scheduler) assert lister.github_session.anonymous assert "using anonymous mode" in caplog.records[-1].message caplog.clear() res = lister.run() assert res == ListerStats(pages=0, origins=0) last_log = caplog.records[-1] assert last_log.levelname == "WARNING" assert "No X-Ratelimit-Reset value found in responses" in last_log.message @pytest.fixture def github_credentials() -> List[Dict[str, str]]: """Return a static list of GitHub credentials""" return sorted( [{"username": f"swh{i:d}", "token": f"token-{i:d}"} for i in range(3)] + [ {"username": f"swh-legacy{i:d}", "password": f"token-legacy-{i:d}"} for i in range(3) ], key=lambda c: c["username"], ) @pytest.fixture def all_tokens(github_credentials) -> List[str]: """Return the list of tokens matching the static credential""" return [t.get("token", t.get("password")) for t in github_credentials] @pytest.fixture def lister_credentials(github_credentials: List[Dict[str, str]]) -> CredentialsType: """Return the credentials formatted for use by the lister""" return {"github": {"github": github_credentials}} def test_authenticated_credentials( swh_scheduler, caplog, github_credentials, lister_credentials, all_tokens ): """Test credentials management when the lister is authenticated""" caplog.set_level(logging.DEBUG, "swh.lister.github.lister") lister = GitHubLister(scheduler=swh_scheduler, credentials=lister_credentials) assert lister.github_session.token_index == 0 assert sorted(lister.credentials, key=lambda t: t["username"]) == github_credentials assert lister.github_session.session.headers["Authorization"] in [ "token %s" % t for t in all_tokens ] def fake_time_sleep(duration: float, sleep_calls: Optional[List[float]] = None): """Record calls to time.sleep in the sleep_calls list""" if duration < 0: raise ValueError("Can't sleep for a negative amount of time!") if sleep_calls is not None: sleep_calls.append(duration) def fake_time_time(): """Return 0 when running time.time()""" return 0 @pytest.fixture def monkeypatch_sleep_calls(monkeypatch) -> Iterator[List[float]]: """Monkeypatch `time.time` and `time.sleep`. Returns a list cumulating the arguments passed to time.sleep().""" sleeps: List[float] = [] monkeypatch.setattr(time, "sleep", lambda d: fake_time_sleep(d, sleeps)) monkeypatch.setattr(time, "time", fake_time_time) yield sleeps @pytest.mark.parametrize( "num_ratelimit", [1] ) # return a single rate-limit response, then continue def test_ratelimit_once_recovery( swh_scheduler, caplog, requests_ratelimited, num_ratelimit, monkeypatch_sleep_calls, lister_credentials, ): """Check that the lister recovers from hitting the rate-limit once""" - caplog.set_level(logging.DEBUG, "swh.lister.github.lister") + caplog.set_level(logging.DEBUG, "swh.lister.github.utils") lister = GitHubLister(scheduler=swh_scheduler, credentials=lister_credentials) res = lister.run() # check that we used all the pages assert res == ListerStats(pages=NUM_PAGES, origins=ORIGIN_COUNT) token_users = [] for record in caplog.records: if "Using authentication token" in record.message: token_users.append(record.args[0]) # check that we used one more token than we saw rate limited requests assert len(token_users) == 1 + num_ratelimit # check that we slept for one second between our token uses assert monkeypatch_sleep_calls == [1] @pytest.mark.parametrize( # Do 5 successful requests, return 6 ratelimits (to exhaust the credentials) with a # set value for X-Ratelimit-Reset, then resume listing successfully. "num_before_ratelimit, num_ratelimit, ratelimit_reset", [(5, 6, 123456)], ) def test_ratelimit_reset_sleep( swh_scheduler, caplog, requests_ratelimited, monkeypatch_sleep_calls, num_before_ratelimit, ratelimit_reset, github_credentials, lister_credentials, ): """Check that the lister properly handles rate-limiting when providing it with authentication tokens""" - caplog.set_level(logging.DEBUG, "swh.lister.github.lister") + caplog.set_level(logging.DEBUG, "swh.lister.github.utils") lister = GitHubLister(scheduler=swh_scheduler, credentials=lister_credentials) res = lister.run() assert res == ListerStats(pages=NUM_PAGES, origins=ORIGIN_COUNT) # We sleep 1 second every time we change credentials, then we sleep until # ratelimit_reset + 1 expected_sleep_calls = len(github_credentials) * [1] + [ratelimit_reset + 1] assert monkeypatch_sleep_calls == expected_sleep_calls found_exhaustion_message = False for record in caplog.records: if record.levelname == "INFO": if "Rate limits exhausted for all tokens" in record.message: found_exhaustion_message = True break assert found_exhaustion_message diff --git a/swh/lister/github/utils.py b/swh/lister/github/utils.py new file mode 100644 index 0000000..9313fba --- /dev/null +++ b/swh/lister/github/utils.py @@ -0,0 +1,170 @@ +# Copyright (C) 2020-2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import logging +import random +import time +from typing import Dict, List, Optional + +import requests +from tenacity import ( + retry, + retry_any, + retry_if_exception_type, + retry_if_result, + wait_exponential, +) + +from .. import USER_AGENT + +logger = logging.getLogger(__name__) + + +class RateLimited(Exception): + def __init__(self, response): + self.reset_time: Optional[int] + + # Figure out how long we need to sleep because of that rate limit + ratelimit_reset = response.headers.get("X-Ratelimit-Reset") + retry_after = response.headers.get("Retry-After") + if ratelimit_reset is not None: + self.reset_time = int(ratelimit_reset) + elif retry_after is not None: + self.reset_time = int(time.time()) + int(retry_after) + 1 + else: + logger.warning( + "Received a rate-limit-like status code %s, but no rate-limit " + "headers set. Response content: %s", + response.status_code, + response.content, + ) + self.reset_time = None + self.response = response + + +class MissingRateLimitReset(Exception): + pass + + +class GitHubSession: + """Manages a :class:`requests.Session` with (optionally) multiple credentials, + and cycles through them when reaching rate-limits.""" + + def __init__(self, credentials: Optional[List[Dict[str, str]]] = None) -> None: + """Initialize a requests session with the proper headers for requests to + GitHub.""" + self.credentials = credentials + if self.credentials: + random.shuffle(self.credentials) + + self.session = requests.Session() + + self.session.headers.update( + {"Accept": "application/vnd.github.v3+json", "User-Agent": USER_AGENT} + ) + + self.anonymous = not self.credentials + + if self.anonymous: + logger.warning("No tokens set in configuration, using anonymous mode") + + self.token_index = -1 + self.current_user: Optional[str] = None + + if not self.anonymous: + # Initialize the first token value in the session headers + self.set_next_session_token() + + def set_next_session_token(self) -> None: + """Update the current authentication token with the next one in line.""" + + assert self.credentials + + self.token_index = (self.token_index + 1) % len(self.credentials) + + auth = self.credentials[self.token_index] + + self.current_user = auth["username"] + logger.debug("Using authentication token for user %s", self.current_user) + + if "password" in auth: + token = auth["password"] + else: + token = auth["token"] + + self.session.headers.update({"Authorization": f"token {token}"}) + + @retry( + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_any( + # ChunkedEncodingErrors happen when the TLS connection gets reset, e.g. + # when running the lister on a connection with high latency + retry_if_exception_type(requests.exceptions.ChunkedEncodingError), + # 502 status codes happen for a Server Error, sometimes + retry_if_result(lambda r: r.status_code == 502), + ), + ) + def _request(self, url: str) -> requests.Response: + response = self.session.get(url) + + if ( + # GitHub returns inconsistent status codes between unauthenticated + # rate limit and authenticated rate limits. Handle both. + response.status_code == 429 + or (self.anonymous and response.status_code == 403) + ): + raise RateLimited(response) + + return response + + def request(self, url) -> requests.Response: + """Repeatedly requests the given URL, cycling through credentials and sleeping + if necessary; until either a successful response or :exc:`MissingRateLimitReset` + """ + # The following for/else loop handles rate limiting; if successful, + # it provides the rest of the function with a `response` object. + # + # If all tokens are rate-limited, we sleep until the reset time, + # then `continue` into another iteration of the outer while loop, + # attempting to get data from the same URL again. + + while True: + max_attempts = len(self.credentials) if self.credentials else 1 + reset_times: Dict[int, int] = {} # token index -> time + for attempt in range(max_attempts): + try: + return self._request(url) + except RateLimited as e: + reset_info = "(unknown reset)" + if e.reset_time is not None: + reset_times[self.token_index] = e.reset_time + reset_info = "(resetting in %ss)" % (e.reset_time - time.time()) + + if not self.anonymous: + logger.info( + "Rate limit exhausted for current user %s %s", + self.current_user, + reset_info, + ) + # Use next token in line + self.set_next_session_token() + # Wait one second to avoid triggering GitHub's abuse rate limits + time.sleep(1) + + # All tokens have been rate-limited. What do we do? + + if not reset_times: + logger.warning( + "No X-Ratelimit-Reset value found in responses for any token; " + "Giving up." + ) + raise MissingRateLimitReset() + + sleep_time = max(reset_times.values()) - time.time() + 1 + logger.info( + "Rate limits exhausted for all tokens. Sleeping for %f seconds.", + sleep_time, + ) + time.sleep(sleep_time)