diff --git a/swh/web/auth/backends.py b/swh/web/auth/backends.py --- a/swh/web/auth/backends.py +++ b/swh/web/auth/backends.py @@ -57,7 +57,15 @@ def _oidc_user_from_profile(oidc_profile: Dict[str, Any]) -> OIDCUser: # decode JWT token - decoded_token = _oidc_client.decode_token(oidc_profile["access_token"]) + try: + access_token = oidc_profile["access_token"] + decoded_token = _oidc_client.decode_token(access_token) + # access token has expired or is invalid + except Exception: + # get a new access token from authentication provider + oidc_profile = _oidc_client.refresh_token(oidc_profile["refresh_token"]) + # decode access token + decoded_token = _oidc_client.decode_token(oidc_profile["access_token"]) # create OIDCUser from decoded token user = _oidc_user_from_decoded_token(decoded_token) @@ -77,6 +85,16 @@ if hasattr(user, key): setattr(user, key, val) + # put OIDC profile in cache or update it after token renewal + cache_key = f"oidc_user_{user.id}" + if cache.get(cache_key) is None or access_token != oidc_profile["access_token"]: + # set cache key TTL as refresh token expiration time + assert user.refresh_expires_at + ttl = int(user.refresh_expires_at.timestamp() - timezone.now().timestamp()) + + # save oidc_profile in cache + cache.set(cache_key, oidc_profile, timeout=max(0, ttl)) + return user @@ -95,12 +113,6 @@ # create Django user user = _oidc_user_from_profile(oidc_profile) - # set cache key TTL as access token expiration time - assert user.expires_at - ttl = int(user.expires_at.timestamp() - timezone.now().timestamp()) - - # save oidc_profile in cache - cache.set(f"oidc_user_{user.id}", oidc_profile, timeout=max(0, ttl)) except Exception as e: sentry_sdk.capture_exception(e) diff --git a/swh/web/tests/auth/test_backends.py b/swh/web/tests/auth/test_backends.py --- a/swh/web/tests/auth/test_backends.py +++ b/swh/web/tests/auth/test_backends.py @@ -4,6 +4,7 @@ # See top-level LICENSE file for more information from datetime import datetime, timedelta +from unittest.mock import Mock import pytest @@ -92,6 +93,67 @@ assert user is None +@pytest.mark.django_db +def test_oidc_code_pkce_auth_backend_refresh_token_success(mocker, request_factory): + """ + Checks access token renewal success using refresh token. + """ + kc_oidc_mock = mock_keycloak(mocker) + + oidc_profile = sample_data.oidc_profile + decoded_token = kc_oidc_mock.decode_token(oidc_profile["access_token"]) + new_access_token = "new_access_token" + + def _refresh_token(refresh_token): + oidc_profile = dict(sample_data.oidc_profile) + oidc_profile["access_token"] = new_access_token + return oidc_profile + + def _decode_token(access_token): + if access_token != new_access_token: + raise Exception("access token token has expired") + else: + return decoded_token + + kc_oidc_mock.decode_token = Mock() + kc_oidc_mock.decode_token.side_effect = _decode_token + kc_oidc_mock.refresh_token.side_effect = _refresh_token + + user = _authenticate_user(request_factory) + + kc_oidc_mock.refresh_token.assert_called_with( + sample_data.oidc_profile["refresh_token"] + ) + + assert user is not None + + +@pytest.mark.django_db +def test_oidc_code_pkce_auth_backend_refresh_token_failure(mocker, request_factory): + """ + Checks access token renewal failure using refresh token. + """ + kc_oidc_mock = mock_keycloak(mocker) + + def _refresh_token(refresh_token): + raise Exception("OIDC session has expired") + + def _decode_token(access_token): + raise Exception("access token token has expired") + + kc_oidc_mock.decode_token = Mock() + kc_oidc_mock.decode_token.side_effect = _decode_token + kc_oidc_mock.refresh_token.side_effect = _refresh_token + + user = _authenticate_user(request_factory) + + kc_oidc_mock.refresh_token.assert_called_with( + sample_data.oidc_profile["refresh_token"] + ) + + assert user is None + + @pytest.mark.django_db def test_oidc_code_pkce_auth_backend_permissions(mocker, request_factory): """ diff --git a/swh/web/tests/auth/test_middlewares.py b/swh/web/tests/auth/test_middlewares.py --- a/swh/web/tests/auth/test_middlewares.py +++ b/swh/web/tests/auth/test_middlewares.py @@ -3,10 +3,10 @@ # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information -from datetime import datetime import pytest +from django.core.cache import cache from django.test import modify_settings from swh.web.common.utils import reverse @@ -20,25 +20,38 @@ MIDDLEWARE={"remove": ["swh.web.auth.middlewares.OIDCSessionRefreshMiddleware"]} ) def test_oidc_session_refresh_middleware_disabled(client, mocker): - # authenticate but make session expires immediately - kc_oidc_mock = mock_keycloak(mocker, exp=int(datetime.now().timestamp())) + # authenticate user + kc_oidc_mock = mock_keycloak(mocker) client.login(code="", code_verifier="", redirect_uri="") kc_oidc_mock.authorization_code.assert_called() url = reverse("swh-web-homepage") + + # visit url first to get user from response + response = check_html_get_response(client, url, status_code=200) + + # simulate OIDC session expiration + cache.delete(f"oidc_user_{response.wsgi_request.user.id}") + # no redirection for silent refresh check_html_get_response(client, url, status_code=200) @pytest.mark.django_db def test_oidc_session_refresh_middleware_enabled(client, mocker): - # authenticate but make session expires immediately - kc_oidc_mock = mock_keycloak(mocker, exp=int(datetime.now().timestamp())) + # authenticate user + kc_oidc_mock = mock_keycloak(mocker) client.login(code="", code_verifier="", redirect_uri="") kc_oidc_mock.authorization_code.assert_called() url = reverse("swh-web-homepage") + # visit url first to get user from response + response = check_html_get_response(client, url, status_code=200) + + # simulate OIDC session expiration + cache.delete(f"oidc_user_{response.wsgi_request.user.id}") + # should redirect for silent session refresh resp = check_html_get_response(client, url, status_code=302) silent_refresh_url = reverse(