-
Karel van Klink authoredKarel van Klink authored
oidc.py 6.06 KiB
"""Module contains the OIDC Authentication class."""
import re
from collections.abc import Callable
from functools import wraps
from http import HTTPStatus
from json import JSONDecodeError
from typing import Any
from fastapi.exceptions import HTTPException
from fastapi.requests import Request
from httpx import AsyncClient
from oauth2_lib.fastapi import OIDCAuth, OIDCUserModel
from oauth2_lib.settings import oauth2lib_settings
from structlog import get_logger
logger = get_logger(__name__)
_CALLBACK_STEP_API_URL_PATTERN = re.compile(
r"^/api/processes/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})"
r"/callback/([0-9a-zA-Z\-_]+)$"
)
def _is_client_credentials_token(intercepted_token: dict) -> bool:
return "sub" not in intercepted_token
def _is_callback_step_endpoint(request: Request) -> bool:
"""Check if the request is a callback step API call."""
return re.match(_CALLBACK_STEP_API_URL_PATTERN, request.url.path) is not None
def ensure_openid_config_loaded(func: Callable) -> Callable:
"""Ensure that the openid_config is loaded before calling the function."""
@wraps(func)
async def wrapper(self: OIDCAuth, async_request: AsyncClient, *args: Any, **kwargs: Any) -> dict:
await self.check_openid_config(async_request)
return await func(self, async_request, *args, **kwargs)
return wrapper
class OIDCAuthentication(OIDCAuth):
"""OIDCUser class extends the ``HTTPBearer`` class to do extra verification.
The class will act as follows: Validate the Credentials at the AAI proxy by calling the UserInfo endpoint.
"""
@staticmethod
async def is_bypassable_request(request: Request) -> bool:
"""Check if the request is a callback step API call."""
return _is_callback_step_endpoint(request=request)
@ensure_openid_config_loaded
async def userinfo(self, async_request: AsyncClient, token: str) -> OIDCUserModel:
"""Get the userinfo from the openid server.
Args:
async_request: The async request.
token: The access token.
Returns:
OIDCUserModel: OIDC user model from openid server.
"""
assert self.openid_config is not None, "OpenID config is not loaded" # noqa: S101
intercepted_token = await self.introspect_token(async_request, token)
client_id = intercepted_token.get("client_id")
if _is_client_credentials_token(intercepted_token):
return OIDCUserModel(client_id=client_id)
response = await async_request.post(
self.openid_config.userinfo_endpoint,
data={"token": token},
headers={"Authorization": f"Bearer {token}"},
)
try:
data = dict(response.json())
except JSONDecodeError as err:
logger.debug(
"Unable to parse userinfo response",
detail=response.text,
resource_server_id=self.resource_server_id,
openid_url=self.openid_url,
)
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text) from err
logger.debug("Response from openid userinfo", response=data)
if response.status_code not in range(200, 300):
logger.debug(
"Userinfo cannot find an active token, user unauthorized",
detail=response.text,
resource_server_id=self.resource_server_id,
openid_url=self.openid_url,
)
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text)
data["client_id"] = client_id
return OIDCUserModel(data)
@ensure_openid_config_loaded
async def introspect_token(self, async_request: AsyncClient, token: str) -> dict:
"""Introspect the access token to see if it is a valid token.
Args:
async_request: The async request
token: the access_token
Returns:
dict from openid server
"""
assert self.openid_config is not None, "OpenID config is not loaded" # noqa: S101
endpoint = self.openid_config.introspect_endpoint or self.openid_config.introspection_endpoint or ""
response = await async_request.post(
endpoint,
data={"token": token, "client_id": self.resource_server_id},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
try:
data = dict(response.json())
except JSONDecodeError as err:
logger.debug(
"Unable to parse introspect response",
detail=response.text,
resource_server_id=self.resource_server_id,
openid_url=self.openid_url,
)
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text) from err
logger.debug("Response from openid introspect", response=data)
if response.status_code not in range(200, 300):
logger.debug(
"Introspect cannot find an active token, user unauthorized",
detail=response.text,
resource_server_id=self.resource_server_id,
openid_url=self.openid_url,
)
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text)
if "active" not in data:
logger.error("Token doesn't have the mandatory 'active' key, probably caused by a caching problem")
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="Missing active key")
if not data.get("active", False):
logger.info("User is not active", user_info=data)
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="User is not active")
return data
oidc_instance = OIDCAuthentication(
openid_url=oauth2lib_settings.OIDC_BASE_URL,
openid_config_url=oauth2lib_settings.OIDC_CONF_URL,
resource_server_id=oauth2lib_settings.OAUTH2_RESOURCE_SERVER_ID,
resource_server_secret=oauth2lib_settings.OAUTH2_RESOURCE_SERVER_SECRET,
oidc_user_model_cls=OIDCUserModel,
)