"""Custom middlewares for the GSO API."""

import json
import re
from collections.abc import Callable
from typing import Any

from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from starlette.status import HTTP_200_OK


class ModifyProcessEndpointResponse(BaseHTTPMiddleware):
    """Middleware to modify the response for Process details endpoint."""

    async def dispatch(self, request: Request, call_next: Callable) -> Response:
        """Middleware to modify the response for Process details endpoint.

        :param request: The incoming HTTP request.
        :type request: Request

        :param call_next: The next middleware or endpoint in the stack.
        :type call_next: Callable

        :return: The modified HTTP response.
        :rtype: Response
        """
        response = await call_next(request)
        path_pattern = re.compile(
            r"/api/processes/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})"
        )

        match = path_pattern.match(request.url.path)

        if match and response.status_code == HTTP_200_OK:
            # Modify the response body as needed
            response_body = b""
            async for chunk in response.body_iterator:
                response_body += chunk
            try:
                json_content = json.loads(response_body)
                await self._modify_response_body(json_content, request)
                modified_response_body = json.dumps(json_content).encode()
                headers = dict(response.headers)
                headers["content-length"] = str(len(modified_response_body))
                return Response(
                    content=modified_response_body,
                    status_code=response.status_code,
                    headers=headers,
                    media_type=response.media_type,
                )

            except json.JSONDecodeError:
                pass

        return response

    @staticmethod
    async def _get_token(request: Request) -> str:
        """Get the token from the request headers.

        :param request: The incoming HTTP request.
        :type request: Request

        :return: The token from the request headers in specific format.
        :rtype: str
        """
        bearer_prefix = "Bearer "
        authorization_header = request.headers.get("Authorization")
        if authorization_header:
            # Remove the "Bearer " prefix from the token
            token = authorization_header.replace(bearer_prefix, "")
            return f"?token={token}"
        return ""

    async def _modify_response_body(self, response_body: dict[str, Any], request: Request) -> None:
        """Modify the response body as needed.

        :param response_body: The response body in dictionary format.
        :type response_body: dict[str, Any]
        :param request: The incoming HTTP request.
        :type request: Request

        :return: None
        """
        max_output_length = 500
        token = await self._get_token(request)
        try:
            for step in response_body["steps"]:
                if step["state"].get("callback_result", None):
                    callback_result = step["state"]["callback_result"]
                    if callback_result and isinstance(callback_result, str):
                        callback_result = json.loads(callback_result)
                    if callback_result.get("output") and len(callback_result["output"]) > max_output_length:
                        callback_result[
                            "output"
                        ] = f'{request.base_url}api/v1/processes/steps/{step["step_id"]}/callback-results{token}'
                    step["state"]["callback_result"] = callback_result
        except (AttributeError, KeyError, TypeError):
            pass