Skip to content
Snippets Groups Projects
subscriptions.py 2.98 KiB
Newer Older
from uuid import UUID

from asyncio_redis import Subscription
from orchestrator.db import (
    ProductTable,
    ResourceTypeTable,
    SubscriptionInstanceTable,
    SubscriptionInstanceValueTable,
    SubscriptionTable,
)

from gso.schemas.enums import ProductType, SubscriptionStatus


def get_active_subscriptions(
    product_type: str,
    fields: list[str],
) -> list[Subscription]:
    """Retrieve active subscriptions for a specific product type.

    Args:
    ----
    product_type (str): The type of the product for which to retrieve subscriptions.
    fields (list[str]): List of fields to be included in the returned Subscription objects.

    Returns:
    -------
    list[Subscription]: A list of Subscription objects that match the query.
    """
    dynamic_fields = [getattr(SubscriptionTable, field) for field in fields]

    return (
        SubscriptionTable.query.join(ProductTable)
        .filter(
            ProductTable.product_type == product_type,
            SubscriptionTable.status == SubscriptionStatus.ACTIVE,
        )
        .with_entities(*dynamic_fields)
        .all()
    )


def get_active_site_subscriptions(fields: list[str]) -> list[Subscription]:
    """Retrieve active subscriptions specifically for sites.

    Args:
    ----
    fields (list[str]): The fields to be included in the returned Subscription objects.

    Returns:
    -------
    list[Subscription]: A list of Subscription objects for sites.
    """
    return get_active_subscriptions(ProductType.SITE, fields)


def get_active_router_subscriptions(fields: list[str]) -> list[Subscription]:
    """Retrieve active subscriptions specifically for routers.

    Args:
    ----
    fields (list[str]): The fields to be included in the returned Subscription objects.

    Returns:
    -------
    list[Subscription]: A list of Subscription objects for routers.
    """
    return get_active_subscriptions(product_type=ProductType.ROUTER, fields=fields)


def get_product_id_by_name(product_name: ProductType) -> UUID:
    """Retrieve the {term}`UUID` of a product by its name.

    Args:
    ----
    product_name (ProductType): The name of the product.

    Returns:
    -------
    """
    return ProductTable.query.filter_by(name=product_name).first().product_id


def get_active_site_subscription_by_name(site_name: str) -> Subscription:
    """Retrieve an active subscription for a site by the site's name.

    Args:
    ----
    site_name (str): The name of the site for which to retrieve the subscription.

    Returns:
    -------
    Subscription: The Subscription object for the site.
    """
    return (
        SubscriptionTable.query.join(
            ProductTable, SubscriptionInstanceTable, SubscriptionInstanceValueTable, ResourceTypeTable
        )
        .filter(SubscriptionInstanceValueTable.value == site_name)
        .filter(ResourceTypeTable.resource_type == "site_name")
        .filter(SubscriptionTable.status == SubscriptionStatus.ACTIVE)
        .first()
    )