diff --git a/gso/utils/types/unique_field.py b/gso/utils/types/unique_field.py index 32c3d9de43560998dea262b512625de60363de44..87962bd7a2b6c1d08297fb4912c24034ec495aac 100644 --- a/gso/utils/types/unique_field.py +++ b/gso/utils/types/unique_field.py @@ -1,21 +1,27 @@ """An input field that must be unique in the database.""" +from functools import partial from typing import Annotated, TypeVar from pydantic import AfterValidator from pydantic_core.core_schema import ValidationInfo +from pydantic_forms.types import UUIDstr from gso.services import subscriptions T = TypeVar("T") -def validate_field_is_unique(value: T, info: ValidationInfo) -> T: +def validate_field_is_unique(subscription_id: UUIDstr, value: T, info: ValidationInfo) -> T: """Validate that a field is unique.""" - if len(subscriptions.get_active_subscriptions_by_field_and_value(str(info.field_name), str(value))) > 0: + matched_subscriptions = subscriptions.get_active_subscriptions_by_field_and_value(str(info.field_name), str(value)) + matched_subscriptions = list( + filter(lambda match: str(match.subscription_id) != subscription_id, matched_subscriptions) + ) + if len(matched_subscriptions) > 0: msg = f"{info.field_name} must be unique" raise ValueError(msg) return value -UniqueField = Annotated[T, AfterValidator(validate_field_is_unique)] +UniqueField = Annotated[T, AfterValidator(partial(validate_field_is_unique, ""))] diff --git a/gso/workflows/site/modify_site.py b/gso/workflows/site/modify_site.py index 9c94e55032ae9712856603463b71299cf8e37fe7..141fcc4320ddd925bc050f7afb1f0193c4a535ef 100644 --- a/gso/workflows/site/modify_site.py +++ b/gso/workflows/site/modify_site.py @@ -1,5 +1,6 @@ """A modification workflow for a site.""" +from functools import partial from typing import Annotated from orchestrator.forms import FormPage @@ -13,14 +14,14 @@ from orchestrator.workflows.steps import ( unsync, ) from orchestrator.workflows.utils import wrap_modify_initial_input_form -from pydantic import ConfigDict +from pydantic import AfterValidator, ConfigDict from pydantic_forms.validators import ReadOnlyField from gso.products.product_blocks.site import SiteTier from gso.products.product_types.site import Site from gso.utils.types.coordinates import LatitudeCoordinate, LongitudeCoordinate from gso.utils.types.ip_address import IPAddress -from gso.utils.types.unique_field import UniqueField +from gso.utils.types.unique_field import validate_field_is_unique def initial_input_form_generator(subscription_id: UUIDstr) -> FormGenerator: @@ -36,10 +37,16 @@ def initial_input_form_generator(subscription_id: UUIDstr) -> FormGenerator: site_country_code: ReadOnlyField(subscription.site.site_country_code, default_type=str) # type: ignore[valid-type] site_latitude: LatitudeCoordinate = subscription.site.site_latitude site_longitude: LongitudeCoordinate = subscription.site.site_longitude - site_bgp_community_id: UniqueField[int] = subscription.site.site_bgp_community_id - site_internal_id: UniqueField[int] = subscription.site.site_internal_id + site_bgp_community_id: Annotated[int, AfterValidator(partial(validate_field_is_unique, subscription_id))] = ( + subscription.site.site_bgp_community_id + ) + site_internal_id: Annotated[int, AfterValidator(partial(validate_field_is_unique, subscription_id))] = ( + subscription.site.site_internal_id + ) site_tier: ReadOnlyField(subscription.site.site_tier, default_type=SiteTier) # type: ignore[valid-type] - site_ts_address: Annotated[IPAddress, UniqueField] | None = subscription.site.site_ts_address + site_ts_address: ( + Annotated[IPAddress, AfterValidator(partial(validate_field_is_unique, subscription_id))] | None + ) = subscription.site.site_ts_address user_input = yield ModifySiteForm diff --git a/test/fixtures/site_fixtures.py b/test/fixtures/site_fixtures.py index e0a3a0bd6a38d3466f4fd6ec621183a31a816d3d..6fa904183f0d8effe142f2c199f180494cbfd7c2 100644 --- a/test/fixtures/site_fixtures.py +++ b/test/fixtures/site_fixtures.py @@ -7,42 +7,34 @@ from gso.products import ProductName from gso.products.product_blocks.site import SiteTier from gso.products.product_types.site import ImportedSiteInactive, SiteInactive from gso.services import subscriptions +from gso.utils.types.coordinates import LatitudeCoordinate, LongitudeCoordinate +from gso.utils.types.ip_address import IPAddress +from gso.utils.types.site_name import SiteName @pytest.fixture() def site_subscription_factory(faker, geant_partner): def subscription_create( - description=None, - start_date="2023-05-24T00:00:00+00:00", - site_name=None, - site_city=None, - site_country=None, - site_country_code=None, - site_latitude=None, - site_longitude=None, - site_bgp_community_id=None, - site_internal_id=None, - site_tier=SiteTier.TIER1, - site_ts_address=None, + description: str | None = None, + site_name: SiteName | None = None, + site_city: str | None = None, + site_country: str | None = None, + site_country_code: str | None = None, + site_latitude: LatitudeCoordinate | None = None, + site_longitude: LongitudeCoordinate | None = None, + site_bgp_community_id: int | None = None, + site_internal_id: int | None = None, + site_tier: SiteTier | None = None, + site_ts_address: IPAddress | None = None, status: SubscriptionLifecycle | None = None, partner: dict | None = None, + start_date="2023-05-24T00:00:00+00:00", *, is_imported: bool | None = True, ) -> UUIDstr: if partner is None: partner = geant_partner - description = description or "Site Subscription" - site_name = site_name or faker.site_name() - site_city = site_city or faker.city() - site_country = site_country or faker.country() - site_country_code = site_country_code or faker.country_code() - site_latitude = site_latitude or str(faker.latitude()) - site_longitude = site_longitude or str(faker.longitude()) - site_bgp_community_id = site_bgp_community_id or faker.pyint() - site_internal_id = site_internal_id or faker.pyint() - site_ts_address = site_ts_address or faker.ipv4() - if is_imported: product_id = subscriptions.get_product_id_by_name(ProductName.SITE) site_subscription = SiteInactive.from_product_id(product_id, customer_id=partner["partner_id"], insync=True) @@ -52,19 +44,19 @@ def site_subscription_factory(faker, geant_partner): product_id, customer_id=partner["partner_id"], insync=True ) - site_subscription.site.site_city = site_city - site_subscription.site.site_name = site_name - site_subscription.site.site_country = site_country - site_subscription.site.site_country_code = site_country_code - site_subscription.site.site_latitude = site_latitude - site_subscription.site.site_longitude = site_longitude - site_subscription.site.site_bgp_community_id = site_bgp_community_id - site_subscription.site.site_internal_id = site_internal_id - site_subscription.site.site_tier = site_tier - site_subscription.site.site_ts_address = site_ts_address + site_subscription.site.site_city = site_city or faker.city() + site_subscription.site.site_name = site_name or faker.site_name() + site_subscription.site.site_country = site_country or faker.country() + site_subscription.site.site_country_code = site_country_code or faker.country_code() + site_subscription.site.site_latitude = site_latitude or str(faker.latitude()) + site_subscription.site.site_longitude = site_longitude or str(faker.longitude()) + site_subscription.site.site_bgp_community_id = site_bgp_community_id or faker.pyint() + site_subscription.site.site_internal_id = site_internal_id or faker.pyint() + site_subscription.site.site_tier = site_tier or SiteTier.TIER1 + site_subscription.site.site_ts_address = site_ts_address or faker.ipv4() site_subscription = SubscriptionModel.from_other_lifecycle(site_subscription, SubscriptionLifecycle.ACTIVE) - site_subscription.description = description + site_subscription.description = description or "Site Subscription" site_subscription.start_date = start_date if status: site_subscription.status = status diff --git a/test/workflows/site/test_modify_site.py b/test/workflows/site/test_modify_site.py index 0db1a50fcd9f7880aeb574fa465afc08e832e808..a79b3b01b111f4aaf200d70b028bc6f487dd20db 100644 --- a/test/workflows/site/test_modify_site.py +++ b/test/workflows/site/test_modify_site.py @@ -6,14 +6,14 @@ from test.workflows import assert_complete, extract_state, run_workflow @pytest.mark.workflow() -def test_modify_site(responses, site_subscription_factory): +def test_modify_site(responses, site_subscription_factory, faker): subscription_id = site_subscription_factory() initial_site_data = [ {"subscription_id": subscription_id}, { - "site_bgp_community_id": 10, - "site_internal_id": 20, - "site_ts_address": "127.0.0.1", + "site_bgp_community_id": faker.pyint(), + "site_internal_id": faker.pyint(), + "site_ts_address": faker.ipv4(), }, ] result, _, _ = run_workflow("modify_site", initial_site_data) @@ -28,16 +28,54 @@ def test_modify_site(responses, site_subscription_factory): @pytest.mark.workflow() -def test_modify_site_with_invalid_data(responses, site_subscription_factory): - subscription_a = Site.from_subscription(site_subscription_factory()) - subscription_b = Site.from_subscription(site_subscription_factory()) +def test_modify_site_with_duplicate_bgp_community_id(faker, site_subscription_factory): + duplicate_bgp_community_id = faker.pyint() + + site_subscription_factory(site_bgp_community_id=duplicate_bgp_community_id) + subscription_b = site_subscription_factory() initial_site_data = [ - {"subscription_id": subscription_b.subscription_id}, + {"subscription_id": subscription_b}, { - "site_bgp_community_id": subscription_a.site.site_bgp_community_id, + "site_bgp_community_id": duplicate_bgp_community_id, }, ] with pytest.raises(FormValidationError, match="site_bgp_community_id must be unique"): run_workflow("modify_site", initial_site_data) + + +@pytest.mark.workflow() +def test_modify_site_with_duplicate_internal_id(faker, site_subscription_factory): + duplicate_internal_id = faker.pyint() + + site_subscription_factory(site_internal_id=duplicate_internal_id) + subscription_b = site_subscription_factory() + + initial_site_data = [ + {"subscription_id": subscription_b}, + { + "site_internal_id": duplicate_internal_id, + }, + ] + + with pytest.raises(FormValidationError, match="site_internal_id must be unique"): + run_workflow("modify_site", initial_site_data) + + +@pytest.mark.workflow() +def test_modify_site_with_duplicate_ts_address(faker, site_subscription_factory): + duplicate_ts_address = faker.ipv4() + + site_subscription_factory(site_ts_address=duplicate_ts_address) + subscription_b = site_subscription_factory() + + initial_site_data = [ + {"subscription_id": subscription_b}, + { + "site_ts_address": duplicate_ts_address, + }, + ] + + with pytest.raises(FormValidationError, match="site_ts_address must be unique"): + run_workflow("modify_site", initial_site_data)