"""Custom middlewares for the GSO API.""" import json import re from collections.abc import Callable from typing import Any from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import Response from starlette.status import HTTP_200_OK class ModifyProcessEndpointResponse(BaseHTTPMiddleware): """Middleware to modify the response for Process details endpoint.""" async def dispatch(self, request: Request, call_next: Callable) -> Response: """Middleware to modify the response for Process details endpoint. Args: ---- request (Request): The incoming HTTP request. call_next (Callable): The next middleware or endpoint in the stack. Returns: ------- Response: The modified HTTP response. """ response = await call_next(request) path_pattern = r"/api/processes/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})" match = re.match(path_pattern, request.url.path) if match and response.status_code == HTTP_200_OK: # Modify the response body as needed response_body = b"" async for chunk in response.body_iterator: response_body += chunk try: json_content = json.loads(response_body) self.modify_response_body(json_content, request) modified_response_body = json.dumps(json_content).encode() headers = dict(response.headers) headers["content-length"] = str(len(modified_response_body)) return Response( content=modified_response_body, status_code=response.status_code, headers=headers, media_type=response.media_type, ) except json.JSONDecodeError: pass return response @staticmethod def modify_response_body(response_body: dict[str, Any], request: Request) -> None: """Modify the response body as needed. Args: ---- response_body (Dict[str, Any]): The response body in dictionary format. request (Request): The incoming HTTP request. """ max_output_length = 1000 try: for step in response_body["steps"]: if step["state"].get("callback_result", None): callback_result = step["state"]["callback_result"] if callback_result and isinstance(callback_result, str): callback_result = json.loads(callback_result) if callback_result.get("output") and len(callback_result["output"]) > max_output_length: callback_result[ "output" ] = f'{request.base_url}api/v1/processes/steps/{step["step_id"]}/callback-results/' step["state"]["callback_result"] = callback_result except (AttributeError, KeyError, TypeError): pass