"""A module that returns the partners available in :term:`GSO`.""" from datetime import datetime from typing import Annotated, Any from uuid import uuid4 from orchestrator.db import db from pydantic import AfterValidator, BaseModel, ConfigDict, EmailStr, Field from sqlalchemy import func from sqlalchemy.exc import NoResultFound from gso.db.models import PartnerTable def validate_partner_name_unique(name: str) -> str: """Validate that the name of a partner is unique.""" if filter_partners_by_name(name=name, case_sensitive=False): msg = "Partner with this name already exists." raise ValueError(msg) return name def validate_partner_email_unique(email: EmailStr) -> EmailStr: """Validate that the e-mail address of a partner is unique.""" email = email.lower() if filter_partners_by_email(email=email, case_sensitive=False): msg = "Partner with this email already exists." raise ValueError(msg) return email PartnerName = Annotated[str, AfterValidator(validate_partner_name_unique)] PartnerEmail = Annotated[EmailStr, AfterValidator(validate_partner_email_unique)] class PartnerSchema(BaseModel): """Partner schema.""" partner_id: str = Field(default_factory=lambda: str(uuid4())) name: PartnerName email: PartnerEmail created_at: datetime = Field(default_factory=lambda: datetime.now().astimezone()) updated_at: datetime = Field(default_factory=lambda: datetime.now().astimezone()) model_config = ConfigDict(from_attributes=True) class ModifiedPartnerSchema(PartnerSchema): """Partner schema when making a modification. The name and email can be empty in this case, if they don't need changing. """ name: PartnerName | None = None # type: ignore[assignment] email: PartnerEmail | None = None # type: ignore[assignment] class PartnerNotFoundError(Exception): """Exception raised when a partner is not found.""" def get_all_partners() -> list[dict]: """Fetch all partners from the database and serialize them to JSON.""" partners = PartnerTable.query.all() return [partner.__json__() for partner in partners] def get_partner_by_name(name: str) -> dict[str, Any]: """Try to get a partner by their name.""" try: partner = db.session.query(PartnerTable).filter(PartnerTable.name == name).one() return partner.__json__() except NoResultFound as e: msg = f"partner {name} not found" raise PartnerNotFoundError(msg) from e def get_partner_by_id(partner_id: str) -> PartnerTable: """Try to get a partner by their id.""" partner = db.session.query(PartnerTable).filter_by(partner_id=partner_id).first() if not partner: raise PartnerNotFoundError return partner def filter_partners_by_attribute( attribute: str, value: str, *, case_sensitive: bool = True ) -> list[dict[str, Any]] | None: """Filter the list of partners by a specified attribute.""" if case_sensitive: partners = db.session.query(PartnerTable).filter(getattr(PartnerTable, attribute) == value).all() else: partners = ( db.session.query(PartnerTable) .filter(func.lower(getattr(PartnerTable, attribute)) == func.lower(value)) .all() ) return [partner.__json__() for partner in partners] if partners else None def filter_partners_by_name(name: str, *, case_sensitive: bool = True) -> list[dict[str, Any]] | None: """Filter the list of partners by name.""" return filter_partners_by_attribute("name", name, case_sensitive=case_sensitive) def filter_partners_by_email(email: str, *, case_sensitive: bool = True) -> list[dict[str, Any]] | None: """Filter the list of partners by email.""" return filter_partners_by_attribute("email", email, case_sensitive=case_sensitive) def create_partner( partner_data: PartnerSchema, ) -> dict: """Create a new partner and add it to the database using Pydantic schema for validation. :param partner_data: Partner data validated by Pydantic schema. :return: JSON representation of the created partner. """ new_partner = PartnerTable(**partner_data.model_dump()) db.session.add(new_partner) db.session.commit() return new_partner.__json__() def edit_partner( partner_data: ModifiedPartnerSchema, ) -> PartnerTable: """Edit an existing partner and update it in the database.""" partner = get_partner_by_id(partner_id=partner_data.partner_id) if partner_data.name: partner.name = partner_data.name if partner_data.email: partner.email = partner_data.email partner.updated_at = datetime.now().astimezone() db.session.commit() return partner def delete_partner(partner_id: str) -> None: """Delete an existing partner from the database.""" partner = get_partner_by_id(partner_id=partner_id) db.session.delete(partner) db.session.commit()