Skip to content
Snippets Groups Projects
cli.py 8.13 KiB
import json
import logging.config
import os
import socket
from datetime import datetime
from typing import Iterable, List

import click
import jsonschema
from brian_polling_manager.interface_stats import vendors
from brian_polling_manager.interface_stats.services import (
    get_juniper_netconf,
    get_netconf_from_source_dir,
    write_points_to_influx,
    write_points_to_stdout,
)
from brian_polling_manager.interface_stats.vendors import Vendor, juniper
from brian_polling_manager.inventory import load_interfaces

from . import config

logger = logging.getLogger(__file__)

LOGGING_DEFAULT_CONFIG = {
    "version": 1,
    "disable_existing_loggers": False,
    "formatters": {
        "simple": {
            "format": "%(asctime)s - %(name)s "
            "(%(lineno)d) - %(levelname)s - %(message)s"
        }
    },
    "handlers": {
        "console": {
            "class": "logging.StreamHandler",
            "level": "DEBUG",
            "formatter": "simple",
            "stream": "ext://sys.stdout",
        },
    },
    "loggers": {
        "brian_polling_manager": {
            "level": "DEBUG",
            "handlers": ["console"],
            "propagate": False,
        }
    },
    "root": {"level": "INFO", "handlers": ["console"]},
}

_APP_CONFIG_PARAMS = {}


def set_app_params(params: dict):
    global _APP_CONFIG_PARAMS
    _APP_CONFIG_PARAMS = params


def setup_logging():
    """
    set up logging using the configured filename

    if LOGGING_CONFIG is defined in the environment, use this for
    the filename, otherwise use LOGGING_DEFAULT_CONFIG
    """
    logging_config = LOGGING_DEFAULT_CONFIG
    if "LOGGING_CONFIG" in os.environ:
        filename = os.environ["LOGGING_CONFIG"]
        with open(filename) as f:
            logging_config = json.loads(f.read())

    logging.config.dictConfig(logging_config)


def write_points(points: Iterable[dict], influx_params: dict, **kwargs):
    if _APP_CONFIG_PARAMS.get("testing", {}).get("no-out"):
        return

    if _APP_CONFIG_PARAMS.get("testing", {}).get("dry_run"):
        return write_points_to_stdout(points, influx_params=influx_params, **kwargs)

    return write_points_to_influx(points, influx_params=influx_params, **kwargs)


def get_netconf(router_name, vendor=Vendor.JUNIPER, **kwargs):
    source_dir = _APP_CONFIG_PARAMS.get("testing", {}).get("netconf-source-dir")
    if source_dir:
        return get_netconf_from_source_dir(router_name, source_dir)

    ssh_params = _APP_CONFIG_PARAMS[vendor.value]
    return get_juniper_netconf(router_name, ssh_params, **kwargs)


def validate_router_hosts(
    hostnames, vendor: Vendor, inprov_hosts=None, load_interfaces_=load_interfaces
):
    if inprov_hosts is None:
        return True

    logger.info(
        f"Validating hosts {' '.join(hostnames)} using providers {inprov_hosts}"
    )
    if vendor == Vendor.NOKIA:
        all_fqdns = []
    else:
        all_fqdns = {ifc["router"] for ifc in load_interfaces_(inprov_hosts)}

    extra_fqdns = set(hostnames) - set(all_fqdns)
    if extra_fqdns:
        raise ValueError(
            f"Routers are not in inventory provider or not {vendor.value}: "
            f"{' '.join(extra_fqdns)}"
        )
    return True


def process_router(
    router_fqdn: str,
    vendor: Vendor,
    all_influx_params: dict,
):
    if vendor == Vendor.JUNIPER:
        return process_juniper_router(router_fqdn, all_influx_params)
    else:
        return process_nokia_router(router_fqdn, all_influx_params)


def process_juniper_router(
    router_fqdn: str,
    all_influx_params: dict,
):
    logger.info(f"processing Juniper router {router_fqdn}")

    document = get_netconf(router_fqdn, vendor=Vendor.JUNIPER)
    timestamp = datetime.now()

    influx_params = all_influx_params["brian-counters"]
    points = _juniper_brian_points(
        router_fqdn=router_fqdn,
        netconf_doc=document,
        timestamp=timestamp,
        measurement_name=influx_params["measurement"],
    )
    write_points(points, influx_params=influx_params)

    influx_params = all_influx_params["error-counters"]
    points = _juniper_error_points(
        router_fqdn=router_fqdn,
        netconf_doc=document,
        timestamp=timestamp,
        measurement_name=influx_params["measurement"],
    )

    write_points(points, influx_params=influx_params)


def _juniper_brian_points(router_fqdn, netconf_doc, timestamp, measurement_name):
    interfaces = juniper.physical_interface_counters(netconf_doc)
    yield from vendors.brian_points(
        router_fqdn, interfaces, timestamp, measurement_name
    )

    interfaces = juniper.logical_interface_counters(netconf_doc)
    yield from vendors.brian_points(
        router_fqdn, interfaces, timestamp, measurement_name
    )


def _juniper_error_points(router_fqdn, netconf_doc, timestamp, measurement_name):
    interfaces = juniper.physical_interface_counters(netconf_doc)
    yield from vendors.error_points(
        router_fqdn, interfaces, timestamp, measurement_name
    )

    # [2024-03-21] We currently have no definition for error points on logical
    # interfaces. This operation is essentially a no-op. Perhaps in the future we will
    # get a definition for errors on logical interfaces
    interfaces = juniper.logical_interface_counters(netconf_doc)
    yield from vendors.error_points(
        router_fqdn, interfaces, timestamp, measurement_name
    )


def process_nokia_router(
    router_fqdn: str,
    all_influx_params: dict,
):
    logger.warning(f"skipping Nokia router {router_fqdn}")


def main(
    app_config_params: dict,
    router_fqdns: List[str],
    vendor: Vendor,
    raise_errors=False,
):
    vendor_str = vendor.value
    inprov_hosts = app_config_params.get("inventory")

    validate_router_hosts(router_fqdns, vendor=vendor, inprov_hosts=inprov_hosts)

    ssh_params = app_config_params.get(vendor_str)
    if not ssh_params:
        raise ValueError(f"'{vendor_str}' ssh params are required")

    error_count = 0
    for router in router_fqdns:
        try:
            process_router(
                router_fqdn=router,
                vendor=vendor,
                all_influx_params=app_config_params["influx"],
            )
        except Exception as e:
            logger.exception(
                f"Error while processing {vendor_str} {router}", exc_info=e
            )
            if raise_errors:
                raise

    return error_count


def validate_config(_unused_ctx, _unused_param, file):
    try:
        return config.load(file)
    except json.JSONDecodeError:
        raise click.BadParameter("config file is not valid json")
    except jsonschema.ValidationError as e:
        raise click.BadParameter(e)


def validate_hostname(_unused_ctx, _unused_param, hostname_or_names):
    hostnames = (
        hostname_or_names
        if isinstance(hostname_or_names, (list, tuple))
        else [hostname_or_names]
    )
    for _h in hostnames:
        try:
            socket.gethostbyname(_h)
        except socket.error:
            raise click.BadParameter(f"{_h} is not resolveable")
    return hostname_or_names


@click.command()
@click.option(
    "--config",
    "app_config_params",
    required=True,
    type=click.File("r"),
    help="config filename",
    callback=validate_config,
)
@click.option(
    "--juniper",
    is_flag=True,
    help="The given router fqdns are juniper routers",
)
@click.option(
    "--nokia",
    is_flag=True,
    help="The given router fqdns are nokia routers",
)
@click.argument("router-fqdn", nargs=-1, callback=validate_hostname)
def cli(
    app_config_params: dict,
    juniper: bool,
    nokia: bool,
    router_fqdn,
):
    if not router_fqdn:
        # Do nothing if no routers are specified
        return
    if not (juniper ^ nokia):
        raise click.BadParameter("Set either '--juniper' or '--nokia', but not both")

    vendor = Vendor.JUNIPER if juniper else Vendor.NOKIA

    set_app_params(app_config_params)

    setup_logging()

    error_count = main(
        app_config_params=app_config_params,
        router_fqdns=router_fqdn,
        vendor=vendor,
    )
    if error_count:
        raise click.ClickException(
            "Errors were encountered while processing interface stats"
        )


if __name__ == "__main__":
    cli()