From 4de070dd940d96476c74c69f0ca33cf92ecbcda4 Mon Sep 17 00:00:00 2001
From: Karel van Klink <karel.vanklink@geant.org>
Date: Tue, 10 Sep 2024 09:38:28 +0200
Subject: [PATCH] Update type for site name

---
 gso/products/product_blocks/site.py | 7 ++++---
 gso/types/site_name.py              | 3 ++-
 gso/types/unique_field.py           | 8 +++++---
 test/fixtures.py                    | 2 +-
 4 files changed, 12 insertions(+), 8 deletions(-)

diff --git a/gso/products/product_blocks/site.py b/gso/products/product_blocks/site.py
index 0a5ec604..436288c5 100644
--- a/gso/products/product_blocks/site.py
+++ b/gso/products/product_blocks/site.py
@@ -5,6 +5,7 @@ from orchestrator.types import SubscriptionLifecycle, strEnum
 
 from gso.types.coordinates import LatitudeCoordinate, LongitudeCoordinate
 from gso.types.ip_address import IPAddress
+from gso.types.site_name import SiteName
 
 
 class SiteTier(strEnum):
@@ -27,7 +28,7 @@ class SiteBlockInactive(
 ):
     """A site that's currently inactive, see :class:`SiteBlock`."""
 
-    site_name: str | None = None
+    site_name: SiteName | None = None
     site_city: str | None = None
     site_country: str | None = None
     site_country_code: str | None = None
@@ -42,7 +43,7 @@ class SiteBlockInactive(
 class SiteBlockProvisioning(SiteBlockInactive, lifecycle=[SubscriptionLifecycle.PROVISIONING]):
     """A site that's currently being provisioned, see :class:`SiteBlock`."""
 
-    site_name: str
+    site_name: SiteName
     site_city: str
     site_country: str
     site_country_code: str
@@ -59,7 +60,7 @@ class SiteBlock(SiteBlockProvisioning, lifecycle=[SubscriptionLifecycle.ACTIVE])
 
     #:  The name of the site, that will dictate part of the :term:`FQDN` of routers that are hosted at this site. For
     #:  example: ``router.X.Y.geant.net``, where X denotes the name of the site.
-    site_name: str
+    site_name: SiteName
     #:  The city at which the site is located.
     site_city: str
     #:  The country in which the site is located.
diff --git a/gso/types/site_name.py b/gso/types/site_name.py
index 2d98096f..903b3468 100644
--- a/gso/types/site_name.py
+++ b/gso/types/site_name.py
@@ -6,7 +6,7 @@ from typing import Annotated
 from pydantic import AfterValidator
 
 
-def validate_site_name(site_name: str) -> None:
+def validate_site_name(site_name: str) -> str:
     """Validate the site name.
 
     The site name must consist of three uppercase letters, optionally followed by a single digit.
@@ -18,6 +18,7 @@ def validate_site_name(site_name: str) -> None:
             f"digit (0-9). Received: {site_name}"
         )
         raise ValueError(msg)
+    return site_name
 
 
 SiteName = Annotated[str, AfterValidator(validate_site_name)]
diff --git a/gso/types/unique_field.py b/gso/types/unique_field.py
index 1181a6b0..32c3d9de 100644
--- a/gso/types/unique_field.py
+++ b/gso/types/unique_field.py
@@ -7,13 +7,15 @@ from pydantic_core.core_schema import ValidationInfo
 
 from gso.services import subscriptions
 
+T = TypeVar("T")
+
 
-def validate_field_is_unique(value: str | int, info: ValidationInfo) -> None:
+def validate_field_is_unique(value: T, info: ValidationInfo) -> T:
     """Validate that a field is unique."""
     if len(subscriptions.get_active_subscriptions_by_field_and_value(str(info.field_name), str(value))) > 0:
         msg = f"{info.field_name} must be unique"
         raise ValueError(msg)
+    return value
 
 
-T = TypeVar("T")
-UniqueField = Annotated[T, str | int, AfterValidator(validate_field_is_unique)]
+UniqueField = Annotated[T, AfterValidator(validate_field_is_unique)]
diff --git a/test/fixtures.py b/test/fixtures.py
index e642c412..4fa5a2a0 100644
--- a/test/fixtures.py
+++ b/test/fixtures.py
@@ -65,7 +65,7 @@ def site_subscription_factory(faker, geant_partner):
             partner = geant_partner
 
         description = description or "Site Subscription"
-        site_name = site_name or faker.domain_word()
+        site_name = site_name or faker.site_name()
         site_city = site_city or faker.city()
         site_country = site_country or faker.country()
         site_country_code = site_country_code or faker.country_code()
-- 
GitLab