From 7fdced2db5abb26a81fcbd2c8290ff2771097ec3 Mon Sep 17 00:00:00 2001
From: Karel van Klink <karel.vanklink@geant.org>
Date: Wed, 11 Dec 2024 10:08:05 +0100
Subject: [PATCH] Add input validation to modification workflows that contain
 GA- and GS-IDs

---
 gso/services/subscriptions.py                 |  2 +-
 gso/workflows/edge_port/modify_edge_port.py   | 11 +++++++++--
 .../iptrunk/create_imported_iptrunk.py        |  2 +-
 gso/workflows/iptrunk/create_iptrunk.py       |  2 +-
 gso/workflows/iptrunk/migrate_iptrunk.py      |  2 +-
 .../iptrunk/modify_trunk_interface.py         | 13 ++++++++++---
 .../create_imported_layer_2_circuit.py        |  3 ++-
 .../create_imported_l3_core_service.py        |  3 ++-
 .../edge_port/test_modify_edge_port.py        | 17 +++++++++++++++++
 test/workflows/iptrunk/test_create_iptrunk.py |  4 +---
 .../iptrunk/test_modify_trunk_interface.py    | 19 ++++++++++++++++++-
 11 files changed, 63 insertions(+), 15 deletions(-)

diff --git a/gso/services/subscriptions.py b/gso/services/subscriptions.py
index e6781244..83d88751 100644
--- a/gso/services/subscriptions.py
+++ b/gso/services/subscriptions.py
@@ -355,7 +355,7 @@ def generate_unique_ga_id() -> str:
     """Generate a unique GA ID using the ga_id_seq database sequence.
 
     Returns:
-        str: A unique GA ID in the format `GA<number>`.
+        str: A unique GA ID in the format `GA-<number>`.
 
     Raises:
         ValueError: If there is an error generating the ID.
diff --git a/gso/workflows/edge_port/modify_edge_port.py b/gso/workflows/edge_port/modify_edge_port.py
index 94388984..b176eafc 100644
--- a/gso/workflows/edge_port/modify_edge_port.py
+++ b/gso/workflows/edge_port/modify_edge_port.py
@@ -1,5 +1,6 @@
 """Modify an existing edge port subscription."""
 
+from functools import partial
 from typing import Annotated, Any, Self
 from uuid import uuid4
 
@@ -10,7 +11,7 @@ from orchestrator.targets import Target
 from orchestrator.workflow import StepList, begin, conditional, done, step
 from orchestrator.workflows.steps import resync, store_process_subscription, unsync
 from orchestrator.workflows.utils import wrap_modify_initial_input_form
-from pydantic import AfterValidator, ConfigDict, model_validator
+from pydantic import AfterValidator, ConfigDict, Field, model_validator
 from pydantic_forms.types import FormGenerator, State, UUIDstr
 from pydantic_forms.validators import ReadOnlyField, validate_unique_list
 
@@ -26,6 +27,7 @@ from gso.utils.helpers import (
 )
 from gso.utils.types.interfaces import LAGMember, PhysicalPortCapacity
 from gso.utils.types.tt_number import TTNumber
+from gso.utils.types.unique_field import validate_field_is_unique
 
 
 def initial_input_form_generator(subscription_id: UUIDstr) -> FormGenerator:
@@ -43,7 +45,12 @@ def initial_input_form_generator(subscription_id: UUIDstr) -> FormGenerator:
         minimum_links: int | None = subscription.edge_port.minimum_links or None
         mac_address: str | None = subscription.edge_port.mac_address or None
         ignore_if_down: bool = subscription.edge_port.ignore_if_down
-        ga_id: str | None = subscription.edge_port.ga_id or None
+        ga_id: (
+            Annotated[
+                str, AfterValidator(partial(validate_field_is_unique, subscription_id)), Field(pattern=r"^GA-\d{5}$")
+            ]
+            | None
+        ) = subscription.edge_port.ga_id or None
 
         @model_validator(mode="after")
         def validate_number_of_members(self) -> Self:
diff --git a/gso/workflows/iptrunk/create_imported_iptrunk.py b/gso/workflows/iptrunk/create_imported_iptrunk.py
index 1feec505..722121fe 100644
--- a/gso/workflows/iptrunk/create_imported_iptrunk.py
+++ b/gso/workflows/iptrunk/create_imported_iptrunk.py
@@ -114,7 +114,7 @@ def initialize_subscription(
         subscription.iptrunk.iptrunk_sides[0].iptrunk_side_node.router_site.site_name,
         subscription.iptrunk.iptrunk_sides[1].iptrunk_side_node.router_site.site_name,
     ])
-    subscription.description = f"IP trunk {side_names[0]} {side_names[1]}, gs_id:{gs_id}"
+    subscription.description = f"IP trunk {side_names[0]} {side_names[1]}, {gs_id}"
     return {"subscription": subscription}
 
 
diff --git a/gso/workflows/iptrunk/create_iptrunk.py b/gso/workflows/iptrunk/create_iptrunk.py
index 8f24e333..962cc5d0 100644
--- a/gso/workflows/iptrunk/create_iptrunk.py
+++ b/gso/workflows/iptrunk/create_iptrunk.py
@@ -361,7 +361,7 @@ def initialize_subscription(
             IptrunkInterfaceBlockInactive.new(subscription_id=uuid4(), **member),
         )
     side_names = sorted([side_a.router_site.site_name, side_b.router_site.site_name])
-    subscription.description = f"IP trunk {side_names[0]} {side_names[1]}, gs_id:{gs_id}"
+    subscription.description = f"IP trunk {side_names[0]} {side_names[1]}, {gs_id}"
 
     return {"subscription": subscription}
 
diff --git a/gso/workflows/iptrunk/migrate_iptrunk.py b/gso/workflows/iptrunk/migrate_iptrunk.py
index 03ee4465..7e7f32b4 100644
--- a/gso/workflows/iptrunk/migrate_iptrunk.py
+++ b/gso/workflows/iptrunk/migrate_iptrunk.py
@@ -769,7 +769,7 @@ def update_subscription_model(
         subscription.iptrunk.iptrunk_sides[0].iptrunk_side_node.router_site.site_name,
         subscription.iptrunk.iptrunk_sides[1].iptrunk_side_node.router_site.site_name,
     ])
-    subscription.description = f"IP trunk {side_names[0]} {side_names[1]}, gs_id:{subscription.iptrunk.gs_id}"
+    subscription.description = f"IP trunk {side_names[0]} {side_names[1]}, {subscription.iptrunk.gs_id}"
 
     return {"subscription": subscription}
 
diff --git a/gso/workflows/iptrunk/modify_trunk_interface.py b/gso/workflows/iptrunk/modify_trunk_interface.py
index 65fa452b..ea2309e4 100644
--- a/gso/workflows/iptrunk/modify_trunk_interface.py
+++ b/gso/workflows/iptrunk/modify_trunk_interface.py
@@ -7,6 +7,7 @@ necessary modifications will be applied.
 """
 
 import json
+from functools import partial
 from typing import Annotated
 from uuid import UUID, uuid4
 
@@ -18,7 +19,7 @@ from orchestrator.utils.json import json_dumps
 from orchestrator.workflow import StepList, begin, conditional, done, step, workflow
 from orchestrator.workflows.steps import resync, store_process_subscription, unsync
 from orchestrator.workflows.utils import wrap_modify_initial_input_form
-from pydantic import ConfigDict
+from pydantic import AfterValidator, ConfigDict, Field
 from pydantic_forms.validators import Label, ReadOnlyField
 
 from gso.products.product_blocks.iptrunk import (
@@ -39,6 +40,7 @@ from gso.utils.shared_enums import Vendor
 from gso.utils.types.interfaces import JuniperLAGMember, LAGMember, LAGMemberList, PhysicalPortCapacity
 from gso.utils.types.ip_address import IPv4AddressType, IPv6AddressType
 from gso.utils.types.tt_number import TTNumber
+from gso.utils.types.unique_field import validate_field_is_unique
 from gso.workflows.iptrunk.migrate_iptrunk import check_ip_trunk_optical_levels_pre
 from gso.workflows.iptrunk.validate_iptrunk import check_ip_trunk_isis
 
@@ -86,7 +88,12 @@ def initial_input_form_generator(subscription_id: UUIDstr) -> FormGenerator:
 
     class ModifyIptrunkForm(FormPage):
         tt_number: TTNumber
-        gs_id: str | None = subscription.iptrunk.gs_id
+        gs_id: (
+            Annotated[
+                str, AfterValidator(partial(validate_field_is_unique, subscription_id)), Field(pattern=r"^GS-\d{5}$")
+            ]
+            | None
+        ) = subscription.iptrunk.gs_id
         iptrunk_description: str | None = subscription.iptrunk.iptrunk_description
         iptrunk_type: IptrunkType = subscription.iptrunk.iptrunk_type
         warning_label: Label = (
@@ -302,7 +309,7 @@ def modify_iptrunk_subscription(
         subscription.iptrunk.iptrunk_sides[0].iptrunk_side_node.router_site.site_name,
         subscription.iptrunk.iptrunk_sides[1].iptrunk_side_node.router_site.site_name,
     ])
-    subscription.description = f"IP trunk {side_names[0]} {side_names[1]}, gs_id:{gs_id}"
+    subscription.description = f"IP trunk {side_names[0]} {side_names[1]}, {gs_id}"
 
     return {
         "subscription": subscription,
diff --git a/gso/workflows/l2_circuit/create_imported_layer_2_circuit.py b/gso/workflows/l2_circuit/create_imported_layer_2_circuit.py
index 6102d86d..3e5e97c9 100644
--- a/gso/workflows/l2_circuit/create_imported_layer_2_circuit.py
+++ b/gso/workflows/l2_circuit/create_imported_layer_2_circuit.py
@@ -23,6 +23,7 @@ from gso.products.product_types.layer_2_circuit import (
 from gso.services.partners import get_partner_by_name
 from gso.services.subscriptions import get_product_id_by_name
 from gso.utils.shared_enums import SBPType
+from gso.utils.types.geant_ids import IMPORTED_GS_ID
 from gso.utils.types.interfaces import BandwidthString
 from gso.utils.types.virtual_identifiers import VC_ID, VLAN_ID
 
@@ -39,7 +40,7 @@ def initial_input_form_generator() -> FormGenerator:
 
         service_type: Layer2CircuitServiceType
         partner: str
-        gs_id: str
+        gs_id: IMPORTED_GS_ID
         vc_id: VC_ID
         layer_2_circuit_side_a: ServiceBindingPortInput
         layer_2_circuit_side_b: ServiceBindingPortInput
diff --git a/gso/workflows/l3_core_service/create_imported_l3_core_service.py b/gso/workflows/l3_core_service/create_imported_l3_core_service.py
index 7862a47f..7280a225 100644
--- a/gso/workflows/l3_core_service/create_imported_l3_core_service.py
+++ b/gso/workflows/l3_core_service/create_imported_l3_core_service.py
@@ -21,6 +21,7 @@ from gso.products.product_types.l3_core_service import ImportedL3CoreServiceInac
 from gso.services.partners import get_partner_by_name
 from gso.services.subscriptions import get_product_id_by_name
 from gso.utils.shared_enums import SBPType
+from gso.utils.types.geant_ids import IMPORTED_GS_ID
 from gso.utils.types.ip_address import IPAddress, IPv4AddressType, IPV4Netmask, IPv6AddressType, IPV6Netmask
 from gso.utils.types.virtual_identifiers import VLAN_ID
 
@@ -50,7 +51,7 @@ def initial_input_form_generator() -> FormGenerator:
     class ServiceBindingPort(BaseModel):
         edge_port: UUIDstr
         ap_type: str
-        gs_id: str
+        gs_id: IMPORTED_GS_ID
         sbp_type: SBPType = SBPType.L3
         is_tagged: bool = False
         vlan_id: VLAN_ID
diff --git a/test/workflows/edge_port/test_modify_edge_port.py b/test/workflows/edge_port/test_modify_edge_port.py
index 25adcd85..0f7ed140 100644
--- a/test/workflows/edge_port/test_modify_edge_port.py
+++ b/test/workflows/edge_port/test_modify_edge_port.py
@@ -1,6 +1,7 @@
 from unittest.mock import patch
 
 import pytest
+from pydantic_forms.exceptions import FormValidationError
 
 from gso.products.product_types.edge_port import EdgePort
 from gso.utils.types.interfaces import PhysicalPortCapacity
@@ -37,6 +38,22 @@ def input_form_wizard_data(request, faker, edge_port_subscription_factory, partn
     ]
 
 
+@pytest.mark.workflow()
+@pytest.mark.parametrize("invalid_ga_id", ["GS-11111", "GA-1234", "GA_12345", "GA-100000"])
+def test_modify_edge_port_with_invalid_ga_id(
+    input_form_wizard_data, faker, invalid_ga_id, iptrunk_side_subscription_factory, iptrunk_subscription_factory
+):
+    input_data = input_form_wizard_data
+    input_data[1]["ga_id"] = invalid_ga_id
+    iptrunk_subscription_factory(
+        iptrunk_sides=[iptrunk_side_subscription_factory(ga_id="GA-11111"), iptrunk_side_subscription_factory()]
+    )
+
+    #  Run workflow
+    with pytest.raises(FormValidationError):
+        run_workflow("modify_edge_port", input_data)
+
+
 @pytest.mark.workflow()
 @patch("gso.services.lso_client._send_request")
 @patch("gso.services.netbox_client.NetboxClient.get_available_interfaces")
diff --git a/test/workflows/iptrunk/test_create_iptrunk.py b/test/workflows/iptrunk/test_create_iptrunk.py
index a45f3bf1..d0f9cfa5 100644
--- a/test/workflows/iptrunk/test_create_iptrunk.py
+++ b/test/workflows/iptrunk/test_create_iptrunk.py
@@ -151,9 +151,7 @@ def test_successful_iptrunk_creation_with_standard_lso_result(
     ])
     assert subscription.status == "provisioning"
     assert subscription.iptrunk.gs_id is not None
-    assert subscription.description == (
-        f"IP trunk {sorted_sides[0]} {sorted_sides[1]}, gs_id:{subscription.iptrunk.gs_id}"
-    )
+    assert subscription.description == f"IP trunk {sorted_sides[0]} {sorted_sides[1]}, {subscription.iptrunk.gs_id}"
 
     assert mock_execute_playbook.call_count == 6
     #  We search for 6 hosts in total, 2 in a /31 and 4 in a /126
diff --git a/test/workflows/iptrunk/test_modify_trunk_interface.py b/test/workflows/iptrunk/test_modify_trunk_interface.py
index 12f49b90..d8a96588 100644
--- a/test/workflows/iptrunk/test_modify_trunk_interface.py
+++ b/test/workflows/iptrunk/test_modify_trunk_interface.py
@@ -1,6 +1,7 @@
 from unittest.mock import patch
 
 import pytest
+from pydantic_forms.exceptions import FormValidationError
 
 from gso.products.product_blocks.iptrunk import IptrunkType
 from gso.products.product_types.iptrunk import Iptrunk
@@ -164,7 +165,7 @@ def test_iptrunk_modify_trunk_interface_success(
         subscription.iptrunk.iptrunk_sides[0].iptrunk_side_node.router_site.site_name,
         subscription.iptrunk.iptrunk_sides[1].iptrunk_side_node.router_site.site_name,
     ])
-    assert subscription.description == f"IP trunk {side_names[0]} {side_names[1]}, gs_id:{new_sid}"
+    assert subscription.description == f"IP trunk {side_names[0]} {side_names[1]}, {new_sid}"
     assert subscription.iptrunk.gs_id == input_form_iptrunk_data[1]["gs_id"]
     assert subscription.iptrunk.iptrunk_description == input_form_iptrunk_data[1]["iptrunk_description"]
     assert subscription.iptrunk.iptrunk_type == input_form_iptrunk_data[1]["iptrunk_type"]
@@ -192,3 +193,19 @@ def test_iptrunk_modify_trunk_interface_success(
             member.interface_description
             == _find_interface_by_name(new_side_b_ae_members, member.interface_name).interface_description
         )
+
+
+@pytest.mark.workflow()
+@pytest.mark.parametrize("invalid_ga_id", ["GA-11111", "GS-1234", "GS_12345", "GS-100000"])
+def test_modify_iptrunk_interface_with_invalid_ga_id(
+    input_form_iptrunk_data, faker, invalid_ga_id, iptrunk_side_subscription_factory, iptrunk_subscription_factory
+):
+    input_data = input_form_iptrunk_data
+    input_data[3]["side_a_ga_id"] = invalid_ga_id
+    iptrunk_subscription_factory(
+        iptrunk_sides=[iptrunk_side_subscription_factory(ga_id="GA-11111"), iptrunk_side_subscription_factory()]
+    )
+
+    #  Run workflow
+    with pytest.raises(FormValidationError):
+        run_workflow("modify_edge_port", input_data)
-- 
GitLab