Skip to content
Snippets Groups Projects
oidc.py 5.87 KiB
"""Module contains the OIDC Authentication class."""

import re
from collections.abc import Callable
from functools import wraps
from http import HTTPStatus
from json import JSONDecodeError
from typing import Any

from fastapi.exceptions import HTTPException
from fastapi.requests import Request
from httpx import AsyncClient
from oauth2_lib.fastapi import OIDCAuth, OIDCUserModel
from oauth2_lib.settings import oauth2lib_settings
from structlog import get_logger

logger = get_logger(__name__)

_CALLBACK_STEP_API_URL_PATTERN = re.compile(
    r"^/api/processes/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})"
    r"/callback/([0-9a-zA-Z\-_]+)$"
)


def _is_client_credentials_token(intercepted_token: dict) -> bool:
    return "sub" not in intercepted_token


def _is_callback_step_endpoint(request: Request) -> bool:
    """Check if the request is a callback step API call."""
    return re.match(_CALLBACK_STEP_API_URL_PATTERN, request.url.path) is not None


def ensure_openid_config_loaded(func: Callable) -> Callable:
    """Ensure that the openid_config is loaded before calling the function."""

    @wraps(func)
    async def wrapper(self: OIDCAuth, async_request: AsyncClient, *args: Any, **kwargs: Any) -> dict:
        await self.check_openid_config(async_request)
        return await func(self, async_request, *args, **kwargs)

    return wrapper


class OIDCAuthentication(OIDCAuth):
    """OIDCUser class extends the :term:`HTTPBearer` class to do extra verification.

    The class will act as follows:
        1. Validate the Credentials at :term: `AAI` proxy by calling the UserInfo endpoint
    """

    @staticmethod
    async def is_bypassable_request(request: Request) -> bool:
        """Check if the request is a callback step API call."""
        return _is_callback_step_endpoint(request=request)

    @ensure_openid_config_loaded
    async def userinfo(self, async_request: AsyncClient, token: str) -> OIDCUserModel:
        """Get the userinfo from the openid server.

        :param AsyncClient async_request: The async request
        :param str token: the access_token
        :return: OIDCUserModel: OIDC user model from openid server

        """
        intercepted_token = await self.introspect_token(async_request, token)
        client_id = intercepted_token.get("client_id")
        if _is_client_credentials_token(intercepted_token):
            return OIDCUserModel(client_id=client_id)

        response = await async_request.post(
            self.openid_config.userinfo_endpoint,
            data={"token": token},
            headers={"Authorization": f"Bearer {token}"},
        )
        try:
            data = dict(response.json())
        except JSONDecodeError as err:
            logger.debug(
                "Unable to parse userinfo response",
                detail=response.text,
                resource_server_id=self.resource_server_id,
                openid_url=self.openid_url,
            )
            raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text) from err
        logger.debug("Response from openid userinfo", response=data)

        if response.status_code not in range(200, 300):
            logger.debug(
                "Userinfo cannot find an active token, user unauthorized",
                detail=response.text,
                resource_server_id=self.resource_server_id,
                openid_url=self.openid_url,
            )
            raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text)

        data["client_id"] = client_id

        return OIDCUserModel(data)

    @ensure_openid_config_loaded
    async def introspect_token(self, async_request: AsyncClient, token: str) -> dict:
        """Introspect the access token to see if it is a valid token.

        :param async_request: The async request
        :param token: the access_token
        :return: dict from openid server
        """
        endpoint = self.openid_config.introspect_endpoint or self.openid_config.introspection_endpoint or ""
        response = await async_request.post(
            endpoint,
            data={"token": token, "client_id": self.resource_server_id},
            headers={"Content-Type": "application/x-www-form-urlencoded"},
        )

        try:
            data = dict(response.json())
        except JSONDecodeError as err:
            logger.debug(
                "Unable to parse introspect response",
                detail=response.text,
                resource_server_id=self.resource_server_id,
                openid_url=self.openid_url,
            )
            raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text) from err

        logger.debug("Response from openid introspect", response=data)

        if response.status_code not in range(200, 300):
            logger.debug(
                "Introspect cannot find an active token, user unauthorized",
                detail=response.text,
                resource_server_id=self.resource_server_id,
                openid_url=self.openid_url,
            )
            raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text)

        if "active" not in data:
            logger.error("Token doesn't have the mandatory 'active' key, probably caused by a caching problem")
            raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="Missing active key")
        if not data.get("active", False):
            logger.info("User is not active", user_info=data)
            raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="User is not active")

        return data


oidc_instance = OIDCAuthentication(
    openid_url=oauth2lib_settings.OIDC_BASE_URL,
    openid_config_url=oauth2lib_settings.OIDC_CONF_URL,
    resource_server_id=oauth2lib_settings.OAUTH2_RESOURCE_SERVER_ID,
    resource_server_secret=oauth2lib_settings.OAUTH2_RESOURCE_SERVER_SECRET,
    oidc_user_model_cls=OIDCUserModel,
)