Skip to content
Snippets Groups Projects
cli.py 9.20 KiB
import json
import logging.config
import os
import socket
from datetime import datetime
from typing import Iterable, List, Sequence

import click
import jsonschema
from brian_polling_manager.interface_stats import vendors
from brian_polling_manager.interface_stats.services import (
    write_points_to_influx,
    write_points_to_stdout,
)
from brian_polling_manager.interface_stats.vendors import Vendor, juniper, nokia
from brian_polling_manager.inventory import load_interfaces

from . import config

logger = logging.getLogger(__file__)

LOGGING_DEFAULT_CONFIG = {
    "version": 1,
    "disable_existing_loggers": False,
    "formatters": {"simple": {"format": "%(asctime)s - %(levelname)s - %(message)s"}},
    "handlers": {
        "console": {
            "class": "logging.StreamHandler",
            "level": "INFO",
            "formatter": "simple",
            "stream": "ext://sys.stdout",
        },
    },
    "loggers": {
        "brian_polling_manager": {
            "level": "INFO",
            "handlers": ["console"],
            "propagate": False,
        }
    },
    "root": {"level": "INFO", "handlers": ["console"]},
}

# TODO: (smell) this makes the methods that use it stateful/non-functional (ER)
_APP_CONFIG_PARAMS = {}


def set_app_params(params: dict):
    global _APP_CONFIG_PARAMS
    _APP_CONFIG_PARAMS = params


def setup_logging():
    """
    set up logging using the configured filename

    if LOGGING_CONFIG is defined in the environment, use this for
    the filename, otherwise use LOGGING_DEFAULT_CONFIG
    """

    # demote ncclient logs
    def changeLevel(record):
        if record.levelno == logging.INFO:
            record.levelno = logging.DEBUG
            record.levelname = "DEBUG"
        return record

    logging.getLogger("ncclient.transport.ssh").addFilter(changeLevel)
    logging.getLogger("ncclient.operations.rpc").addFilter(changeLevel)

    logging_config = LOGGING_DEFAULT_CONFIG
    if "LOGGING_CONFIG" in os.environ:
        filename = os.environ["LOGGING_CONFIG"]
        with open(filename) as f:
            logging_config = json.loads(f.read())

    logging.config.dictConfig(logging_config)


def write_points(points: Iterable[dict], influx_params: dict, **kwargs):
    if _APP_CONFIG_PARAMS.get("testing", {}).get("no-out"):
        return

    if _APP_CONFIG_PARAMS.get("testing", {}).get("dry_run"):
        return write_points_to_stdout(points, influx_params=influx_params, **kwargs)

    return write_points_to_influx(points, influx_params=influx_params, **kwargs)


def get_netconf(router_name, vendor=Vendor.JUNIPER, **kwargs):
    source_dir = _APP_CONFIG_PARAMS.get("testing", {}).get("netconf-source-dir")
    if source_dir:
        if vendor == Vendor.JUNIPER:
            return juniper.get_netconf_interface_info_from_source_dir(
                router_name, source_dir
            )
        else:
            return nokia.get_netconf_interface_info_from_source_dir(
                router_name, source_dir
            )

    ssh_params = _APP_CONFIG_PARAMS[vendor.value]
    if vendor == Vendor.JUNIPER:
        return juniper.get_netconf_interface_info(router_name, ssh_params, **kwargs)
    else:
        return nokia.get_netconf_interface_info(router_name, ssh_params, **kwargs)


def validate_router_hosts(
    hostnames: List[str],
    vendor: Vendor,
    inprov_hosts=None,
    load_interfaces_=load_interfaces,
):
    if inprov_hosts is None:
        return True

    logger.info(
        f"Validating hosts {' '.join(hostnames)} using providers {inprov_hosts}"
    )
    if vendor == Vendor.NOKIA:
        all_fqdns = []
    else:
        all_fqdns = {ifc["router"] for ifc in load_interfaces_(inprov_hosts)}

    extra_fqdns = set(hostnames) - set(all_fqdns)
    if extra_fqdns:
        raise ValueError(
            f"Routers are not in inventory provider or not {vendor.value}: "
            f"{' '.join(extra_fqdns)}"
        )
    return True


def process_router(
    router_fqdn: str,
    vendor: Vendor,
    all_influx_params: dict,
):
    if vendor == Vendor.JUNIPER:
        return process_juniper_router(router_fqdn, all_influx_params)
    else:
        return process_nokia_router(router_fqdn, all_influx_params)


def process_juniper_router(
    router_fqdn: str,
    all_influx_params: dict,
):
    logger.info(f"Processing Juniper router {router_fqdn}")

    document = get_netconf(router_fqdn, vendor=Vendor.JUNIPER)
    timestamp = datetime.now()

    influx_params = all_influx_params["brian-counters"]
    logger.info("Processing Brian points...")
    points = list(
        _juniper_brian_points(
            router_fqdn=router_fqdn,
            netconf_doc=document,
            timestamp=timestamp,
            measurement_name=influx_params["measurement"],
        )
    )
    _log_interface_points_sorted(points)
    write_points(points, influx_params=influx_params)

    influx_params = all_influx_params["error-counters"]
    logger.info("Processing Error points...")
    points = list(
        _juniper_error_points(
            router_fqdn=router_fqdn,
            netconf_doc=document,
            timestamp=timestamp,
            measurement_name=influx_params["measurement"],
        )
    )

    _log_interface_points_sorted(points, point_kind="error")
    write_points(points, influx_params=influx_params)


def _log_interface_points_sorted(points: Sequence[dict], point_kind=""):
    N_COLUMNS = 5
    num_points = len(points)
    point_kind = point_kind + " " if point_kind else ""
    semicolon = ":" if num_points else ""
    logger.info(f"Found {point_kind}points for {num_points} interfaces{semicolon}")

    if not points:
        return

    interfaces = sorted(p["tags"]["interface_name"] for p in points)
    longest_ifc = max(len(i) for i in interfaces)
    for n in range(len(interfaces) // N_COLUMNS + 1):
        ifc_slice = interfaces[n * N_COLUMNS : (n + 1) * N_COLUMNS]
        logger.info("    ".join(i.ljust(longest_ifc) for i in ifc_slice))


def _juniper_brian_points(router_fqdn, netconf_doc, timestamp, measurement_name):
    interfaces = juniper.interface_counters(netconf_doc)
    yield from vendors.brian_points(
        router_fqdn, interfaces, timestamp, measurement_name
    )


def _juniper_error_points(router_fqdn, netconf_doc, timestamp, measurement_name):
    interfaces = juniper.interface_counters(netconf_doc)
    yield from vendors.error_points(
        router_fqdn, interfaces, timestamp, measurement_name
    )


def process_nokia_router(
    router_fqdn: str,
    all_influx_params: dict,
):
    logger.warning(f"skipping Nokia router {router_fqdn}")


def main(
    app_config_params: dict,
    router_fqdns: List[str],
    vendor: Vendor,
    raise_errors=False,
):
    vendor_str = vendor.value
    inprov_hosts = app_config_params.get("inventory")

    validate_router_hosts(router_fqdns, vendor=vendor, inprov_hosts=inprov_hosts)

    ssh_params = app_config_params.get(vendor_str)
    if not ssh_params:
        raise ValueError(f"'{vendor_str}' ssh params are required")

    error_count = 0
    for router in router_fqdns:
        try:
            process_router(
                router_fqdn=router,
                vendor=vendor,
                all_influx_params=app_config_params["influx"],
            )
        except Exception as e:
            logger.exception(
                f"Error while processing {vendor_str} {router}", exc_info=e
            )
            error_count += 1

            if raise_errors:
                raise

    return error_count


def validate_config(_unused_ctx, _unused_param, file):
    try:
        return config.load(file)
    except json.JSONDecodeError:
        raise click.BadParameter("config file is not valid json")
    except jsonschema.ValidationError as e:
        raise click.BadParameter(e)


def validate_hostname(_unused_ctx, _unused_param, hostname_or_names):
    hostnames = (
        hostname_or_names
        if isinstance(hostname_or_names, (list, tuple))
        else [hostname_or_names]
    )
    for _h in hostnames:
        try:
            socket.gethostbyname(_h)
        except socket.error:
            raise click.BadParameter(f"{_h} is not resolveable")
    return hostname_or_names


@click.command()
@click.option(
    "--config",
    "app_config_params",
    required=True,
    type=click.File("r"),
    help="config filename",
    callback=validate_config,
)
@click.option(
    "--juniper",
    is_flag=True,
    help="The given router fqdns are juniper routers",
)
@click.option(
    "--nokia",
    is_flag=True,
    help="The given router fqdns are nokia routers",
)
@click.argument("router-fqdn", nargs=-1, callback=validate_hostname)
def cli(
    app_config_params: dict,
    juniper: bool,
    nokia: bool,
    router_fqdn,
):
    if not router_fqdn:
        # Do nothing if no routers are specified
        return
    if not (juniper ^ nokia):
        raise click.BadParameter("Set either '--juniper' or '--nokia', but not both")

    vendor = Vendor.JUNIPER if juniper else Vendor.NOKIA

    set_app_params(app_config_params)

    setup_logging()

    error_count = main(
        app_config_params=app_config_params,
        router_fqdns=router_fqdn,
        vendor=vendor,
    )
    if error_count:
        raise click.ClickException(
            "Errors were encountered while processing interface stats"
        )


if __name__ == "__main__":
    cli()