Newer
Older

Neda Moeini
committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""Custom middlewares for the GSO API."""
import json
import re
from collections.abc import Callable
from typing import Any
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from starlette.status import HTTP_200_OK
class ModifyProcessEndpointResponse(BaseHTTPMiddleware):
"""Middleware to modify the response for Process details endpoint."""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""Middleware to modify the response for Process details endpoint.
Args:
----
request (Request): The incoming HTTP request.
call_next (Callable): The next middleware or endpoint in the stack.
Returns:
-------
Response: The modified HTTP response.
"""
response = await call_next(request)
path_pattern = r"/api/processes/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})"
match = re.match(path_pattern, request.url.path)
if match and response.status_code == HTTP_200_OK:
# Modify the response body as needed
response_body = b""
async for chunk in response.body_iterator:
response_body += chunk
try:
json_content = json.loads(response_body)
await self.modify_response_body(json_content, request)

Neda Moeini
committed
modified_response_body = json.dumps(json_content).encode()
headers = dict(response.headers)
headers["content-length"] = str(len(modified_response_body))
return Response(
content=modified_response_body,
status_code=response.status_code,
headers=headers,
media_type=response.media_type,
)
except json.JSONDecodeError:
pass
return response
@staticmethod
async def _get_token(request: Request) -> str:
"""Get the token from the request headers.
Args:
----
request (Request): The incoming HTTP request.
Returns:
-------
str: The token from the request headers in specific format.
"""
bearer_prefix = "Bearer "
authorization_header = request.headers.get("Authorization")
if authorization_header:
# Remove the "Bearer " prefix from the token
token = authorization_header.replace(bearer_prefix, "")
return f"?token={token}"
async def modify_response_body(self, response_body: dict[str, Any], request: Request) -> None:

Neda Moeini
committed
"""Modify the response body as needed.
Args:
----
response_body (Dict[str, Any]): The response body in dictionary format.
request (Request): The incoming HTTP request.
"""
max_output_length = 500
token = await self._get_token(request)

Neda Moeini
committed
try:
for step in response_body["steps"]:
if step["state"].get("callback_result", None):
callback_result = step["state"]["callback_result"]
if callback_result and isinstance(callback_result, str):
callback_result = json.loads(callback_result)
if callback_result.get("output") and len(callback_result["output"]) > max_output_length:
callback_result[
"output"
] = f'{request.base_url}api/v1/processes/steps/{step["step_id"]}/callback-results{token}'

Neda Moeini
committed
step["state"]["callback_result"] = callback_result
except (AttributeError, KeyError, TypeError):
pass