Skip to content
Snippets Groups Projects
cli.py 9.74 KiB
import enum
import json
from logging import LogRecord
import logging.config
import sys
from datetime import datetime
from typing import Any, Iterable, List, Optional, Collection

import click
import jsonschema
from brian_polling_manager.influx import influx_client
from brian_polling_manager.interface_stats import vendors, config
from brian_polling_manager.interface_stats.vendors import Vendor
from brian_polling_manager.inventory import (
    INVENTORY_INTERFACES_SCHEMA,
    load_inventory_json,
)
from lxml import etree

logger = logging.getLogger()

DEFAULT_INTERFACES_URL = "/poller/interfaces/"


class PointGroup(enum.Enum):
    BRIAN = ("brian", "brian-counters", vendors.brian_points)
    ERRORS = ("error", "error-counters", vendors.error_points)

    def config_params(self, app_params: dict):
        return app_params[self.value[1]]

    @property
    def points(self):
        return self.value[2]

    def __str__(self):
        return self.value[0]


def write_points_to_influx(
    points: Iterable[dict],
    influx_params: dict,
    timeout=5,
    batch_size=50,
):
    client = influx_client({"timeout": timeout, **influx_params})
    with client:
        client.write_points(points, batch_size=batch_size)


def write_points_to_stdout(points, influx_params, stream=sys.stdout, **_):
    for point in points:
        stream.write(f"{influx_params['measurement']} - {json.dumps(point)}\n")
    stream.flush()


class OutputMethod(enum.Enum):
    INFLUX = ("influx", write_points_to_influx)
    STDOUT = ("stdout", write_points_to_stdout)
    NO_OUT = ("no-out", lambda *_, **__: None)

    def write_points(self, points: Iterable[dict], influx_params: dict, **kwargs):
        return self.value[1](points, influx_params=influx_params, **kwargs)

    @classmethod
    def from_string(cls, method: str):
        return {m.value[0]: m for m in cls}[method]

    def __str__(self):
        return self.value[0]


class MessageCounter(logging.NullHandler):
    def __init__(self, level=logging.NOTSET) -> None:
        super().__init__(level)
        self.count = 0

    def handle(self, record: LogRecord) -> None:
        self.count += 1


def setup_logging(debug=False) -> MessageCounter:
    """
    :param debug: set log level to DEBUG, or INFO otherwise
    :returns: a MessageCounter object that tracks error log messages
    """

    # demote ncclient logs
    def changeLevel(record):
        if record.levelno == logging.INFO:
            record.levelno = logging.DEBUG
            record.levelname = "DEBUG"
        return record

    def drop(record):
        pass

    logging.getLogger("ncclient.operations.rpc").addFilter(changeLevel)
    logging.getLogger("ncclient.transport.tls").addFilter(changeLevel)
    logging.getLogger("ncclient.transport.ssh").addFilter(drop)
    logging.getLogger("ncclient.transport.parser").addFilter(drop)

    level = logging.DEBUG if debug else logging.INFO
    counter = MessageCounter(level=logging.ERROR)
    stream_handler = logging.StreamHandler(sys.stdout)
    stream_handler.setLevel(level)
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(message)s",
        level=level,
        handlers=[counter, stream_handler],
    )
    return counter


def load_interfaces(
    router_fqdn: str, interfaces: Any, app_config_params: dict, point_group: PointGroup
):
    # if we choose to write points for all interfaces and we have provided inventory
    # provider hosts, we make a selection based on the interfaces. Otherwise we write
    # points for all interfaces we find on the router
    if interfaces is not ALL_:
        return interfaces

    inprov_hosts = app_config_params.get("inventory")
    params = point_group.config_params(app_config_params)
    if inprov_hosts is not None:
        return _get_interfaces_for_router(
            router_fqdn,
            inprov_hosts=inprov_hosts,
            url=params.get("inventory-url", DEFAULT_INTERFACES_URL),
        )
    return None


def _get_interfaces_for_router(
    router: str, inprov_hosts: List[str], url: str
) -> List[str]:
    logger.info(
        f"Fetching interfaces from inventory provider: {inprov_hosts} using url '{url}'"
    )

    all_interfaces = [
        ifc["name"]
        for ifc in load_inventory_json(url, inprov_hosts, INVENTORY_INTERFACES_SCHEMA)
        if ifc["router"] == router
    ]

    if not all_interfaces:
        raise click.ClickException(f"No interfaces found for router {router}")

    return all_interfaces


def process_router(
    router_fqdn: str,
    vendor: Vendor,
    document: Any,
    timestamp: datetime,
    interfaces: Optional[List[str]],
    app_config_params: dict,
    output: OutputMethod,
    point_group: PointGroup,
):
    influx_params = point_group.config_params(app_config_params)["influx"]

    points = list(
        _points(
            router_fqdn=router_fqdn,
            netconf_doc=document,
            interfaces=interfaces,
            timestamp=timestamp,
            measurement_name=influx_params["measurement"],
            vendor=vendor,
            point_group=point_group,
        )
    )

    _log_interface_points_sorted(points, point_kind=str(point_group))
    output.write_points(points, influx_params=influx_params)


def _points(
    router_fqdn: str,
    netconf_doc: etree.Element,
    interfaces: Optional[List[str]],
    timestamp: datetime,
    measurement_name: str,
    vendor: Vendor,
    point_group: PointGroup,
):
    counters = vendor.interface_counters(netconf_doc, interfaces=interfaces)
    yield from point_group.points(router_fqdn, counters, timestamp, measurement_name)


def _log_interface_points_sorted(points: Collection[dict], point_kind=""):
    N_COLUMNS = 5
    num_points = len(points)
    point_kind = point_kind + " " if point_kind else ""
    semicolon = ":" if num_points else ""
    logger.info(f"Found {point_kind}points for {num_points} interfaces{semicolon}")

    if not points:
        return

    interfaces = sorted(p["tags"]["interface_name"] for p in points)
    longest_ifc = max(len(i) for i in interfaces)
    ifc_count = len(interfaces)
    for n in range(ifc_count // N_COLUMNS + (ifc_count % N_COLUMNS > 0)):
        ifc_slice = interfaces[n * N_COLUMNS : (n + 1) * N_COLUMNS]
        logger.info("    ".join(i.ljust(longest_ifc) for i in ifc_slice))


ALL_ = object()


def main(
    app_config_params: dict,
    router_fqdn: str,
    vendor: Vendor,
    output: OutputMethod = OutputMethod.INFLUX,
    interfaces=ALL_,
):
    vendor_str = str(vendor)
    logger.info(f"Processing {vendor_str.capitalize()} router {router_fqdn}")

    if not app_config_params.get(vendor_str):
        raise ValueError(f"'{vendor_str}' ssh params are required")
    ssh_params = vendor.config_params(app_config_params)
    netconf = vendor.get_netconf(router_name=router_fqdn, ssh_params=ssh_params)
    timestamp = datetime.now()

    for point_group in PointGroup:
        logger.info(f"Processing {str(point_group).capitalize()} points...")

        check_interfaces = load_interfaces(
            router_fqdn=router_fqdn,
            interfaces=interfaces,
            app_config_params=app_config_params,
            point_group=point_group,
        )

        process_router(
            router_fqdn=router_fqdn,
            vendor=vendor,
            document=netconf,
            timestamp=timestamp,
            interfaces=check_interfaces,
            app_config_params=app_config_params,
            output=output,
            point_group=point_group,
        )


def validate_config(_unused_ctx, _unused_param, file):
    try:
        return config.load(file)
    except json.JSONDecodeError:
        raise click.BadParameter("config file is not valid json")
    except jsonschema.ValidationError as e:
        raise click.BadParameter(e)


@click.command()
@click.option(
    "--config",
    "app_config_params",
    required=True,
    type=click.File("r"),
    help="config filename",
    callback=validate_config,
)
@click.option("--juniper", help="A Juniper router fqdn")
@click.option("--nokia", help="A Nokia router fqdn")
@click.option(
    "-o",
    "--output",
    type=click.Choice(["influx", "stdout", "no-out"], case_sensitive=False),
    default="influx",
    help="Choose an output method. Default: influx",
)
@click.option(
    "--all",
    "all_",
    is_flag=True,
    default=False,
    help=(
        "Write points for all interfaces found in inventory provider for this router."
        " Do not use this flag when supplying a list of interfaces"
    ),
)
@click.option(
    "-v", "--verbose", is_flag=True, default=False, help="Run with verbose output"
)
@click.argument("interfaces", nargs=-1)
def cli(
    app_config_params: dict,
    juniper: bool,
    nokia: bool,
    output: str,
    all_: bool,
    verbose: bool,
    interfaces: List[str],
):
    if not (interfaces or all_):
        # Do nothing if no interfaces are specified
        return

    if interfaces and all_:
        raise click.BadParameter("Do not supply both 'interfaces' and '--all'")

    if not (juniper or nokia) or (juniper and nokia):
        raise click.BadParameter(
            "Supply either a '--juniper' or '--nokia' router, but not both"
        )
    router_fqdn = juniper or nokia
    vendor = Vendor.JUNIPER if juniper else Vendor.NOKIA

    error_counter = setup_logging(debug=verbose)

    try:
        main(
            app_config_params=app_config_params,
            router_fqdn=router_fqdn,
            vendor=vendor,
            output=OutputMethod.from_string(output.lower()),
            interfaces=interfaces if interfaces else ALL_,
        )
    except Exception:
        logger.exception(
            f"Error while processing {str(vendor).capitalize()} router {router_fqdn}"
        )

    if error_counter.count:
        raise click.ClickException(
            "Errors were encountered while processing interface stats"
        )


if __name__ == "__main__":
    cli()