"""Module contains the OIDC Authentication class.""" import re from collections.abc import Callable from functools import wraps from http import HTTPStatus from json import JSONDecodeError from typing import Any from fastapi.exceptions import HTTPException from fastapi.requests import Request from httpx import AsyncClient from oauth2_lib.fastapi import OIDCAuth, OIDCUserModel from oauth2_lib.settings import oauth2lib_settings from structlog import get_logger logger = get_logger(__name__) _CALLBACK_STEP_API_URL_PATTERN = re.compile( r"^/api/processes/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})" r"/callback/([0-9a-zA-Z\-_]+)$" ) def _is_client_credentials_token(intercepted_token: dict) -> bool: return "sub" not in intercepted_token def _is_callback_step_endpoint(request: Request) -> bool: """Check if the request is a callback step API call.""" return re.match(_CALLBACK_STEP_API_URL_PATTERN, request.url.path) is not None def ensure_openid_config_loaded(func: Callable) -> Callable: """Ensure that the openid_config is loaded before calling the function.""" @wraps(func) async def wrapper(self: OIDCAuth, async_request: AsyncClient, *args: Any, **kwargs: Any) -> dict: await self.check_openid_config(async_request) return await func(self, async_request, *args, **kwargs) return wrapper class OIDCAuthentication(OIDCAuth): """OIDCUser class extends the :term:`HTTPBearer` class to do extra verification. The class will act as follows: 1. Validate the Credentials at :term: `AAI` proxy by calling the UserInfo endpoint """ @staticmethod async def is_bypassable_request(request: Request) -> bool: """Check if the request is a callback step API call.""" return _is_callback_step_endpoint(request=request) @ensure_openid_config_loaded async def userinfo(self, async_request: AsyncClient, token: str) -> OIDCUserModel: """Get the userinfo from the openid server. :param AsyncClient async_request: The async request :param str token: the access_token :return: OIDCUserModel: OIDC user model from openid server """ intercepted_token = await self.introspect_token(async_request, token) client_id = intercepted_token.get("client_id") if _is_client_credentials_token(intercepted_token): return OIDCUserModel(client_id=client_id) response = await async_request.post( self.openid_config.userinfo_endpoint, data={"token": token}, headers={"Authorization": f"Bearer {token}"}, ) try: data = dict(response.json()) except JSONDecodeError as err: logger.debug( "Unable to parse userinfo response", detail=response.text, resource_server_id=self.resource_server_id, openid_url=self.openid_url, ) raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text) from err logger.debug("Response from openid userinfo", response=data) if response.status_code not in range(200, 300): logger.debug( "Userinfo cannot find an active token, user unauthorized", detail=response.text, resource_server_id=self.resource_server_id, openid_url=self.openid_url, ) raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text) data["client_id"] = client_id return OIDCUserModel(data) @ensure_openid_config_loaded async def introspect_token(self, async_request: AsyncClient, token: str) -> dict: """Introspect the access token to see if it is a valid token. :param async_request: The async request :param token: the access_token :return: dict from openid server """ endpoint = self.openid_config.introspect_endpoint or self.openid_config.introspection_endpoint or "" response = await async_request.post( endpoint, data={"token": token, "client_id": self.resource_server_id}, headers={"Content-Type": "application/x-www-form-urlencoded"}, ) try: data = dict(response.json()) except JSONDecodeError as err: logger.debug( "Unable to parse introspect response", detail=response.text, resource_server_id=self.resource_server_id, openid_url=self.openid_url, ) raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text) from err logger.debug("Response from openid introspect", response=data) if response.status_code not in range(200, 300): logger.debug( "Introspect cannot find an active token, user unauthorized", detail=response.text, resource_server_id=self.resource_server_id, openid_url=self.openid_url, ) raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text) if "active" not in data: logger.error("Token doesn't have the mandatory 'active' key, probably caused by a caching problem") raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="Missing active key") if not data.get("active", False): logger.info("User is not active", user_info=data) raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="User is not active") return data oidc_instance = OIDCAuthentication( openid_url=oauth2lib_settings.OIDC_BASE_URL, openid_config_url=oauth2lib_settings.OIDC_CONF_URL, resource_server_id=oauth2lib_settings.OAUTH2_RESOURCE_SERVER_ID, resource_server_secret=oauth2lib_settings.OAUTH2_RESOURCE_SERVER_SECRET, oidc_user_model_cls=OIDCUserModel, )