From c1eadccbb4ecd780b3064812ba5ab60ec848a024 Mon Sep 17 00:00:00 2001
From: Karel van Klink <karel.vanklink@geant.org>
Date: Thu, 24 Oct 2024 15:43:33 +0200
Subject: [PATCH] Fix modify_site workflow where a site was modified to contain
 the same information as before

---
 gso/utils/types/unique_field.py         | 12 +++--
 gso/workflows/site/modify_site.py       | 17 ++++---
 test/fixtures/site_fixtures.py          | 60 +++++++++++--------------
 test/workflows/site/test_modify_site.py | 56 +++++++++++++++++++----
 4 files changed, 94 insertions(+), 51 deletions(-)

diff --git a/gso/utils/types/unique_field.py b/gso/utils/types/unique_field.py
index 32c3d9de..87962bd7 100644
--- a/gso/utils/types/unique_field.py
+++ b/gso/utils/types/unique_field.py
@@ -1,21 +1,27 @@
 """An input field that must be unique in the database."""
 
+from functools import partial
 from typing import Annotated, TypeVar
 
 from pydantic import AfterValidator
 from pydantic_core.core_schema import ValidationInfo
+from pydantic_forms.types import UUIDstr
 
 from gso.services import subscriptions
 
 T = TypeVar("T")
 
 
-def validate_field_is_unique(value: T, info: ValidationInfo) -> T:
+def validate_field_is_unique(subscription_id: UUIDstr, value: T, info: ValidationInfo) -> T:
     """Validate that a field is unique."""
-    if len(subscriptions.get_active_subscriptions_by_field_and_value(str(info.field_name), str(value))) > 0:
+    matched_subscriptions = subscriptions.get_active_subscriptions_by_field_and_value(str(info.field_name), str(value))
+    matched_subscriptions = list(
+        filter(lambda match: str(match.subscription_id) != subscription_id, matched_subscriptions)
+    )
+    if len(matched_subscriptions) > 0:
         msg = f"{info.field_name} must be unique"
         raise ValueError(msg)
     return value
 
 
-UniqueField = Annotated[T, AfterValidator(validate_field_is_unique)]
+UniqueField = Annotated[T, AfterValidator(partial(validate_field_is_unique, ""))]
diff --git a/gso/workflows/site/modify_site.py b/gso/workflows/site/modify_site.py
index 9c94e550..141fcc43 100644
--- a/gso/workflows/site/modify_site.py
+++ b/gso/workflows/site/modify_site.py
@@ -1,5 +1,6 @@
 """A modification workflow for a site."""
 
+from functools import partial
 from typing import Annotated
 
 from orchestrator.forms import FormPage
@@ -13,14 +14,14 @@ from orchestrator.workflows.steps import (
     unsync,
 )
 from orchestrator.workflows.utils import wrap_modify_initial_input_form
-from pydantic import ConfigDict
+from pydantic import AfterValidator, ConfigDict
 from pydantic_forms.validators import ReadOnlyField
 
 from gso.products.product_blocks.site import SiteTier
 from gso.products.product_types.site import Site
 from gso.utils.types.coordinates import LatitudeCoordinate, LongitudeCoordinate
 from gso.utils.types.ip_address import IPAddress
-from gso.utils.types.unique_field import UniqueField
+from gso.utils.types.unique_field import validate_field_is_unique
 
 
 def initial_input_form_generator(subscription_id: UUIDstr) -> FormGenerator:
@@ -36,10 +37,16 @@ def initial_input_form_generator(subscription_id: UUIDstr) -> FormGenerator:
         site_country_code: ReadOnlyField(subscription.site.site_country_code, default_type=str)  # type: ignore[valid-type]
         site_latitude: LatitudeCoordinate = subscription.site.site_latitude
         site_longitude: LongitudeCoordinate = subscription.site.site_longitude
-        site_bgp_community_id: UniqueField[int] = subscription.site.site_bgp_community_id
-        site_internal_id: UniqueField[int] = subscription.site.site_internal_id
+        site_bgp_community_id: Annotated[int, AfterValidator(partial(validate_field_is_unique, subscription_id))] = (
+            subscription.site.site_bgp_community_id
+        )
+        site_internal_id: Annotated[int, AfterValidator(partial(validate_field_is_unique, subscription_id))] = (
+            subscription.site.site_internal_id
+        )
         site_tier: ReadOnlyField(subscription.site.site_tier, default_type=SiteTier)  # type: ignore[valid-type]
-        site_ts_address: Annotated[IPAddress, UniqueField] | None = subscription.site.site_ts_address
+        site_ts_address: (
+            Annotated[IPAddress, AfterValidator(partial(validate_field_is_unique, subscription_id))] | None
+        ) = subscription.site.site_ts_address
 
     user_input = yield ModifySiteForm
 
diff --git a/test/fixtures/site_fixtures.py b/test/fixtures/site_fixtures.py
index e0a3a0bd..6fa90418 100644
--- a/test/fixtures/site_fixtures.py
+++ b/test/fixtures/site_fixtures.py
@@ -7,42 +7,34 @@ from gso.products import ProductName
 from gso.products.product_blocks.site import SiteTier
 from gso.products.product_types.site import ImportedSiteInactive, SiteInactive
 from gso.services import subscriptions
+from gso.utils.types.coordinates import LatitudeCoordinate, LongitudeCoordinate
+from gso.utils.types.ip_address import IPAddress
+from gso.utils.types.site_name import SiteName
 
 
 @pytest.fixture()
 def site_subscription_factory(faker, geant_partner):
     def subscription_create(
-        description=None,
-        start_date="2023-05-24T00:00:00+00:00",
-        site_name=None,
-        site_city=None,
-        site_country=None,
-        site_country_code=None,
-        site_latitude=None,
-        site_longitude=None,
-        site_bgp_community_id=None,
-        site_internal_id=None,
-        site_tier=SiteTier.TIER1,
-        site_ts_address=None,
+        description: str | None = None,
+        site_name: SiteName | None = None,
+        site_city: str | None = None,
+        site_country: str | None = None,
+        site_country_code: str | None = None,
+        site_latitude: LatitudeCoordinate | None = None,
+        site_longitude: LongitudeCoordinate | None = None,
+        site_bgp_community_id: int | None = None,
+        site_internal_id: int | None = None,
+        site_tier: SiteTier | None = None,
+        site_ts_address: IPAddress | None = None,
         status: SubscriptionLifecycle | None = None,
         partner: dict | None = None,
+        start_date="2023-05-24T00:00:00+00:00",
         *,
         is_imported: bool | None = True,
     ) -> UUIDstr:
         if partner is None:
             partner = geant_partner
 
-        description = description or "Site Subscription"
-        site_name = site_name or faker.site_name()
-        site_city = site_city or faker.city()
-        site_country = site_country or faker.country()
-        site_country_code = site_country_code or faker.country_code()
-        site_latitude = site_latitude or str(faker.latitude())
-        site_longitude = site_longitude or str(faker.longitude())
-        site_bgp_community_id = site_bgp_community_id or faker.pyint()
-        site_internal_id = site_internal_id or faker.pyint()
-        site_ts_address = site_ts_address or faker.ipv4()
-
         if is_imported:
             product_id = subscriptions.get_product_id_by_name(ProductName.SITE)
             site_subscription = SiteInactive.from_product_id(product_id, customer_id=partner["partner_id"], insync=True)
@@ -52,19 +44,19 @@ def site_subscription_factory(faker, geant_partner):
                 product_id, customer_id=partner["partner_id"], insync=True
             )
 
-        site_subscription.site.site_city = site_city
-        site_subscription.site.site_name = site_name
-        site_subscription.site.site_country = site_country
-        site_subscription.site.site_country_code = site_country_code
-        site_subscription.site.site_latitude = site_latitude
-        site_subscription.site.site_longitude = site_longitude
-        site_subscription.site.site_bgp_community_id = site_bgp_community_id
-        site_subscription.site.site_internal_id = site_internal_id
-        site_subscription.site.site_tier = site_tier
-        site_subscription.site.site_ts_address = site_ts_address
+        site_subscription.site.site_city = site_city or faker.city()
+        site_subscription.site.site_name = site_name or faker.site_name()
+        site_subscription.site.site_country = site_country or faker.country()
+        site_subscription.site.site_country_code = site_country_code or faker.country_code()
+        site_subscription.site.site_latitude = site_latitude or str(faker.latitude())
+        site_subscription.site.site_longitude = site_longitude or str(faker.longitude())
+        site_subscription.site.site_bgp_community_id = site_bgp_community_id or faker.pyint()
+        site_subscription.site.site_internal_id = site_internal_id or faker.pyint()
+        site_subscription.site.site_tier = site_tier or SiteTier.TIER1
+        site_subscription.site.site_ts_address = site_ts_address or faker.ipv4()
 
         site_subscription = SubscriptionModel.from_other_lifecycle(site_subscription, SubscriptionLifecycle.ACTIVE)
-        site_subscription.description = description
+        site_subscription.description = description or "Site Subscription"
         site_subscription.start_date = start_date
         if status:
             site_subscription.status = status
diff --git a/test/workflows/site/test_modify_site.py b/test/workflows/site/test_modify_site.py
index 0db1a50f..a79b3b01 100644
--- a/test/workflows/site/test_modify_site.py
+++ b/test/workflows/site/test_modify_site.py
@@ -6,14 +6,14 @@ from test.workflows import assert_complete, extract_state, run_workflow
 
 
 @pytest.mark.workflow()
-def test_modify_site(responses, site_subscription_factory):
+def test_modify_site(responses, site_subscription_factory, faker):
     subscription_id = site_subscription_factory()
     initial_site_data = [
         {"subscription_id": subscription_id},
         {
-            "site_bgp_community_id": 10,
-            "site_internal_id": 20,
-            "site_ts_address": "127.0.0.1",
+            "site_bgp_community_id": faker.pyint(),
+            "site_internal_id": faker.pyint(),
+            "site_ts_address": faker.ipv4(),
         },
     ]
     result, _, _ = run_workflow("modify_site", initial_site_data)
@@ -28,16 +28,54 @@ def test_modify_site(responses, site_subscription_factory):
 
 
 @pytest.mark.workflow()
-def test_modify_site_with_invalid_data(responses, site_subscription_factory):
-    subscription_a = Site.from_subscription(site_subscription_factory())
-    subscription_b = Site.from_subscription(site_subscription_factory())
+def test_modify_site_with_duplicate_bgp_community_id(faker, site_subscription_factory):
+    duplicate_bgp_community_id = faker.pyint()
+
+    site_subscription_factory(site_bgp_community_id=duplicate_bgp_community_id)
+    subscription_b = site_subscription_factory()
 
     initial_site_data = [
-        {"subscription_id": subscription_b.subscription_id},
+        {"subscription_id": subscription_b},
         {
-            "site_bgp_community_id": subscription_a.site.site_bgp_community_id,
+            "site_bgp_community_id": duplicate_bgp_community_id,
         },
     ]
 
     with pytest.raises(FormValidationError, match="site_bgp_community_id must be unique"):
         run_workflow("modify_site", initial_site_data)
+
+
+@pytest.mark.workflow()
+def test_modify_site_with_duplicate_internal_id(faker, site_subscription_factory):
+    duplicate_internal_id = faker.pyint()
+
+    site_subscription_factory(site_internal_id=duplicate_internal_id)
+    subscription_b = site_subscription_factory()
+
+    initial_site_data = [
+        {"subscription_id": subscription_b},
+        {
+            "site_internal_id": duplicate_internal_id,
+        },
+    ]
+
+    with pytest.raises(FormValidationError, match="site_internal_id must be unique"):
+        run_workflow("modify_site", initial_site_data)
+
+
+@pytest.mark.workflow()
+def test_modify_site_with_duplicate_ts_address(faker, site_subscription_factory):
+    duplicate_ts_address = faker.ipv4()
+
+    site_subscription_factory(site_ts_address=duplicate_ts_address)
+    subscription_b = site_subscription_factory()
+
+    initial_site_data = [
+        {"subscription_id": subscription_b},
+        {
+            "site_ts_address": duplicate_ts_address,
+        },
+    ]
+
+    with pytest.raises(FormValidationError, match="site_ts_address must be unique"):
+        run_workflow("modify_site", initial_site_data)
-- 
GitLab