Skip to content
Snippets Groups Projects
gap.py 4.94 KiB
import logging
from collections.abc import Callable
from typing import Any

import requests
from requests import Session

from mapping_provider.dependencies import config_dep

logger = logging.getLogger(__name__)

GRANT_TYPE = "client_credentials"
SCOPE = "openid profile email aarc"


def get_token_endpoint(discovery_endpoint_url: str, session: Session) -> str:
    """Fetch the token endpoint URL from discovery document."""
    response = session.get(discovery_endpoint_url)
    response.raise_for_status()
    data: dict[str, Any] = response.json()
    return str(data["token_endpoint"])


def get_token(config: dict[str, Any], session: Session) -> str:
    """Retrieve an access token using client credentials flow."""
    aai_config = config["aai"]
    token_endpoint = get_token_endpoint(aai_config["discovery_endpoint_url"], session)
    response = session.post(
        token_endpoint,
        data={
            "grant_type": GRANT_TYPE,
            "scope": SCOPE,
            "client_id": aai_config["client_id"],
            "client_secret": aai_config["secret"],
        },
    )
    response.raise_for_status()
    data: dict[str, Any] = response.json()
    return str(data["access_token"])


def make_request(query: str, token: str, config: dict[str, Any], session: Session) -> dict[Any, Any]:
    """Make a GraphQL request to the orchestrator API."""
    api_url = f"{config['orchestrator']['url']}/api/graphql"
    headers = {"Authorization": f"Bearer {token}"}
    response = session.post(api_url, headers=headers, json={"query": query})
    response.raise_for_status()

    data: dict[str, Any] = response.json()
    if "errors" in data:
        logger.error(f"GraphQL query returned errors: {data['errors']}. Query: {query}")
        raise ValueError(f"GraphQL query errors: {data['errors']}")

    return data


def extract_router(product_block_instances: list[dict[str, Any]]) -> str | None:
    """Extract router FQDNs from productBlockInstances."""
    for instance in product_block_instances:
        for value in instance.get("productBlockInstanceValues", []):
            if value.get("field") == "routerFqdn" and value.get("value"):
                return str(value["value"])
    return None


def extract_trunk(product_block_instances: list[dict[str, Any]]) -> list[str] | None:
    """Extract trunks from productBlockInstances."""
    fqdns = []
    for instance in product_block_instances:
        for value in instance.get("productBlockInstanceValues", []):
            if value.get("field") == "routerFqdn" and value.get("value"):
                fqdns.append(value["value"])
                break

    if len(fqdns) >= 2:
        return [fqdns[0], fqdns[1]]
    return None


def load_inventory(
    config: dict[str, Any],
    token: str,
    tag: str,
    session: Session,
    extractor: Callable[[list[dict[str, Any]]], Any | None],
) -> list[Any]:
    """
    Generic function to load inventory items based on tag and extractor function.

    The extractor receives a list of productBlockInstances and returns parsed output.
    """
    results = []
    end_cursor = 0
    has_next_page = True

    while has_next_page:
        query = f"""
           query {{
               subscriptions(
                   filterBy: {{field: "status", value: "PROVISIONING|ACTIVE"}},
                   first: 100,
                   after: {end_cursor},
                   query: "tag:({tag})"
               ) {{
                   pageInfo {{
                       hasNextPage
                       endCursor
                   }}
                   page {{
                       subscriptionId
                       product {{
                           tag
                       }}
                       productBlockInstances {{
                           productBlockInstanceValues
                       }}
                   }}
               }}
           }}
           """

        data = make_request(query, token, config, session)
        page_data = data.get("data", {}).get("subscriptions", {}).get("page", [])
        page_info = data.get("data", {}).get("subscriptions", {}).get("pageInfo", {})

        for item in page_data:
            instances = item.get("productBlockInstances", [])
            extracted = extractor(instances)
            if extracted:
                results.append(extracted)

        has_next_page = page_info.get("hasNextPage", False)
        end_cursor = page_info.get("endCursor", 0)

    return results


def load_routers(config: config_dep) -> list[str]:
    """Load routers (nodes) from orchestrator."""
    with requests.Session() as session:
        token = get_token(config, session)
        return load_inventory(config, token, tag="RTR", session=session, extractor=extract_router)


def load_trunks(config: config_dep) -> list[list[str]]:
    """Load trunks (edges) from orchestrator."""
    with requests.Session() as session:
        token = get_token(config, session)
        return load_inventory(config, token, tag="IPTRUNK", session=session, extractor=extract_trunk)