Skip to content
Snippets Groups Projects
gap.py 6.10 KiB
import concurrent.futures
import logging
import socket
from typing import Dict, Optional

import requests
from requests.adapters import HTTPAdapter
from urllib3.poolmanager import PoolManager

logger = logging.getLogger(__name__)

GRANT_TYPE = 'client_credentials'
SCOPE = 'openid profile email aarc'


class IPv4Adapter(HTTPAdapter):
    """A custom adapter that forces the use of IPv4.
    The reason for this is that the orchestrator does not support IPv6. We use this adapter to force the use of IPv4 as
    a temporary workaround. This adapter should be removed once the orchestrator supports IPv6."""

    def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
        pool_kwargs['socket_options'] = [(socket.IPPROTO_IP, socket.IP_TOS, 0)]
        self.poolmanager = PoolManager(
            num_pools=connections,
            maxsize=maxsize,
            block=block,
            **pool_kwargs
        )

    def proxy_manager_for(self, proxy, **proxy_kwargs):
        proxy_kwargs['socket_options'] = [(socket.IPPROTO_IP, socket.IP_TOS, 0)]
        return super().proxy_manager_for(proxy, **proxy_kwargs)


def get_token_endpoint(discovery_endpoint_url: str) -> str:
    response = requests.get(discovery_endpoint_url)
    response.raise_for_status()
    return response.json()['token_endpoint']


def get_token(aai_config: dict) -> str:
    """Get an access token using the given configuration."""
    response = requests.post(
        get_token_endpoint(aai_config['discovery_endpoint_url']),
        data={
            'grant_type': GRANT_TYPE,
            'scope': SCOPE,
            'client_id': aai_config['inventory_provider']['client_id'],
            'client_secret': aai_config['inventory_provider']['secret']
        }
    )
    response.raise_for_status()
    return response.json()['access_token']


def make_request(body: dict, token: str, app_config: dict) -> Dict:
    """Make a request to the orchestrator using the given body."""
    api_url = f'{app_config["orchestrator"]["url"]}/api/graphql'
    headers = {'Authorization': f'Bearer {token}'}
    session = requests.Session()
    # Mount the adapter to force IPv4
    # This should be removed once the orchestrator supports IPv6
    # See the docstring of the IPv4Adapter class for more info
    adapter = IPv4Adapter()
    session.mount('http://', adapter)
    session.mount('https://', adapter)
    response = session.post(api_url, headers=headers, json=body)
    response.raise_for_status()
    # The graphql API returns a 200 status code even if there are errors in the response
    errors = response.json().get('errors')
    if errors:
        err_msg = f'GraphQL query returned errors: {errors}'
        logger.error(err_msg)
        raise ValueError(err_msg)
    return response.json()


def extract_router_info(device: dict, token: str, app_config: dict) -> Optional[dict]:
    tag_to_key_map = {
        "RTR": "router",
        "OFFICE_ROUTER": "officeRouter",
        "Super_POP_SWITCH": "superPopSwitch"
    }

    tag = device.get("product", {}).get("tag")
    key = tag_to_key_map.get(tag)
    subscription_id = device.get("subscriptionId")

    if key is None or subscription_id is None:
        logger.warning(f"Skipping device with invalid tag or subscription ID: {device}")
        return None

    query = f"""
    query {{
            subscriptions(
                filterBy: {{ field: "subscriptionId", value: "{subscription_id}" }}
            ) {{
                page {{
                    subscriptionId
                    productBlockInstances {{
                        productBlockInstanceValues
                    }}
                }}
            }}
        }}
        """

    response = make_request(body={'query': query}, token=token, app_config=app_config)
    page_data = response.get('data', {}).get('subscriptions', {}).get('page')

    if not page_data:
        logger.warning(f"No data for subscription ID: {subscription_id}")
        return None

    instance_values = page_data[0].get('productBlockInstances', [{}])[0].get('productBlockInstanceValues', [])

    fqdn = next((item.get('value') for item in instance_values if item.get('field') == f'{key}Fqdn'), None)
    vendor = next((item.get('value') for item in instance_values if item.get('field') == 'vendor'), None)

    if fqdn and vendor:
        return {'fqdn': fqdn, 'vendor': vendor}
    else:
        logger.warning(f"Skipping device with missing FQDN or vendor: {device}")
        return None


def load_routers_from_orchestrator(app_config: dict) -> Dict:
    """Gets devices from the orchestrator and returns a dictionary of FQDNs and vendors."""
    token = get_token(app_config['aai'])
    routers = {}
    end_cursor = 0
    has_next_page = True

    while has_next_page:
        query = f"""
        {{
            subscriptions(
                filterBy: {{field: "status", value: "PROVISIONING|ACTIVE"}},
                first: 100,
                after: {end_cursor},
                query: "tag:(RTR|OFFICE_ROUTER|Super_POP_SWITCH)"
            ) {{
                pageInfo {{
                    hasNextPage
                    endCursor
                }}
                page {{
                    subscriptionId
                    product {{
                        tag
                    }}
                }}
            }}
        }}
        """

        response = make_request(body={'query': query}, token=token, app_config=app_config)
        try:
            devices = response['data']['subscriptions']['page']
            page_info = response['data']['subscriptions']['pageInfo']
            end_cursor = page_info['endCursor']
            has_next_page = page_info['hasNextPage']
        except (TypeError, KeyError):
            devices = []
            has_next_page = False

        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [executor.submit(extract_router_info, device, token, app_config) for device in devices]
            for future in concurrent.futures.as_completed(futures):
                router_info = future.result()
                if router_info is not None:
                    routers[router_info['fqdn']] = router_info['vendor']

    return routers