Skip to content
Snippets Groups Projects
cli.py 8.08 KiB
import json
import logging.config
import os
import pathlib
from datetime import datetime
from functools import partial
from typing import Callable, Iterable, List

import click

from brian_polling_manager.interface_stats.click_helpers import (
    validate_config,
    validate_hostname,
)
from brian_polling_manager.interface_stats.services.netconf import (
    NetconfProvider,
    get_netconf_provider,
)
from brian_polling_manager.interface_stats.services.writers import (
    PointWriter,
    get_point_writer,
)
from brian_polling_manager.interface_stats.vendors import Vendor, juniper
from brian_polling_manager.interface_stats import vendors

logger = logging.getLogger(__file__)

LOGGING_DEFAULT_CONFIG = {
    "version": 1,
    "disable_existing_loggers": False,
    "formatters": {
        "simple": {
            "format": "%(asctime)s - %(name)s "
            "(%(lineno)d) - %(levelname)s - %(message)s"
        }
    },
    "handlers": {
        "console": {
            "class": "logging.StreamHandler",
            "level": "DEBUG",
            "formatter": "simple",
            "stream": "ext://sys.stdout",
        },
    },
    "loggers": {
        "brian_polling_manager": {
            "level": "DEBUG",
            "handlers": ["console"],
            "propagate": False,
        }
    },
    "root": {"level": "INFO", "handlers": ["console"]},
}


def setup_logging():
    """
    set up logging using the configured filename

    if LOGGING_CONFIG is defined in the environment, use this for
    the filename, otherwise use LOGGING_DEFAULT_CONFIG
    """
    logging_config = LOGGING_DEFAULT_CONFIG
    if "LOGGING_CONFIG" in os.environ:
        filename = os.environ["LOGGING_CONFIG"]
        with open(filename) as f:
            logging_config = json.loads(f.read())

    # # TODO: this mac workaround should be removed ...
    # import platform
    # if platform.system() == 'Darwin':
    #     logging_config['handlers']['syslog_handler']['address'] \
    #         = '/var/run/syslog'

    logging.config.dictConfig(logging_config)


def process_juniper_router(
    router_fqdn: str,
    all_influx_params: dict,
    netconf_provider: NetconfProvider,
    get_point_writer_: Callable[[dict], PointWriter],
):
    logger.info(f"processing Juniper router {router_fqdn}")

    document = netconf_provider.get(router_fqdn)
    timestamp = datetime.now()

    influx_params = all_influx_params["brian-counters"]
    points = _juniper_brian_points(
        router_fqdn=router_fqdn,
        netconf_doc=document,
        timestamp=timestamp,
        measurement_name=influx_params["measurement"],
    )

    writer = get_point_writer_(influx_params)
    writer.write_points(points)

    influx_params = all_influx_params["error-counters"]
    points = _juniper_error_points(
        router_fqdn=router_fqdn,
        netconf_doc=document,
        timestamp=timestamp,
        measurement_name=influx_params["measurement"],
    )

    writer = get_point_writer_(influx_params)
    writer.write_points(points)


def _juniper_brian_points(router_fqdn, netconf_doc, timestamp, measurement_name):
    interfaces = juniper.physical_interface_counters(netconf_doc)
    yield from vendors.brian_points(
        router_fqdn, interfaces, timestamp, measurement_name
    )

    interfaces = juniper.logical_interface_counters(netconf_doc)
    yield from vendors.brian_points(
        router_fqdn, interfaces, timestamp, measurement_name
    )


def _juniper_error_points(router_fqdn, netconf_doc, timestamp, measurement_name):
    interfaces = juniper.physical_interface_counters(netconf_doc)
    yield from vendors.error_points(
        router_fqdn, interfaces, timestamp, measurement_name
    )

    # [2024-03-21] We currently have no definition for error points on logical
    # interfaces. This operation is essentially a no-op. Perhaps in the future we will
    # get a definition for errors on logical interfaces
    interfaces = juniper.logical_interface_counters(netconf_doc)
    yield from vendors.error_points(
        router_fqdn, interfaces, timestamp, measurement_name
    )


def process_nokia_router(
    router_fqdn: str,
    all_influx_params: dict,
    get_netconf: NetconfProvider,
    writer_factory: Callable[[dict], PointWriter] = get_point_writer,
):
    logger.warning(f"skipping Nokia router {router_fqdn}")


def get_routers_from_inventory_provider(vendor: Vendor, hosts: Iterable[str]):
    return []


def main(
    app_config_params: dict,
    juniper_fqdns: List[str],
    nokia_fqdns: List[str],
    get_netconf_provider_: Callable[[Vendor, dict], NetconfProvider],
    get_point_writer_: Callable[[dict], PointWriter],
    raise_errors=False,
):

    if not juniper_fqdns and not nokia_fqdns:
        inprov_hosts = app_config_params.get("inventory", [])
        if not inprov_hosts:
            raise ValueError("Must supply at least one inventory provider in config")
        juniper_fqdns = get_routers_from_inventory_provider(
            Vendor.JUNIPER, inprov_hosts
        )
        nokia_fqdns = get_routers_from_inventory_provider(Vendor.NOKIA, inprov_hosts)

    if juniper_fqdns:
        juniper_ssh_params = app_config_params.get("juniper")
        if not juniper_ssh_params:
            raise ValueError("'juniper' ssh params are required")

        juniper_netconf_provider = get_netconf_provider_(
            Vendor.JUNIPER, juniper_ssh_params
        )

    if nokia_fqdns:
        nokia_ssh_params = app_config_params.get("nokia")
        if not nokia_ssh_params:
            raise ValueError("'nokia' ssh params are required")

        nokia_netconf_provider = get_netconf_provider_(Vendor.NOKIA, nokia_ssh_params)

    error_count = 0
    for router in juniper_fqdns:
        try:
            process_juniper_router(
                router,
                app_config_params["influx"],
                netconf_provider=juniper_netconf_provider,
                get_point_writer_=get_point_writer_,
            )
        except Exception as e:
            logger.exception(f"Error while processing juniper {router}", exc_info=e)
            if raise_errors:
                raise
    for router in nokia_fqdns:
        try:
            process_nokia_router(
                router,
                app_config_params["influx"],
                get_netconf=nokia_netconf_provider,
                get_point_writer_=get_point_writer_,
            )
        except Exception as e:
            logger.exception(f"Error while processing nokia {router}", exc_info=e)
            if raise_errors:
                raise
    return error_count


@click.command()
@click.option(
    "--config",
    "app_config_params",
    required=True,
    type=click.File("r"),
    help="config filename",
    callback=validate_config,
)
@click.option(
    "--juniper",
    "juniper_fqdns",
    multiple=True,
    type=click.STRING,
    help="juniper router fqdn(s)",
    callback=validate_hostname,
)
@click.option(
    "--nokia",
    "nokia_fqdns",
    multiple=True,
    type=click.STRING,
    help="nuniper router fqdn(s)",
    callback=validate_hostname,
)
@click.option(
    "--source-dir",
    type=click.Path(
        exists=True,
        file_okay=False,
        dir_okay=True,
        readable=True,
        path_type=pathlib.Path,
    ),
    default=None,
    help="Read from a snapshot directory instead of querying the routers",
)
@click.option(
    "--dry-run",
    is_flag=True,
    help="Perform a dry run, do not write to influx but dump to stdout",
)
def cli(
    app_config_params: dict,
    juniper_fqdns: List[str],
    nokia_fqdns: List[str],
    source_dir: pathlib.Path,
    dry_run: bool,
):

    writer_factory = partial(get_point_writer, kind="stdout" if dry_run else "influx")
    netconf_provider_factory = partial(get_netconf_provider, source_dir=source_dir)

    setup_logging()

    error_count = main(
        app_config_params=app_config_params,
        juniper_fqdns=juniper_fqdns,
        nokia_fqdns=nokia_fqdns,
        get_netconf_provider_=netconf_provider_factory,
        get_point_writer_=writer_factory,
    )
    if error_count:
        raise click.ClickException(
            "Errors were encountered while processing interface stats"
        )


if __name__ == "__main__":
    cli()