Skip to content
Snippets Groups Projects
oidc_policy_helper.py 16.5 KiB
Newer Older
"""OpenID Connect and Open Policy Agent Integration for GSO Application.

This module provides helper functions and classes for handling OpenID Connect (OIDC) and
Open Policy Agent (OPA) related functionalities within the GSO application. It includes
implementations for OIDC-based user authentication and user information modeling. Additionally,
it facilitates making authorization decisions based on policies defined in OPA. Key components
comprise OIDCUser, OIDCUserModel, OPAResult, and opa_decision. These elements integrate with
FastAPI to ensure secure API development.
"""

import re
import ssl
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
from http import HTTPStatus
from json import JSONDecodeError
from typing import Any, ClassVar, cast

from fastapi.exceptions import HTTPException
from fastapi.param_functions import Depends
from fastapi.requests import Request
from fastapi.security.http import HTTPBearer
from httpx import AsyncClient, NetworkError
from pydantic import BaseModel
from starlette.requests import ClientDisconnect
from structlog import get_logger

from gso.auth.settings import oauth2lib_settings

logger = get_logger(__name__)

HTTPX_SSL_CONTEXT = ssl.create_default_context()  # https://github.com/encode/httpx/issues/838

_CALLBACK_STEP_API_URL_PATTERN = re.compile(
    r"^/api/processes/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})"
    r"/callback/([0-9a-zA-Z\-_]+)$"
)


def _is_callback_step_endpoint(request: Request) -> bool:
    """Check if the request is a callback step API call."""
    return re.match(_CALLBACK_STEP_API_URL_PATTERN, request.url.path) is not None


class InvalidScopeValueError(ValueError):
    """Exception raised for invalid scope values in OIDC."""


class OIDCUserModel(dict):
    """The standard claims of a OIDCUserModel object. Defined per `Section 5.1`_ and AAI attributes.

    .. _`Section 5.1`: http://openid.net/specs/openid-connect-core-1_0.html#StandardClaims
    """

    #: registered claims that OIDCUserModel supports
    REGISTERED_CLAIMS: ClassVar[list[str]] = [
        "sub",
        "name",
        "given_name",
        "family_name",
        "middle_name",
        "nickname",
        "preferred_username",
        "profile",
        "picture",
        "website",
        "email",
        "email_verified",
        "gender",
        "birthdate",
        "zoneinfo",
        "locale",
        "phone_number",
        "phone_number_verified",
        "address",
        "updated_at",
    ]

    def __getattr__(self, key: str) -> Any:
        """Get an attribute value using key.

        Overrides the default behavior to return the value from the dictionary
        if the attribute is one of the registered claims or raises an AttributeError
        if the key is not found.

Karel van Klink's avatar
Karel van Klink committed
        :param str key: The attribute name to retrieve.
        :return: The value of the attribute if it exists, otherwise raises AttributeError.
        """
        try:
            return object.__getattribute__(self, key)
        except AttributeError as error:
            if key in self.REGISTERED_CLAIMS:
                return self.get(key)
            raise error from None

    @property
    def user_name(self) -> str:
        """Return the username of the user."""
        if "user_name" in self.keys():
            return cast(str, self["user_name"])
        if "unspecified_id" in self.keys():
            return cast(str, self["unspecified_id"])
        return ""

    @property
    def display_name(self) -> str:
        """Return the display name of the user."""
        return self.get("display_name", "")

    @property
    def principal_name(self) -> str:
        """Return the principal name of the user."""
        return self.get("eduperson_principal_name", "")

    @property
    def scopes(self) -> set[str]:
        """Return the scopes of the user."""
        scope_value = self.get("scope")
        if scope_value is None:
            return set()

        if isinstance(scope_value, list):
            return {item for item in scope_value if isinstance(item, str)}
        if isinstance(scope_value, str):
            return set(filter(None, re.split("[ ,]", scope_value)))

        message = f"Invalid scope value: {scope_value}"
        raise InvalidScopeValueError(message)


async def _make_async_client() -> AsyncGenerator[AsyncClient, None]:
    async with AsyncClient(http1=True, verify=HTTPX_SSL_CONTEXT) as client:
        yield client


class OIDCConfig(BaseModel):
    """Configuration for OpenID Connect (OIDC) authentication and token validation."""

    issuer: str
    authorization_endpoint: str
    token_endpoint: str
    userinfo_endpoint: str
    introspect_endpoint: str | None = None
    introspection_endpoint: str | None = None
    jwks_uri: str
    response_types_supported: list[str]
    response_modes_supported: list[str]
    grant_types_supported: list[str]
    subject_types_supported: list[str]
    id_token_signing_alg_values_supported: list[str]
    scopes_supported: list[str]
    token_endpoint_auth_methods_supported: list[str]
    claims_supported: list[str]
    claims_parameter_supported: bool
    request_parameter_supported: bool
    code_challenge_methods_supported: list[str]


class OPAResult(BaseModel):
    """Represents the outcome of an authorization decision made by the Open Policy Agent (OPA).

    Attributes
    ----------
    - result (bool): Indicates whether the access request is allowed or denied.
    - decision_id (str): A unique identifier for the decision made by OPA.
    """

    result: bool = False
    decision_id: str


class OIDCUser(HTTPBearer):
    """OIDCUser class extends the :term:`HTTPBearer` class to do extra verification.

    The class will act as follows:
        1. Validate the Credentials at :term: `AAI` proxy by calling the UserInfo endpoint
    """

    openid_config: OIDCConfig | None = None
    openid_url: str
    resource_server_id: str
    resource_server_secret: str

    def __init__(
        self,
        openid_url: str,
        resource_server_id: str,
        resource_server_secret: str,
        *,
        auto_error: bool = True,
        scheme_name: str | None = None,
    ):
        """Set up OIDCUser with specified OpenID Connect configurations and credentials."""
        super().__init__(auto_error=auto_error)
        self.openid_url = openid_url
        self.resource_server_id = resource_server_id
        self.resource_server_secret = resource_server_secret
        self.scheme_name = scheme_name or self.__class__.__name__

    async def __call__(  # type: ignore[override]
        self, request: Request, token: str | None = None
    ) -> OIDCUserModel | None:
        """Return the OIDC user from OIDC introspect endpoint.

        This is used as a security module in Fastapi projects


Karel van Klink's avatar
Karel van Klink committed
        :param Request request: Starlette request method.
        :param str token: Optional value to directly pass a token.
        :return: OIDCUserModel object.
        """
        if not oauth2lib_settings.OAUTH2_ACTIVE:
            return None

        async with AsyncClient(http1=True, verify=HTTPX_SSL_CONTEXT) as async_request:
            if not token:
                credentials = await super().__call__(request)
                if not credentials:
                    return None
                token = credentials.credentials
            elif _is_callback_step_endpoint(request):
                logger.debug(
                    "callback step endpoint is called. verification will be done by endpoint itself.", url=request.url
                )
                return None
            await self.check_openid_config(async_request)
            intercepted_token = await self.introspect_token(async_request, token)

            if "active" not in intercepted_token:
                logger.error("Token doesn't have the mandatory 'active' key, probably caused by a caching problem")
                raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="Missing active key")
            if not intercepted_token.get("active", False):
                logger.info("User is not active", url=request.url, user_info=intercepted_token)
                raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="User is not active")

            user_info = await self.userinfo(async_request, token)

            logger.debug("OIDCUserModel object.", intercepted_token=intercepted_token)
            return user_info

    async def check_openid_config(self, async_request: AsyncClient) -> None:
        """Check of openid config is loaded and load if not."""
        if self.openid_config is not None:
            return

        response = await async_request.get(self.openid_url + "/.well-known/openid-configuration")
        self.openid_config = OIDCConfig.parse_obj(response.json())

    async def userinfo(self, async_request: AsyncClient, token: str) -> OIDCUserModel:
        """Get the userinfo from the openid server.

        :param AsyncClient async_request: The async request
        :param str token: the access_token
        :return: OIDCUserModel: OIDC user model from openid server

        """
        await self.check_openid_config(async_request)
        assert self.openid_config, "OpenID config should be loaded"  # noqa: S101

        response = await async_request.post(
            self.openid_config.userinfo_endpoint,
            data={"token": token},
            headers={"Authorization": f"Bearer {token}"},
        )
        try:
            data = dict(response.json())
        except JSONDecodeError as err:
            logger.debug(
                "Unable to parse userinfo response",
                detail=response.text,
                resource_server_id=self.resource_server_id,
                openid_url=self.openid_url,
            )
            raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text) from err
        logger.debug("Response from openid userinfo", response=data)

        if response.status_code not in range(200, 300):
            logger.debug(
                "Userinfo cannot find an active token, user unauthorized",
                detail=response.text,
                resource_server_id=self.resource_server_id,
                openid_url=self.openid_url,
            )
            raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text)

        return OIDCUserModel(data)

    async def introspect_token(self, async_request: AsyncClient, token: str) -> dict:
        """Introspect the access token to see if it is a valid token.

        :param async_request: The async request
        :param token: the access_token
        :return: dict from openid server
        """
        await self.check_openid_config(async_request)
        assert self.openid_config, "OpenID config should be loaded"  # noqa: S101

        endpoint = self.openid_config.introspect_endpoint or self.openid_config.introspection_endpoint or ""
        response = await async_request.post(
            endpoint,
            data={"token": token, "client_id": self.resource_server_id},
            headers={"Content-Type": "application/x-www-form-urlencoded"},
        )

        try:
            data = dict(response.json())
        except JSONDecodeError as err:
            logger.debug(
                "Unable to parse introspect response",
                detail=response.text,
                resource_server_id=self.resource_server_id,
                openid_url=self.openid_url,
            )
            raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text) from err

        logger.debug("Response from openid introspect", response=data)

        if response.status_code not in range(200, 300):
            logger.debug(
                "Introspect cannot find an active token, user unauthorized",
                detail=response.text,
                resource_server_id=self.resource_server_id,
                openid_url=self.openid_url,
            )
            raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=response.text)

        return data


async def _get_decision(async_request: AsyncClient, opa_url: str, opa_input: dict) -> OPAResult:
    logger.debug("Posting input json to Policy agent", opa_url=opa_url, input=opa_input)
    try:
        response = await async_request.post(opa_url, json=opa_input)
    except (NetworkError, TypeError) as exc:
        logger.debug("Could not get decision from policy agent", error=str(exc))
        raise HTTPException(status_code=HTTPStatus.SERVICE_UNAVAILABLE, detail="Policy agent is unavailable") from exc

    result = response.json()
    logger.debug("Received response from Policy agent", response=result)
    return OPAResult(result=result["result"]["allow"], decision_id=result["decision_id"])


def _evaluate_decision(decision: OPAResult, *, auto_error: bool, **context: dict[str, Any]) -> bool:
    did = decision.decision_id

    if decision.result:
        logger.debug("User is authorized to access the resource", decision_id=did, **context)
        return True

    logger.debug("User is not allowed to access the resource", decision_id=did, **context)
    if not auto_error:
        return False

    raise HTTPException(
        status_code=HTTPStatus.FORBIDDEN,
        detail=f"User is not allowed to access resource: {context.get('resource')} Decision was taken with id: {did}",
    )


def opa_decision(
    opa_url: str,
    oidc_security: OIDCUser,
    *,
    auto_error: bool = True,
    opa_kwargs: Mapping[str, str] | None = None,
) -> Callable[[Request, OIDCUserModel, AsyncClient], Awaitable[bool | None]]:
    """Create a decision function for Open Policy Agent (OPA) authorization checks.

    This function generates an asynchronous decision function that can be used in FastAPI endpoints
    to authorize requests based on OPA policies. It utilizes OIDC for user information and makes a
    call to the OPA service to determine authorization.

Karel van Klink's avatar
Karel van Klink committed
    :param str opa_url: URL of the Open Policy Agent service.
    :param OIDCUser oidc_security: An instance of OIDCUser for user authentication.
    :param bool auto_error: If True, automatically raises an HTTPException on authorization failure.
    :param Mapping[str, str] | None opa_kwargs: Additional keyword arguments to be passed to the OPA input.
Karel van Klink's avatar
Karel van Klink committed
    :return: An asynchronous decision function that can be used as a dependency in FastAPI endpoints.
    """

    async def _opa_decision(
        request: Request,
        user_info: OIDCUserModel = Depends(oidc_security),  # noqa: B008
        async_request: AsyncClient = Depends(_make_async_client),  # noqa: B008
    ) -> bool | None:
        """Check OIDCUserModel against the OPA policy.

        This is used as a security module in Fastapi projects
        This method will make an async call towards the Policy agent.

        Args:
        ----
            request: Request object that will be used to retrieve request metadata.
            user_info: The OIDCUserModel object that will be checked
            async_request: The :term:`httpx` client.
        """
        if not (oauth2lib_settings.OAUTH2_ACTIVE and oauth2lib_settings.OAUTH2_AUTHORIZATION_ACTIVE):
            return None

        if _is_callback_step_endpoint(request):
            return None

        try:
            json = await request.json()
        # Silencing the Decode error or Type error when request.json() does not return anything sane.
        # Some requests do not have a json response therefore as this code gets called on every request
        # we need to suppress the `None` case (TypeError) or the `other than json` case (JSONDecodeError)
        # Suppress AttributeError in case of websocket request, it doesn't have .json
        except (JSONDecodeError, TypeError, ClientDisconnect, AttributeError):
            json = {}

        # defaulting to GET request method for WebSocket request, it doesn't have .method
        request_method = request.method if hasattr(request, "method") else "GET"
        opa_input = {
            "input": {
                **(opa_kwargs or {}),
                **user_info,
                "resource": request.url.path,
                "method": request_method,
                "arguments": {"path": request.path_params, "query": {**request.query_params}, "json": json},
            }
        }

        decision = await _get_decision(async_request, opa_url, opa_input)

        context = {
            "resource": opa_input["input"]["resource"],
            "method": opa_input["input"]["method"],
            "user_info": user_info,
            "input": opa_input,
            "url": request.url,
        }
        return _evaluate_decision(decision, auto_error=auto_error, **context)

    return _opa_decision