diff --git a/docs/index.rst b/docs/index.rst --- a/docs/index.rst +++ b/docs/index.rst @@ -2,6 +2,8 @@ .. include:: README.rst +.. _swh-web-client-auth: + Authentication -------------- @@ -81,6 +83,21 @@ $ swh authentication refresh $REFRESH_TOKEN "......." +Note that if you intend to use the :class:`swh.web.client.client.WebAPIClient` +class, the access token renewal will be automatically handled if you call +method :meth:`swh.web.client.client.WebAPIClient.authenticate` prior to +sending any requests. To activate authentication, use the following code snippet:: + + from swh.web.client import WebAPIClient + + REFRESH_TOKEN = '.......' # Use "swh authentication login" command to get it + + client = WebAPIClient() + client.authenticate(REFRESH_TOKEN) + + # All requests to the Web API will be authenticated + resp = client.get('swh:1:rev:aafb16d69fd30ff58afdd69036a26047f3aebdc6') + It is also possible to ``logout`` from the authenticated OpenID Connect session which invalidates all previously emitted tokens. diff --git a/swh/web/client/auth.py b/swh/web/client/auth.py --- a/swh/web/client/auth.py +++ b/swh/web/client/auth.py @@ -13,6 +13,15 @@ SWH_WEB_CLIENT_ID = 'swh-web' +class AuthenticationError(Exception): + """Authentication related error. + + Example: A bearer token has expired. + + """ + pass + + class OpenIDConnectSession: """ Simple class wrapping requests sent to an OpenID Connect server. diff --git a/swh/web/client/client.py b/swh/web/client/client.py --- a/swh/web/client/client.py +++ b/swh/web/client/client.py @@ -28,7 +28,8 @@ """ -from typing import Any, Callable, Dict, Generator, List, Union +from datetime import datetime, timedelta +from typing import Any, Callable, Dict, Generator, List, Optional, Union from urllib.parse import urlparse import dateutil.parser @@ -39,6 +40,9 @@ from swh.model.identifiers import PersistentId as PID from swh.model.identifiers import parse_persistent_identifier as parse_pid +from .auth import ( + AuthenticationError, OpenIDConnectSession, SWH_OIDC_SERVER_URL +) PIDish = Union[PID, str] @@ -115,7 +119,8 @@ """ - def __init__(self, api_url='https://archive.softwareheritage.org/api/1'): + def __init__(self, api_url='https://archive.softwareheritage.org/api/1', + auth_url=SWH_OIDC_SERVER_URL): """Create a client for the Software Heritage Web API See: https://archive.softwareheritage.org/api/ @@ -130,6 +135,8 @@ self.api_url = api_url self.api_path = u.path + self.oidc_session = OpenIDConnectSession(oidc_server_url=auth_url) + self.oidc_profile: Optional[Dict[str, Any]] = None self._getters: Dict[str, Callable[[PIDish], Any]] = { CONTENT: self.content, @@ -159,11 +166,20 @@ url = '/'.join([self.api_url, query]) r = None + headers = {} + if self.oidc_profile is not None: + # use bearer token authentication + if datetime.now() > self.oidc_profile['expires_at']: + # refresh access token if it has expired + self.authenticate(self.oidc_profile['refresh_token']) + access_token = self.oidc_profile['access_token'] + headers = {'Authorization': f'Bearer {access_token}'} + if http_method == 'get': - r = requests.get(url, **req_args) + r = requests.get(url, **req_args, headers=headers) r.raise_for_status() elif http_method == 'head': - r = requests.head(url, **req_args) + r = requests.head(url, **req_args, headers=headers) else: raise ValueError(f'unsupported HTTP method: {http_method}') @@ -397,3 +413,28 @@ r.raise_for_status() yield from r.iter_content(chunk_size=None, decode_unicode=False) + + def authenticate(self, refresh_token: str): + """Authenticate API requests using OpenID Connect bearer token + + Args: + refresh_token: A refresh token retrieved using the + ``swh authentication login`` command (see + :ref:`swh-web-client-auth` section in main documentation) + + Raises: + swh.web.client.auth.AuthenticationError: if authentication fails + + """ + now = datetime.now() + try: + self.oidc_profile = self.oidc_session.refresh(refresh_token) + if 'expires_in' in self.oidc_profile: + expires_in = self.oidc_profile['expires_in'] + expires_at = now + timedelta(seconds=expires_in) + self.oidc_profile['expires_at'] = expires_at + except Exception as e: + raise AuthenticationError(str(e)) + if 'access_token' not in self.oidc_profile: + # JSON error response + raise AuthenticationError(self.oidc_profile) diff --git a/swh/web/client/tests/conftest.py b/swh/web/client/tests/conftest.py --- a/swh/web/client/tests/conftest.py +++ b/swh/web/client/tests/conftest.py @@ -15,12 +15,13 @@ headers = {} if api_call == "snapshot/cabcc7d7bf639bbe1cc3b41989e1806618dd5764/": # monkey patch the only URL that require a special response headers - # (to make the client insit and follow pagination) + # (to make the client init and follow pagination) headers = { "Link": f"<{API_URL}/{api_call}?branches_count=1000&branches_from=refs/tags/v3.0-rc7>; rel=\"next\"" # NoQA: E501 } requests_mock.get(f"{API_URL}/{api_call}", text=data, headers=headers) + return requests_mock @pytest.fixture diff --git a/swh/web/client/tests/test_cli.py b/swh/web/client/tests/test_cli.py --- a/swh/web/client/tests/test_cli.py +++ b/swh/web/client/tests/test_cli.py @@ -11,7 +11,7 @@ runner = CliRunner() -_oidc_profile = { +oidc_profile = { 'access_token': 'some-access-token', 'expires_in': 600, 'refresh_expires_in': 0, @@ -27,12 +27,12 @@ mock_getpass.return_value = 'password' mock_oidc_session = mocker.patch('swh.web.client.cli.OpenIDConnectSession') mock_login = mock_oidc_session.return_value.login - mock_login.return_value = _oidc_profile + mock_login.return_value = oidc_profile result = runner.invoke(authentication, ['login', 'username'], input='password\n') assert result.exit_code == 0 - assert json.loads(result.output) == _oidc_profile + assert json.loads(result.output) == oidc_profile mock_login.side_effect = Exception('Auth error') @@ -45,16 +45,16 @@ mock_oidc_session = mocker.patch('swh.web.client.cli.OpenIDConnectSession') mock_refresh = mock_oidc_session.return_value.refresh - mock_refresh.return_value = _oidc_profile + mock_refresh.return_value = oidc_profile result = runner.invoke(authentication, - ['refresh', _oidc_profile['refresh_token']]) + ['refresh', oidc_profile['refresh_token']]) assert result.exit_code == 0 - assert json.loads(result.stdout) == _oidc_profile['access_token'] + assert json.loads(result.stdout) == oidc_profile['access_token'] mock_refresh.side_effect = Exception('Auth error') result = runner.invoke(authentication, - ['refresh', _oidc_profile['refresh_token']]) + ['refresh', oidc_profile['refresh_token']]) assert result.exit_code == 1 @@ -64,10 +64,10 @@ mock_logout = mock_oidc_session.return_value.logout result = runner.invoke(authentication, - ['logout', _oidc_profile['refresh_token']]) + ['logout', oidc_profile['refresh_token']]) assert result.exit_code == 0 mock_logout.side_effect = Exception('Auth error') result = runner.invoke(authentication, - ['logout', _oidc_profile['refresh_token']]) + ['logout', oidc_profile['refresh_token']]) assert result.exit_code == 1 diff --git a/swh/web/client/tests/test_web_api_client.py b/swh/web/client/tests/test_web_api_client.py --- a/swh/web/client/tests/test_web_api_client.py +++ b/swh/web/client/tests/test_web_api_client.py @@ -3,10 +3,18 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from copy import copy +from datetime import datetime from dateutil.parser import parse as parse_date +from unittest.mock import call, Mock +import pytest + +from swh.web.client.auth import AuthenticationError from swh.model.identifiers import parse_persistent_identifier as parse_pid +from .test_cli import oidc_profile + def test_get_content(web_api_client, web_api_mock): pid = parse_pid("swh:1:cnt:fe95a46679d128ff167b7c55df5d02356c5a1ae1") @@ -102,3 +110,92 @@ snp.update(partial) assert len(snp) == 1391 + + +def test_authenticate_success(web_api_client, web_api_mock): + + rel_id = 'b9db10d00835e9a43e2eebef2db1d04d4ae82342' + url = f'{web_api_client.api_url}/release/{rel_id}/' + + web_api_client.oidc_session = Mock() + web_api_client.oidc_session.refresh.return_value = copy(oidc_profile) + + access_token = oidc_profile['access_token'] + refresh_token = 'user-refresh-token' + + web_api_client.authenticate(refresh_token) + + assert 'expires_at' in web_api_client.oidc_profile + + pid = parse_pid(f'swh:1:rel:{rel_id}') + web_api_client.get(pid) + + web_api_client.oidc_session.refresh.assert_called_once_with(refresh_token) + + sent_request = web_api_mock._adapter.last_request + + assert sent_request.url == url + assert 'Authorization' in sent_request.headers + + assert sent_request.headers['Authorization'] == f'Bearer {access_token}' + + +def test_authenticate_refresh_token(web_api_client, web_api_mock): + + rel_id = 'b9db10d00835e9a43e2eebef2db1d04d4ae82342' + url = f'{web_api_client.api_url}/release/{rel_id}/' + + oidc_profile_cp = copy(oidc_profile) + + web_api_client.oidc_session = Mock() + web_api_client.oidc_session.refresh.return_value = oidc_profile_cp + + refresh_token = 'user-refresh-token' + web_api_client.authenticate(refresh_token) + + assert 'expires_at' in web_api_client.oidc_profile + + # simulate access token expiration + web_api_client.oidc_profile['expires_at'] = datetime.now() + + access_token = 'new-access-token' + oidc_profile_cp['access_token'] = access_token + + pid = parse_pid(f'swh:1:rel:{rel_id}') + web_api_client.get(pid) + + calls = [call(refresh_token), call(oidc_profile['refresh_token'])] + web_api_client.oidc_session.refresh.assert_has_calls(calls) + + sent_request = web_api_mock._adapter.last_request + + assert sent_request.url == url + assert 'Authorization' in sent_request.headers + + assert sent_request.headers['Authorization'] == f'Bearer {access_token}' + + +def test_authenticate_failure(web_api_client, web_api_mock): + msg = 'Authentication error' + web_api_client.oidc_session = Mock() + web_api_client.oidc_session.refresh.side_effect = Exception(msg) + + refresh_token = 'user-refresh-token' + + with pytest.raises(AuthenticationError) as e: + web_api_client.authenticate(refresh_token) + + assert e.match(msg) + + oidc_error_response = { + 'error': 'invalid_grant', + 'error_description': 'Invalid refresh token', + } + + web_api_client.oidc_session.refresh.side_effect = None + web_api_client.oidc_session.refresh.return_value = oidc_error_response + + with pytest.raises(AuthenticationError) as e: + web_api_client.authenticate(refresh_token) + + assert e.match(repr(oidc_error_response))