Skip to content
Snippets Groups Projects
Verified Commit 42ba200f authored by Karel van Klink's avatar Karel van Klink :smiley_cat:
Browse files

Move field validators into annotated types

parent 27934bdf
No related branches found
No related tags found
1 merge request!265Feature/refactor validators
Showing
with 293 additions and 156 deletions
......@@ -13,7 +13,7 @@ import yaml
from orchestrator.db import db
from orchestrator.services.processes import start_process
from orchestrator.types import SubscriptionLifecycle
from pydantic import BaseModel, EmailStr, ValidationError, field_validator, model_validator
from pydantic import BaseModel, ValidationError, field_validator, model_validator
from sqlalchemy.exc import SQLAlchemyError
from gso.db.models import PartnerTable
......@@ -21,9 +21,9 @@ from gso.products import ProductType
from gso.products.product_blocks.iptrunk import IptrunkType
from gso.products.product_blocks.router import RouterRole
from gso.services.partners import (
PartnerEmail,
PartnerName,
PartnerNotFoundError,
filter_partners_by_email,
filter_partners_by_name,
get_partner_by_name,
)
from gso.services.subscriptions import (
......@@ -41,27 +41,8 @@ app: typer.Typer = typer.Typer()
class CreatePartner(BaseModel):
"""Required inputs for creating a partner."""
name: str
email: EmailStr
@field_validator("name")
def validate_name(cls, name: str) -> str:
"""Validate name."""
if filter_partners_by_name(name=name, case_sensitive=False):
msg = "Partner with this name already exists."
raise ValueError(msg)
return name
@field_validator("email")
def validate_email(cls, email: str) -> EmailStr:
"""Validate email."""
email = email.lower()
if filter_partners_by_email(email=email, case_sensitive=False):
msg = "Partner with this email already exists."
raise ValueError(msg)
return email
name: PartnerName
email: PartnerEmail
class SiteImportModel(BaseSiteValidatorModel):
......
......@@ -4,6 +4,7 @@ from orchestrator.domain.base import ProductBlockModel
from orchestrator.types import SubscriptionLifecycle, strEnum
from gso.types.coordinates import LatitudeCoordinate, LongitudeCoordinate
from gso.types.ip_address import IPAddress
class SiteTier(strEnum):
......@@ -35,7 +36,7 @@ class SiteBlockInactive(
site_internal_id: int | None = None
site_bgp_community_id: int | None = None
site_tier: SiteTier | None = None
site_ts_address: str | None = None
site_ts_address: IPAddress | None = None
class SiteBlockProvisioning(SiteBlockInactive, lifecycle=[SubscriptionLifecycle.PROVISIONING]):
......@@ -50,7 +51,7 @@ class SiteBlockProvisioning(SiteBlockInactive, lifecycle=[SubscriptionLifecycle.
site_internal_id: int
site_bgp_community_id: int
site_tier: SiteTier
site_ts_address: str
site_ts_address: IPAddress
class SiteBlock(SiteBlockProvisioning, lifecycle=[SubscriptionLifecycle.ACTIVE]):
......@@ -79,4 +80,4 @@ class SiteBlock(SiteBlockProvisioning, lifecycle=[SubscriptionLifecycle.ACTIVE])
#: The address of the terminal server that this router is connected to. The terminal server provides out of band
#: access. This is required in case a link goes down, or when a router is initially added to the network and it
#: does not have any IP trunks connected to it.
site_ts_address: str
site_ts_address: IPAddress
"""A module that returns the partners available in :term:`GSO`."""
from datetime import datetime
from typing import Any
from typing import Annotated, Any
from uuid import uuid4
from orchestrator.db import db
from pydantic import BaseModel, ConfigDict, EmailStr, Field
from pydantic import AfterValidator, BaseModel, ConfigDict, EmailStr, Field
from sqlalchemy import func
from sqlalchemy.exc import NoResultFound
from gso.db.models import PartnerTable
def validate_partner_name_unique(name: str) -> str:
"""Validate that the name of a partner is unique."""
if filter_partners_by_name(name=name, case_sensitive=False):
msg = "Partner with this name already exists."
raise ValueError(msg)
return name
def validate_partner_email_unique(email: EmailStr) -> EmailStr:
"""Validate that the e-mail address of a partner is unique."""
email = email.lower()
if filter_partners_by_email(email=email, case_sensitive=False):
msg = "Partner with this email already exists."
raise ValueError(msg)
return email
PartnerName = Annotated[str, AfterValidator(validate_partner_name_unique)]
PartnerEmail = Annotated[EmailStr, AfterValidator(validate_partner_email_unique)]
class PartnerSchema(BaseModel):
"""Partner schema."""
partner_id: str = Field(default_factory=lambda: str(uuid4()))
name: str
email: EmailStr
name: PartnerName
email: PartnerEmail
created_at: datetime = Field(default_factory=lambda: datetime.now().astimezone())
updated_at: datetime = Field(default_factory=lambda: datetime.now().astimezone())
model_config = ConfigDict(from_attributes=True)
class ModifiedPartnerSchema(PartnerSchema):
"""Partner schema when making a modification.
The name and email can be empty in this case, if they don't need changing.
"""
name: PartnerName | None = None # type: ignore[assignment]
email: PartnerEmail | None = None # type: ignore[assignment]
class PartnerNotFoundError(Exception):
"""Exception raised when a partner is not found."""
......@@ -95,7 +126,7 @@ def create_partner(
def edit_partner(
partner_data: PartnerSchema,
partner_data: ModifiedPartnerSchema,
) -> PartnerTable:
"""Edit an existing partner and update it in the database."""
partner = get_partner_by_id(partner_id=partner_data.partner_id)
......
"""A base site type for validation purposes that can be extended elsewhere."""
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from pydantic import BaseModel
from gso.products.product_blocks.site import SiteTier
from gso.types.coordinates import LatitudeCoordinate, LongitudeCoordinate
from gso.types.country_code import validate_country_code
from gso.types.ip_address import validate_ipv4_or_ipv6
from gso.types.site_name import validate_site_name
from gso.types.unique_field import validate_field_is_unique
from gso.types.country_code import CountryCode
from gso.types.ip_address import IPAddress
from gso.types.site_name import SiteName
from gso.types.unique_field import UniqueField
class BaseSiteValidatorModel(BaseModel):
"""A base site validator model extended by create site and by import site."""
site_bgp_community_id: int
site_internal_id: int
site_bgp_community_id: UniqueField[int]
site_internal_id: UniqueField[int]
site_tier: SiteTier
site_ts_address: str
site_country_code: str
site_name: str
site_ts_address: UniqueField[IPAddress]
site_country_code: CountryCode
site_name: UniqueField[SiteName]
site_city: str
site_country: str
site_latitude: LatitudeCoordinate
site_longitude: LongitudeCoordinate
partner: str
@field_validator("site_ts_address")
def validate_ts_address(cls, site_ts_address: str) -> str:
"""Validate that a terminal server address is valid."""
validate_ipv4_or_ipv6(site_ts_address)
return site_ts_address
@field_validator("site_country_code")
def country_code_must_exist(cls, country_code: str) -> str:
"""Validate that the country code exists."""
validate_country_code(country_code)
return country_code
@field_validator("site_ts_address", "site_internal_id", "site_bgp_community_id", "site_name")
def field_must_be_unique(cls, value: str | int, info: ValidationInfo) -> str | int:
"""Validate that a field is unique."""
if not info.field_name:
msg = "Field name must be provided."
raise ValueError(msg)
validate_field_is_unique(info.field_name, value)
return value
@field_validator("site_name")
def site_name_must_be_valid(cls, site_name: str) -> str:
"""Validate the site name.
The site name must consist of three uppercase letters, followed by an optional single digit.
"""
validate_site_name(site_name)
return site_name
"""Country codes."""
from typing import Annotated
import pycountry
from pydantic import AfterValidator
def validate_country_code(country_code: str) -> str:
......@@ -13,3 +16,6 @@ def validate_country_code(country_code: str) -> str:
msg = "Invalid or non-existent country code, it must be in ISO 3166-1 alpha-2 format."
raise ValueError(msg) from e
return country_code
CountryCode = Annotated[str, AfterValidator(validate_country_code)]
"""IP addresses."""
import ipaddress
from typing import Annotated
from pydantic import AfterValidator
def validate_ipv4_or_ipv6(value: str) -> str:
......@@ -12,3 +15,6 @@ def validate_ipv4_or_ipv6(value: str) -> str:
raise ValueError(msg) from e
else:
return value
IPAddress = Annotated[str, AfterValidator(validate_ipv4_or_ipv6)]
"""A router that must be present in Netbox."""
from typing import Annotated, TypeVar
from pydantic import AfterValidator
from pydantic_forms.types import UUIDstr
from gso.products.product_types.router import Router
......@@ -25,3 +28,7 @@ def validate_router_in_netbox(subscription_id: UUIDstr) -> UUIDstr:
msg = "The selected router does not exist in Netbox."
raise ValueError(msg)
return subscription_id
T = TypeVar("T")
NetboxEnabledRouter = Annotated[T, UUIDstr, AfterValidator(validate_router_in_netbox)]
"""Type for the name of a site."""
import re
from typing import Annotated
from pydantic import AfterValidator
def validate_site_name(site_name: str) -> None:
......@@ -15,3 +18,6 @@ def validate_site_name(site_name: str) -> None:
f"digit (0-9). Received: {site_name}"
)
raise ValueError(msg)
SiteName = Annotated[str, AfterValidator(validate_site_name)]
"""An input field that must be unique in the database."""
from typing import Annotated, TypeVar
from pydantic import AfterValidator
from pydantic_core.core_schema import ValidationInfo
from gso.services import subscriptions
def validate_field_is_unique(field_name: str, value: str | int) -> None:
"""Validate that a site field is unique."""
if len(subscriptions.get_active_subscriptions_by_field_and_value(field_name, str(value))) > 0:
msg = f"{field_name} must be unique"
def validate_field_is_unique(value: str | int, info: ValidationInfo) -> None:
"""Validate that a field is unique."""
if len(subscriptions.get_active_subscriptions_by_field_and_value(str(info.field_name), str(value))) > 0:
msg = f"{info.field_name} must be unique"
raise ValueError(msg)
T = TypeVar("T")
UniqueField = Annotated[T, str | int, AfterValidator(validate_field_is_unique)]
......@@ -16,7 +16,7 @@ from orchestrator.workflow import StepList, begin, conditional, done, step, step
from orchestrator.workflows.steps import resync, set_status, store_process_subscription
from orchestrator.workflows.utils import wrap_create_initial_input_form
from ping3 import ping
from pydantic import ConfigDict, field_validator
from pydantic import ConfigDict
from pydantic_forms.validators import ReadOnlyField
from pynetbox.models.dcim import Interfaces
......@@ -35,7 +35,7 @@ from gso.services.sharepoint import SharePointClient
from gso.services.subscriptions import get_non_terminated_iptrunk_subscriptions
from gso.settings import load_oss_params
from gso.types.interfaces import JuniperLAGMember, LAGMember, LAGMemberList, PhysicalPortCapacity
from gso.types.netbox_router import validate_router_in_netbox
from gso.types.netbox_router import NetboxEnabledRouter
from gso.types.tt_number import TTNumber
from gso.utils.helpers import (
available_interfaces_choices,
......@@ -83,11 +83,7 @@ def initial_input_form_generator(product_name: str) -> FormGenerator:
class SelectRouterSideA(FormPage):
model_config = ConfigDict(title="Select a router for side A of the trunk.")
side_a_node_id: router_enum_a # type: ignore[valid-type]
@field_validator("side_a_node_id")
def validate_device_exists_in_netbox(cls, side_a_node_id: UUIDstr) -> str | None:
return validate_router_in_netbox(side_a_node_id)
side_a_node_id: NetboxEnabledRouter[router_enum_a] # type: ignore[valid-type]
user_input_router_side_a = yield SelectRouterSideA
router_a = user_input_router_side_a.side_a_node_id.name
......@@ -134,11 +130,7 @@ def initial_input_form_generator(product_name: str) -> FormGenerator:
class SelectRouterSideB(FormPage):
model_config = ConfigDict(title="Select a router for side B of the trunk.")
side_b_node_id: router_enum_b # type: ignore[valid-type]
@field_validator("side_b_node_id")
def validate_device_exists_in_netbox(cls, side_b_node_id: UUIDstr) -> str | None:
return validate_router_in_netbox(side_b_node_id)
side_b_node_id: NetboxEnabledRouter[router_enum_b] # type: ignore[valid-type]
user_input_router_side_b = yield SelectRouterSideB
router_b = user_input_router_side_b.side_b_node_id.name
......
"""Workflows for the LAN Switch interconnect product."""
"""A creation workflow for creating a new interconnect between a switch and a router."""
from typing import Annotated
from uuid import uuid4
from annotated_types import Len
from orchestrator.forms import FormPage
from orchestrator.targets import Target
from orchestrator.types import FormGenerator, State, SubscriptionLifecycle, UUIDstr
from orchestrator.workflow import StepList, begin, done, step, workflow
from orchestrator.workflows.steps import resync, set_status, store_process_subscription
from orchestrator.workflows.utils import wrap_create_initial_input_form
from pydantic import AfterValidator, ConfigDict
from pydantic_forms.validators import Divider, ReadOnlyField
from gso.products.product_blocks.lan_switch_interconnect import (
LanSwitchInterconnectAddressSpace,
LanSwitchInterconnectInterfaceBlockInactive,
)
from gso.products.product_types.lan_switch_interconnect import LanSwitchInterconnectInactive
from gso.products.product_types.router import Router
from gso.products.product_types.switch import Switch
from gso.services.partners import get_partner_by_name
from gso.types.interfaces import (
JuniperAEInterface,
JuniperLAGMember,
JuniperPhyInterface,
LAGMember,
PhysicalPortCapacity,
validate_interface_names_are_unique,
)
from gso.types.tt_number import TTNumber
from gso.utils.helpers import (
active_router_selector,
active_switch_selector,
available_interfaces_choices,
available_lags_choices,
)
from gso.utils.shared_enums import Vendor
def _initial_input_form(product_name: str) -> FormGenerator:
class CreateLANSwitchInterconnectForm(FormPage):
model_config = ConfigDict(title=product_name)
tt_number: TTNumber
router_side: active_router_selector() # type: ignore[valid-type]
switch_side: active_switch_selector() # type: ignore[valid-type]
address_space: LanSwitchInterconnectAddressSpace
description: str
minimum_link_count: int
divider: Divider
vlan_id: ReadOnlyField(111, default_type=int) # type: ignore[valid-type]
user_input = yield CreateLANSwitchInterconnectForm
router = Router.from_subscription(user_input.router_side)
if router.router.vendor == Vendor.NOKIA:
class NokiaLAGMemberA(LAGMember):
interface_name: available_interfaces_choices( # type: ignore[valid-type]
router.subscription_id,
PhysicalPortCapacity.TEN_GIGABIT_PER_SECOND,
)
router_side_ae_member_list = Annotated[
list[NokiaLAGMemberA],
AfterValidator(validate_interface_names_are_unique),
Len(min_length=user_input.minimum_link_count),
]
else:
router_side_ae_member_list = Annotated[ # type: ignore[assignment, misc]
list[JuniperLAGMember],
AfterValidator(validate_interface_names_are_unique),
Len(min_length=user_input.minimum_link_count),
]
class InterconnectRouterSideForm(FormPage):
model_config = ConfigDict(title="Please enter interface names and descriptions for the router side.")
router_side_iface: available_lags_choices(user_input.router_side) or JuniperAEInterface # type: ignore[valid-type]
router_side_ae_members: router_side_ae_member_list
router_side_input = yield InterconnectRouterSideForm
switch_side_ae_member_list = Annotated[
list[JuniperLAGMember],
AfterValidator(validate_interface_names_are_unique),
Len(
min_length=len(router_side_input.router_side_ae_members),
max_length=len(router_side_input.router_side_ae_members),
),
]
class InterconnectSwitchSideForm(FormPage):
model_config = ConfigDict(title="Please enter interface names and descriptions for the switch side.")
switch_side_iface: JuniperPhyInterface
switch_side_ae_members: switch_side_ae_member_list
switch_side_input = yield InterconnectSwitchSideForm
return user_input.model_dump() | router_side_input.model_dump() | switch_side_input.model_dump()
@step("Create subscription")
def create_subscription(product: UUIDstr, partner: str) -> State:
"""Create a new subscription object in the database."""
subscription = LanSwitchInterconnectInactive.from_product_id(product, get_partner_by_name(partner)["partner_id"])
return {"subscription": subscription}
@step("Initialize subscription")
def initialize_subscription(
subscription: LanSwitchInterconnectInactive,
description: str,
address_space: LanSwitchInterconnectAddressSpace,
minimum_link_count: int,
router_side: UUIDstr,
router_side_iface: JuniperPhyInterface,
router_side_ae_members: list[dict],
switch_side: UUIDstr,
switch_side_iface: JuniperPhyInterface,
switch_side_ae_members: list[dict],
) -> State:
"""Update the product model with all input from the operator."""
subscription.lan_switch_interconnect.lan_switch_interconnect_description = description
subscription.lan_switch_interconnect.address_space = address_space
subscription.lan_switch_interconnect.minimum_links = minimum_link_count
subscription.lan_switch_interconnect.router_side.node = Router.from_subscription(router_side).router
subscription.lan_switch_interconnect.router_side.ae_iface = router_side_iface
for member in router_side_ae_members:
subscription.lan_switch_interconnect.router_side.ae_members.append(
LanSwitchInterconnectInterfaceBlockInactive.new(subscription_id=uuid4(), **member)
)
subscription.lan_switch_interconnect.switch_side.node = Switch.from_subscription(switch_side).switch
subscription.lan_switch_interconnect.switch_side.ae_iface = switch_side_iface
for member in switch_side_ae_members:
subscription.lan_switch_interconnect.switch_side.ae_members.append(
LanSwitchInterconnectInterfaceBlockInactive.new(subscription_id=uuid4(), **member)
)
return {"subscription": subscription}
@workflow(
"Create LAN switch interconnect",
initial_input_form=wrap_create_initial_input_form(_initial_input_form),
target=Target.CREATE,
)
def create_lan_switch_interconnect() -> StepList:
"""Create a new LAN interconnect between a Switch and a Router."""
return (
begin
>> create_subscription
>> store_process_subscription(Target.CREATE)
>> initialize_subscription
>> set_status(SubscriptionLifecycle.ACTIVE)
>> resync
>> done
)
......@@ -16,6 +16,7 @@ from gso.services import subscriptions
from gso.services.partners import get_partner_by_name
from gso.types.base_site import BaseSiteValidatorModel
from gso.types.coordinates import LatitudeCoordinate, LongitudeCoordinate
from gso.types.ip_address import IPAddress
@step("Create subscription")
......@@ -52,7 +53,7 @@ def initialize_subscription(
site_longitude: LongitudeCoordinate,
site_bgp_community_id: int,
site_internal_id: int,
site_ts_address: str,
site_ts_address: IPAddress,
site_tier: SiteTier,
) -> State:
"""Initialise the subscription object with all input."""
......
......@@ -14,6 +14,7 @@ from gso.products.product_types import site
from gso.services.partners import get_partner_by_name
from gso.types.base_site import BaseSiteValidatorModel
from gso.types.coordinates import LatitudeCoordinate, LongitudeCoordinate
from gso.types.ip_address import IPAddress
def initial_input_form_generator(product_name: str) -> FormGenerator:
......@@ -50,7 +51,7 @@ def initialize_subscription(
site_longitude: LongitudeCoordinate,
site_bgp_community_id: int,
site_internal_id: int,
site_ts_address: str,
site_ts_address: IPAddress,
site_tier: site_pb.SiteTier,
) -> State:
"""Initialise the subscription object with all user input."""
......
"""A modification workflow for a site."""
from typing import Annotated
from orchestrator.forms import FormPage
from orchestrator.targets import Target
from orchestrator.types import FormGenerator, State, SubscriptionLifecycle, UUIDstr
......@@ -11,15 +13,14 @@ from orchestrator.workflows.steps import (
unsync,
)
from orchestrator.workflows.utils import wrap_modify_initial_input_form
from pydantic import ConfigDict, field_validator
from pydantic_core.core_schema import ValidationInfo
from pydantic import ConfigDict
from pydantic_forms.validators import ReadOnlyField
from gso.products.product_blocks.site import SiteTier
from gso.products.product_types.site import Site
from gso.types.coordinates import LatitudeCoordinate, LongitudeCoordinate
from gso.types.ip_address import validate_ipv4_or_ipv6
from gso.types.unique_field import validate_field_is_unique
from gso.types.ip_address import IPAddress
from gso.types.unique_field import UniqueField
def initial_input_form_generator(subscription_id: UUIDstr) -> FormGenerator:
......@@ -35,30 +36,10 @@ def initial_input_form_generator(subscription_id: UUIDstr) -> FormGenerator:
site_country_code: ReadOnlyField(subscription.site.site_country_code, default_type=str) # type: ignore[valid-type]
site_latitude: LatitudeCoordinate = subscription.site.site_latitude
site_longitude: LongitudeCoordinate = subscription.site.site_longitude
site_bgp_community_id: int = subscription.site.site_bgp_community_id
site_internal_id: int = subscription.site.site_internal_id
site_bgp_community_id: UniqueField[int] = subscription.site.site_bgp_community_id
site_internal_id: UniqueField[int] = subscription.site.site_internal_id
site_tier: ReadOnlyField(subscription.site.site_tier, default_type=SiteTier) # type: ignore[valid-type]
site_ts_address: str | None = subscription.site.site_ts_address
@field_validator("site_ts_address", "site_internal_id", "site_bgp_community_id")
def field_must_be_unique(cls, value: str | int, info: ValidationInfo) -> str | int:
if not info.field_name:
msg = "Field name must be provided."
raise ValueError(msg)
if value and value == getattr(subscription.site, info.field_name):
return value
validate_field_is_unique(info.field_name, value)
return value
@field_validator("site_ts_address")
def validate_ts_address(cls, site_ts_address: str) -> str:
if site_ts_address and site_ts_address != subscription.site.site_ts_address:
validate_ipv4_or_ipv6(site_ts_address)
return site_ts_address
site_ts_address: Annotated[IPAddress, UniqueField] | None = subscription.site.site_ts_address
user_input = yield ModifySiteForm
......@@ -73,7 +54,7 @@ def modify_site_subscription(
site_longitude: LongitudeCoordinate,
site_bgp_community_id: int,
site_internal_id: int,
site_ts_address: str,
site_ts_address: IPAddress,
) -> State:
"""Update the subscription model in the service database."""
subscription.site.site_city = site_city
......
......@@ -4,9 +4,9 @@ from orchestrator.forms import FormPage
from orchestrator.targets import Target
from orchestrator.types import FormGenerator, State
from orchestrator.workflow import StepList, begin, done, step, workflow
from pydantic import ConfigDict, EmailStr, field_validator
from pydantic import ConfigDict
from gso.services.partners import PartnerSchema, create_partner, filter_partners_by_email, filter_partners_by_name
from gso.services.partners import PartnerEmail, PartnerName, PartnerSchema, create_partner
def initial_input_form_generator() -> FormGenerator:
......@@ -15,25 +15,8 @@ def initial_input_form_generator() -> FormGenerator:
class CreatePartnerForm(FormPage):
model_config = ConfigDict(title="Create a Partner")
name: str
email: EmailStr
@field_validator("name")
def validate_name(cls, name: str) -> str:
if filter_partners_by_name(name=name, case_sensitive=False):
msg = "Partner with this name already exists."
raise ValueError(msg)
return name
@field_validator("email")
def validate_email(cls, email: str) -> EmailStr:
email = email.lower()
if filter_partners_by_email(email=email, case_sensitive=False):
msg = "Partner with this email already exists."
raise ValueError(msg)
return email
name: PartnerName
email: PartnerEmail
initial_user_input = yield CreatePartnerForm
......@@ -42,8 +25,8 @@ def initial_input_form_generator() -> FormGenerator:
@step("Save partner information to database")
def save_partner_to_database(
name: str,
email: EmailStr,
name: PartnerName,
email: PartnerEmail,
) -> State:
"""Save user input as a new partner in database."""
partner = create_partner(
......
......@@ -9,7 +9,7 @@ from pydantic_forms.types import UUIDstr
from pydantic_forms.validators import Choice
from gso.services.partners import (
PartnerSchema,
ModifiedPartnerSchema,
edit_partner,
filter_partners_by_email,
filter_partners_by_name,
......@@ -38,23 +38,25 @@ def initial_input_form_generator() -> FormGenerator:
class ModifyPartnerForm(FormPage):
model_config = ConfigDict(title="Modify a Partner")
name: str = partner["name"]
email: EmailStr = partner["email"]
name: str | None = partner["name"]
email: EmailStr | None = partner["email"]
@field_validator("name")
def validate_name(cls, name: str) -> str:
if partner["name"] != name and filter_partners_by_name(name=name, case_sensitive=False):
def validate_name(cls, name: str) -> str | None:
if partner["name"] == name:
return None
if filter_partners_by_name(name=name, case_sensitive=False):
msg = "Partner with this name already exists."
raise ValueError(msg)
return name
@field_validator("email")
def validate_email(cls, email: str) -> EmailStr:
if partner["email"] != email and filter_partners_by_email(email=email, case_sensitive=False):
def validate_email(cls, email: str) -> EmailStr | None:
if partner["email"] == email:
return None
if filter_partners_by_email(email=email, case_sensitive=False):
msg = "Partner with this email already exists."
raise ValueError(msg)
return email
user_input = yield ModifyPartnerForm
......@@ -65,12 +67,12 @@ def initial_input_form_generator() -> FormGenerator:
@step("Save partner information to database")
def save_partner_to_database(
partner_id: UUIDstr,
name: str,
email: EmailStr,
name: str | None,
email: EmailStr | None,
) -> State:
"""Save modified partner in database."""
partner = edit_partner(
partner_data=PartnerSchema(
partner_data=ModifiedPartnerSchema(
partner_id=partner_id,
name=name,
email=email,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment