"""Module contains the OIDC Authentication class.""" import re from collections.abc import Callable from functools import wraps from http import HTTPStatus from json import JSONDecodeError from typing import Any from fastapi.exceptions import HTTPException from fastapi.requests import Request from httpx import AsyncClient from oauth2_lib.fastapi import OIDCAuth, OIDCUserModel from oauth2_lib.settings import oauth2lib_settings from structlog import get_logger logger = get_logger(__name__) _CALLBACK_STEP_API_URL_PATTERN = re.compile( r"^/api/processes/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})" r"/callback/([0-9a-zA-Z\-_]+)$" ) def _is_client_credentials_token(intercepted_token: dict) -> bool: return "sub" not in intercepted_token def _is_callback_step_endpoint(request: Request) -> bool: """Check if the request is a callback step API call.""" return re.match(_CALLBACK_STEP_API_URL_PATTERN, request.url.path) is not None def ensure_openid_config_loaded(func: Callable) -> Callable: """Ensure that the openid_config is loaded before calling the function.""" @wraps(func) async def wrapper(self: OIDCAuth, async_request: AsyncClient, *args: Any, **kwargs: Any) -> dict: await self.check_openid_config(async_request) return await func(self, async_request, *args, **kwargs) return wrapper class OIDCAuthentication(OIDCAuth): """OIDCUser class extends the ``HTTPBearer`` class to do extra verification. The class will act as follows: Validate the Credentials at the AAI proxy by calling the UserInfo endpoint. """ @staticmethod async def is_bypassable_request(request: Request) -> bool: """Check if the request is a callback step API call.""" return _is_callback_step_endpoint(request=request) @ensure_openid_config_loaded async def userinfo(self, async_request: AsyncClient, token: str) -> OIDCUserModel: """Get the userinfo from the openid server. Args: async_request: The async request. token: The access token. Returns: OIDCUserModel: OIDC user model from openid server. """ assert self.openid_config is not None, "OpenID config is not loaded" # noqa: S101 intercepted_token = await self.introspect_token(async_request, token) client_id = intercepted_token.get("client_id") if _is_client_credentials_token(intercepted_token): return OIDCUserModel(client_id=client_id) response = await async_request.post( self.openid_config.userinfo_endpoint, data={"token": token}, headers={"Authorization": f"Bearer {token}"}, ) try: data = dict(response.json()) except JSONDecodeError as err: logger.debug( "Unable to parse userinfo response", detail=response.text, resource_server_id=self.resource_server_id, openid_url=self.openid_url, ) raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text) from err logger.debug("Response from openid userinfo", response=data) if response.status_code not in range(200, 300): logger.debug( "Userinfo cannot find an active token, user unauthorized", detail=response.text, resource_server_id=self.resource_server_id, openid_url=self.openid_url, ) raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text) data["client_id"] = client_id return OIDCUserModel(data) @ensure_openid_config_loaded async def introspect_token(self, async_request: AsyncClient, token: str) -> dict: """Introspect the access token to see if it is a valid token. Args: async_request: The async request token: the access_token Returns: dict from openid server """ assert self.openid_config is not None, "OpenID config is not loaded" # noqa: S101 endpoint = self.openid_config.introspect_endpoint or self.openid_config.introspection_endpoint or "" response = await async_request.post( endpoint, data={"token": token, "client_id": self.resource_server_id}, headers={"Content-Type": "application/x-www-form-urlencoded"}, ) try: data = dict(response.json()) except JSONDecodeError as err: logger.debug( "Unable to parse introspect response", detail=response.text, resource_server_id=self.resource_server_id, openid_url=self.openid_url, ) raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text) from err logger.debug("Response from openid introspect", response=data) if response.status_code not in range(200, 300): logger.debug( "Introspect cannot find an active token, user unauthorized", detail=response.text, resource_server_id=self.resource_server_id, openid_url=self.openid_url, ) raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text) if "active" not in data: logger.error("Token doesn't have the mandatory 'active' key, probably caused by a caching problem") raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="Missing active key") if not data.get("active", False): logger.info("User is not active", user_info=data) raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="User is not active") return data oidc_instance = OIDCAuthentication( openid_url=oauth2lib_settings.OIDC_BASE_URL, openid_config_url=oauth2lib_settings.OIDC_CONF_URL, resource_server_id=oauth2lib_settings.OAUTH2_RESOURCE_SERVER_ID, resource_server_secret=oauth2lib_settings.OAUTH2_RESOURCE_SERVER_SECRET, oidc_user_model_cls=OIDCUserModel, )