From a769f9a70d7d3e2516cbeec09648e4a387b642f9 Mon Sep 17 00:00:00 2001
From: Mohammad Torkashvand <mohammad.torkashvand@geant.org>
Date: Wed, 1 May 2024 13:51:30 +0200
Subject: [PATCH] add graphql

---
 gso/__init__.py         |   7 ++-
 gso/graphql/__init__.py |   1 +
 gso/graphql/types.py    |  27 +++++++++++
 gso/middlewares.py      | 101 ----------------------------------------
 gso/monkeypatches.py    |   3 +-
 5 files changed, 35 insertions(+), 104 deletions(-)
 create mode 100644 gso/graphql/__init__.py
 create mode 100644 gso/graphql/types.py
 delete mode 100644 gso/middlewares.py

diff --git a/gso/__init__.py b/gso/__init__.py
index ecdfd940..74719f79 100644
--- a/gso/__init__.py
+++ b/gso/__init__.py
@@ -5,19 +5,22 @@ from gso import monkeypatches  # noqa: F401, isort:skip
 import typer
 from orchestrator import OrchestratorCore, app_settings
 from orchestrator.cli.main import app as cli_app
+from orchestrator.graphql import SCALAR_OVERRIDES
 
 # noinspection PyUnresolvedReferences
 import gso.products
 import gso.workflows  # noqa: F401
 from gso.api import router as api_router
-from gso.middlewares import ModifyProcessEndpointResponse
+from gso.graphql.types import GSO_SCALAR_OVERRIDES
+
+SCALAR_OVERRIDES.update(GSO_SCALAR_OVERRIDES)
 
 
 def init_gso_app() -> OrchestratorCore:
     """Initialise the :term:`GSO` app."""
     app = OrchestratorCore(base_settings=app_settings)
+    app.register_graphql()
     app.include_router(api_router, prefix="/api")
-    app.add_middleware(ModifyProcessEndpointResponse)
     return app
 
 
diff --git a/gso/graphql/__init__.py b/gso/graphql/__init__.py
new file mode 100644
index 00000000..98799c0c
--- /dev/null
+++ b/gso/graphql/__init__.py
@@ -0,0 +1 @@
+"""graphql module."""
diff --git a/gso/graphql/types.py b/gso/graphql/types.py
new file mode 100644
index 00000000..c6fb9200
--- /dev/null
+++ b/gso/graphql/types.py
@@ -0,0 +1,27 @@
+"""Map some Orchestrator types to scalars."""
+
+from ipaddress import IPv4Network, IPv6Network
+from typing import Any, NewType
+
+import strawberry
+from orchestrator.graphql.types import serialize_to_string
+from strawberry.custom_scalar import ScalarDefinition, ScalarWrapper
+
+IPv4NetworkType = strawberry.scalar(
+    NewType("IPv4NetworkType", str),
+    description="Represent the Orchestrator IPv4Network data type",
+    serialize=serialize_to_string,
+    parse_value=lambda v: v,
+)
+
+IPv6NetworkType = strawberry.scalar(
+    NewType("IPv6NetworkType", str),
+    description="Represent the Orchestrator IPv6Network data type",
+    serialize=serialize_to_string,
+    parse_value=lambda v: v,
+)
+
+GSO_SCALAR_OVERRIDES: dict[object, Any | ScalarWrapper | ScalarDefinition] = {
+    IPv4Network: IPv4NetworkType,
+    IPv6Network: IPv6NetworkType,
+}
diff --git a/gso/middlewares.py b/gso/middlewares.py
deleted file mode 100644
index 5ffca88e..00000000
--- a/gso/middlewares.py
+++ /dev/null
@@ -1,101 +0,0 @@
-"""Custom middlewares for the GSO API."""
-
-import json
-import re
-from collections.abc import Callable
-from typing import Any
-
-from fastapi import Request
-from starlette.middleware.base import BaseHTTPMiddleware
-from starlette.responses import Response
-from starlette.status import HTTP_200_OK
-
-
-class ModifyProcessEndpointResponse(BaseHTTPMiddleware):
-    """Middleware to modify the response for Process details endpoint."""
-
-    async def dispatch(self, request: Request, call_next: Callable) -> Response:
-        """Middleware to modify the response for Process details endpoint.
-
-        :param request: The incoming HTTP request.
-        :type request: Request
-
-        :param call_next: The next middleware or endpoint in the stack.
-        :type call_next: Callable
-
-        :return: The modified HTTP response.
-        :rtype: Response
-        """
-        response = await call_next(request)
-        path_pattern = re.compile(
-            r"/api/processes/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})"
-        )
-
-        match = path_pattern.match(request.url.path)
-
-        if match and response.status_code == HTTP_200_OK:
-            # Modify the response body as needed
-            response_body = b""
-            async for chunk in response.body_iterator:
-                response_body += chunk
-            try:
-                json_content = json.loads(response_body)
-                await self._modify_response_body(json_content, request)
-                modified_response_body = json.dumps(json_content).encode()
-                headers = dict(response.headers)
-                headers["content-length"] = str(len(modified_response_body))
-                return Response(
-                    content=modified_response_body,
-                    status_code=response.status_code,
-                    headers=headers,
-                    media_type=response.media_type,
-                )
-
-            except json.JSONDecodeError:
-                pass
-
-        return response
-
-    @staticmethod
-    async def _get_token(request: Request) -> str:
-        """Get the token from the request headers.
-
-        :param request: The incoming HTTP request.
-        :type request: Request
-
-        :return: The token from the request headers in specific format.
-        :rtype: str
-        """
-        bearer_prefix = "Bearer "
-        authorization_header = request.headers.get("Authorization")
-        if authorization_header:
-            # Remove the "Bearer " prefix from the token
-            token = authorization_header.replace(bearer_prefix, "")
-            return f"?token={token}"
-        return ""
-
-    async def _modify_response_body(self, response_body: dict[str, Any], request: Request) -> None:
-        """Modify the response body as needed.
-
-        :param response_body: The response body in dictionary format.
-        :type response_body: dict[str, Any]
-        :param request: The incoming HTTP request.
-        :type request: Request
-
-        :return: None
-        """
-        max_output_length = 500
-        token = await self._get_token(request)
-        try:
-            for step in response_body["steps"]:
-                if step["state"].get("callback_result", None):
-                    callback_result = step["state"]["callback_result"]
-                    if callback_result and isinstance(callback_result, str):
-                        callback_result = json.loads(callback_result)
-                    if callback_result.get("output") and len(callback_result["output"]) > max_output_length:
-                        callback_result["output"] = (
-                            f'{request.base_url}api/v1/processes/steps/{step["step_id"]}/callback-results{token}'
-                        )
-                    step["state"]["callback_result"] = callback_result
-        except (AttributeError, KeyError, TypeError):
-            pass
diff --git a/gso/monkeypatches.py b/gso/monkeypatches.py
index 2e94f50b..1b71f634 100644
--- a/gso/monkeypatches.py
+++ b/gso/monkeypatches.py
@@ -7,11 +7,12 @@ oauth2_lib package to meet specific requirements of the gso application.
 import oauth2_lib.fastapi
 import oauth2_lib.settings
 
-from gso.auth.oidc_policy_helper import HTTPX_SSL_CONTEXT, OIDCUser, OIDCUserModel, opa_decision
+from gso.auth.oidc_policy_helper import HTTPX_SSL_CONTEXT, OIDCUser, OIDCUserModel, _get_decision, opa_decision
 from gso.auth.settings import oauth2lib_settings
 
 oauth2_lib.fastapi.OIDCUser = OIDCUser  # type: ignore[assignment, misc]
 oauth2_lib.fastapi.OIDCUserModel = OIDCUserModel  # type: ignore[assignment, misc]
 oauth2_lib.fastapi.opa_decision = opa_decision  # type: ignore[assignment]
+oauth2_lib.fastapi._get_decision = _get_decision  # type: ignore[assignment] # noqa: SLF001
 oauth2_lib.fastapi.HTTPX_SSL_CONTEXT = HTTPX_SSL_CONTEXT
 oauth2_lib.settings.oauth2lib_settings = oauth2lib_settings  # type: ignore[assignment]
-- 
GitLab