diff --git a/gso/cli/imports.py b/gso/cli/imports.py index 3a6dcf1c814aecb96a2ecc1a406a78fd9fda9f9a..406fa85d1dbc0e11d58187909a63dc5f6285285c 100644 --- a/gso/cli/imports.py +++ b/gso/cli/imports.py @@ -13,14 +13,19 @@ import yaml from orchestrator.db import db from orchestrator.services.processes import start_process from orchestrator.types import SubscriptionLifecycle -from pydantic import BaseModel, ValidationError, field_validator, model_validator +from pydantic import BaseModel, EmailStr, ValidationError, field_validator, model_validator from sqlalchemy.exc import SQLAlchemyError from gso.db.models import PartnerTable from gso.products import ProductType from gso.products.product_blocks.iptrunk import IptrunkType, PhysicalPortCapacity from gso.products.product_blocks.router import RouterRole -from gso.services.partners import PartnerNotFoundError, get_partner_by_name +from gso.services.partners import ( + PartnerNotFoundError, + filter_partners_by_email, + filter_partners_by_name, + get_partner_by_name, +) from gso.services.subscriptions import ( get_active_router_subscriptions, get_active_subscriptions_by_field_and_value, @@ -32,6 +37,32 @@ from gso.utils.shared_enums import IPv4AddressType, IPv6AddressType, PortNumber, app: typer.Typer = typer.Typer() +class CreatePartner(BaseModel): + """Required inputs for creating a partner.""" + + name: str + email: EmailStr + + @field_validator("name") + def validate_name(cls, name: str) -> str: + """Validate name.""" + if filter_partners_by_name(name=name, case_sensitive=False): + msg = "Partner with this name already exists." + raise ValueError(msg) + + return name + + @field_validator("email") + def validate_email(cls, email: str) -> EmailStr: + """Validate email.""" + email = email.lower() + if filter_partners_by_email(email=email, case_sensitive=False): + msg = "Partner with this email already exists." + raise ValueError(msg) + + return email + + class SiteImportModel(BaseSiteValidatorModel): """The required input for importing an existing :class:`gso.products.product_types.site`.""" @@ -375,7 +406,7 @@ def import_partners(file_path: str = typer.Argument(..., help="Path to the CSV f if partner.get("created_at"): partner["created_at"] = datetime.strptime(partner["created_at"], "%Y-%m-%d").replace(tzinfo=UTC) - new_partner = PartnerTable(**partner) + new_partner = PartnerTable(**CreatePartner(**partner).model_dump()) db.session.add(new_partner) db.session.commit()