diff --git a/gso/utils/helpers.py b/gso/utils/helpers.py index 78bf4c28d02a57515b366c9a38608a5de336991d..5e3b149ac45f61168c89a925a3fb0426c48fdf1b 100644 --- a/gso/utils/helpers.py +++ b/gso/utils/helpers.py @@ -187,10 +187,9 @@ def validate_site_name(site_name: str) -> str: The site name must consist of three uppercase letters (A-Z) followed by an optional single digit (0-9). """ pattern = re.compile(r"^[A-Z]{3}[0-9]?$") - if pattern.match(site_name): - return site_name - else: + if not pattern.match(site_name): raise ValueError( "Enter a valid site name. It must consist of three uppercase letters (A-Z) followed by an optional single " "digit (0-9)." ) + return site_name diff --git a/gso/workflows/tasks/import_site.py b/gso/workflows/tasks/import_site.py index af96fca24e6e67d0ffac5013c593eed238a16706..fd3a55cc77b80250951076302e21668b75c151f2 100644 --- a/gso/workflows/tasks/import_site.py +++ b/gso/workflows/tasks/import_site.py @@ -5,12 +5,20 @@ from orchestrator.targets import Target from orchestrator.types import FormGenerator, State, SubscriptionLifecycle from orchestrator.workflow import StepList, done, init, step, workflow from orchestrator.workflows.steps import resync, set_status, store_process_subscription +from pydantic import validator +from pydantic.fields import ModelField from gso.products import ProductType from gso.products.product_blocks.site import SiteTier from gso.products.product_types.site import SiteInactive from gso.services import subscriptions from gso.services.crm import get_customer_by_name +from gso.utils.helpers import ( + validate_country_code, + validate_ipv4_or_ipv6, + validate_site_fields_is_unique, + validate_site_name, +) from gso.workflows.site.create_site import initialize_subscription @@ -43,6 +51,32 @@ def generate_initial_input_form() -> FormGenerator: site_ts_address: str customer: str + @validator("site_ts_address", allow_reuse=True) + def validate_ts_address(cls, site_ts_address: str) -> str: + validate_site_fields_is_unique("site_ts_address", site_ts_address) + validate_ipv4_or_ipv6(site_ts_address) + return site_ts_address + + @validator("site_country_code", allow_reuse=True) + def country_code_must_exist(cls, country_code: str) -> str: + validate_country_code(country_code) + return country_code + + @validator("site_internal_id", "site_bgp_community_id", allow_reuse=True) + def validate_unique_fields(cls, value: str, field: ModelField) -> str | int: + return validate_site_fields_is_unique(field.name, value) + + @validator("site_name", allow_reuse=True) + def site_name_must_be_valid(cls, site_name: str) -> str: + """Validate the site name. + + The site name must consist of three uppercase letters (A-Z) followed + by an optional single digit (0-9). + """ + validate_site_fields_is_unique("site_name", site_name) + validate_site_name(site_name) + return site_name + user_input = yield ImportSite return user_input.dict() diff --git a/test/imports/test_imports.py b/test/imports/test_imports.py index 5ebdc8e6ce19ab8a469fc70156cf340e5b0b74c2..3cc72fcdb839c3bacb28a538c7f5de175a810f5f 100644 --- a/test/imports/test_imports.py +++ b/test/imports/test_imports.py @@ -295,7 +295,3 @@ def test_import_iptrunk_fails_on_side_a_and_b_members_mismatch( assert response.json() == { "detail": [{"loc": ["body", "__root__"], "msg": "Mismatch between Side A and B members", "type": "value_error"}] } - - def test_site_name_is_valid(): - site_model = SiteImportModel(site_name="123456") - assert site_model is not None