Skip to content
Snippets Groups Projects
helpers.py 10.7 KiB
Newer Older
"""Helper methods that are used across :term:`GSO`."""

import ipaddress
from enum import StrEnum
from uuid import UUID
import pycountry
from orchestrator.types import UUIDstr
from pydantic import BaseModel, validator
from pydantic.fields import ModelField
from pydantic_forms.validators import Choice
from gso import settings
from gso.products.product_blocks.iptrunk import IptrunkInterfaceBlock
from gso.products.product_blocks.site import SiteTier
from gso.products.product_types.router import Router
from gso.services.netbox_client import NetboxClient
from gso.services.subscriptions import get_active_subscriptions_by_field_and_value
from gso.utils.shared_choices import Vendor
class LAGMember(BaseModel):
    """A :term:`LAG` member interface that consists of a name and description."""
    interface_name: str
    interface_description: str

    def __hash__(self) -> int:
        """Calculate the hash based on the interface name and description, so that uniqueness can be determined."""
        return hash((self.interface_name, self.interface_description))


class SNMPVersion(StrEnum):
    """An enumerator for the two relevant versions of :term:`SNMP`: v2c and 3."""

    V2C = "v2c"
    V3 = "v3"


def available_interfaces_choices(router_id: UUID, speed: str) -> Choice | None:
    """Return a list of available interfaces for a given router and speed.

    For Nokia routers, return a list of available interfaces.
    For Juniper routers, return a string.
    """
    if get_router_vendor(router_id) != Vendor.NOKIA:
        return None
    interfaces = {
        interface["name"]: f"{interface['name']}  {interface['description']}"
        for interface in NetboxClient().get_available_interfaces(router_id, speed)
    return Choice("ae member", zip(interfaces.keys(), interfaces.items(), strict=True))  # type: ignore[arg-type]
def available_interfaces_choices_including_current_members(
Neda Moeini's avatar
Neda Moeini committed
    router_id: UUID,
    speed: str,
    interfaces: list[IptrunkInterfaceBlock],
) -> Choice | None:
    """Return a list of available interfaces for a given router and speed including the current members.

    For Nokia routers, return a list of available interfaces.
    For Juniper routers, return a string.
    """
    if get_router_vendor(router_id) != Vendor.NOKIA:
    available_interfaces = list(NetboxClient().get_available_interfaces(router_id, speed))
    available_interfaces.extend(
        [
Neda Moeini's avatar
Neda Moeini committed
            NetboxClient().get_interface_by_name_and_device(
                interface.interface_name,
                Router.from_subscription(router_id).router.router_fqdn,
Neda Moeini's avatar
Neda Moeini committed
            )
        interface["name"]: f"{interface['name']}  {interface['description']}" for interface in available_interfaces
    return Choice("ae member", zip(options.keys(), options.items(), strict=True))  # type: ignore[arg-type]
def available_lags_choices(router_id: UUID) -> Choice | None:
    """Return a list of available lags for a given router.

    For Nokia routers, return a list of available lags.
    For Juniper routers, return ``None``.
    if get_router_vendor(router_id) != Vendor.NOKIA:
    side_a_ae_iface_list = NetboxClient().get_available_lags(router_id)
    return Choice("ae iface", zip(side_a_ae_iface_list, side_a_ae_iface_list, strict=True))  # type: ignore[arg-type]
def get_router_vendor(router_id: UUID) -> Vendor:
    """Retrieve the vendor of a router.

    :param router_id: The :term:`UUID` of the router.
    :type router_id: :class:`uuid.UUID`
    :return: The vendor of the router.
    :rtype: Vendor:
    return Router.from_subscription(router_id).router.vendor
def iso_from_ipv4(ipv4_address: IPv4Address) -> str:
    """Calculate an :term:`ISO` address, based on an IPv4 address.
Karel van Klink's avatar
Karel van Klink committed
    :param IPv4Address ipv4_address: The address that's to be converted
    :returns: An :term:`ISO`-formatted address.
    """
    padded_octets = [f"{x:>03}" for x in str(ipv4_address).split(".")]
    joined_octets = "".join(padded_octets)
    re_split = ".".join(re.findall("....", joined_octets))
def validate_router_in_netbox(subscription_id: UUIDstr) -> UUIDstr:
    """Verify if a device exists in Netbox.

    Raises a :class:`ValueError` if the device is not found.
    :param subscription_id: The :term:`UUID` of the router subscription.
    :type subscription_id: :class:`UUIDstr`

    :return: The :term:`UUID` of the router subscription.
    :rtype: :class:`UUIDstr`
    router_type = Router.from_subscription(subscription_id)
    if router_type.router.vendor == Vendor.NOKIA:
        device = NetboxClient().get_device_by_name(router_type.router.router_fqdn)
        if not device:
            msg = "The selected router does not exist in Netbox."
            raise ValueError(msg)
    return subscription_id


def validate_iptrunk_unique_interface(interfaces: list[LAGMember]) -> list[LAGMember]:
    """Verify if the interfaces are unique.

    Raises a :class:`ValueError` if the interfaces are not unique.

    :param interfaces: The list of interfaces.
    :type interfaces: list[:class:`LAGMember`]
    :return: The list of interfaces
    :rtype: list[:class:`LAGMember`]
    """
    interface_names = [member.interface_name for member in interfaces]
    if len(interface_names) != len(set(interface_names)):
        msg = "Interfaces must be unique."
        raise ValueError(msg)
    return interfaces


def validate_site_fields_is_unique(field_name: str, value: str | int) -> str | int:
    """Validate that a site field is unique."""
    if len(get_active_subscriptions_by_field_and_value(field_name, str(value))) > 0:
        msg = f"{field_name} must be unique"
        raise ValueError(msg)
    return value


def validate_ipv4_or_ipv6(value: str) -> str:
    """Validate that a value is a valid IPv4 or IPv6 address."""
    try:
        ipaddress.ip_address(value)
    except ValueError as e:
        msg = "Enter a valid IPv4 or IPv6 address."
        raise ValueError(msg) from e
    else:
        return value


def validate_country_code(country_code: str) -> str:
    """Validate that a country code is valid."""
    # Check for the UK code before attempting to look it up since it's known as "GB" in the pycountry database.
    if country_code != "UK":
        try:
            pycountry.countries.lookup(country_code)
        except LookupError as e:
            msg = "Invalid or non-existent country code, it must be in ISO 3166-1 alpha-2 format."
            raise ValueError(msg) from e
    return country_code


def validate_site_name(site_name: str) -> str:
    """Validate the site name.

    The site name must consist of three uppercase letters, optionally followed by a single digit.
    """
    pattern = re.compile(r"^[A-Z]{3}[0-9]?$")
    if not pattern.match(site_name):
            "Enter a valid site name. It must consist of three uppercase letters (A-Z), followed by an optional single "
            f"digit (0-9). Received: {site_name}"
        raise ValueError(msg)
    return site_name


class BaseSiteValidatorModel(BaseModel):
    """A base site validator model extended by create site and by import site."""

    site_bgp_community_id: int
    site_internal_id: int
    site_tier: SiteTier
    site_ts_address: str

    @validator("site_ts_address", check_fields=False, allow_reuse=True)
    def validate_ts_address(cls, site_ts_address: str) -> str:
        """Validate that a terminal server address is valid."""
        validate_ipv4_or_ipv6(site_ts_address)
        return site_ts_address

    @validator("site_country_code", check_fields=False, allow_reuse=True)
    def country_code_must_exist(cls, country_code: str) -> str:
        """Validate that the country code exists."""
        validate_country_code(country_code)
        return country_code

    @validator(
        "site_ts_address",
        "site_internal_id",
        "site_bgp_community_id",
        "site_name",
        check_fields=False,
        allow_reuse=True,
    )
    def validate_unique_fields(cls, value: str, field: ModelField) -> str | int:
        """Validate that the internal and :term:`BGP` community IDs are unique."""
        return validate_site_fields_is_unique(field.name, value)

    @validator("site_name", check_fields=False, allow_reuse=True)
    def site_name_must_be_valid(cls, site_name: str) -> str:
        """Validate the site name.

        The site name must consist of three uppercase letters, followed by an optional single digit.
        """
        validate_site_name(site_name)
        return site_name
def validate_interface_name_list(interface_name_list: list, vendor: str) -> list:
    """Validate that the provided interface name matches the expected pattern.

    The expected pattern for the interface name is one of 'ge', 'et', 'xe' followed by a dash '-',
    then a digit between 0 and 9, a forward slash '/', another digit between 0 and 9,
    another forward slash '/', and ends with a digit between 0 and 9.
    For example: 'xe-1/0/0'.

    :param list interface_name_list: List of interface names to validate.
    :param str vendor: Router vendor to check interface names

    :return list: The list of interface names if all match was successful, otherwise it will throw a ValueError
                  exception.
    if vendor == Vendor.NOKIA:
    pattern = re.compile(r"^(ge|et|xe)-[0-9]/[0-9]/[0-9]$")
    for interface in interface_name_list:
        if not bool(pattern.match(interface.interface_name)):
Neda Moeini's avatar
Neda Moeini committed
            error_msg = (
                f"Invalid interface name. The interface name should be of format: xe-1/0/0. "
                f"Got: [{interface.interface_name}]"
            )
            raise ValueError(error_msg)
    return interface_name_list


def validate_tt_number(tt_number: str) -> str:
    """Validate a string to match a specific pattern.

    This method checks if the input string starts with 'TT#' and is followed by exactly 16 digits.

    :param str tt_number: The TT number as string to validate

    :return str: The TT number string if TT number match was successful, otherwise it will raise a ValueError.
    """
    pattern = r"^TT#\d{16}$"
    if not bool(re.match(pattern, tt_number)):
        err_msg = (
            f"The given TT number: {tt_number} is not valid. "
            f" A valid TT number starts with 'TT#' followed by 16 digits."
        )
        raise ValueError(err_msg)

    return tt_number


def generate_fqdn(hostname: str, site_name: str, country_code: str) -> str:
    """Generate an :term:`FQDN` from a hostname, site name, and a country code."""
    oss = settings.load_oss_params()
    return f"{hostname}.{site_name.lower()}.{country_code.lower()}{oss.IPAM.LO.domain_name}"