Skip to content
Snippets Groups Projects
test_interface_stats.py 8.23 KiB
import datetime
from functools import partial
import itertools
import re
from typing import Iterable
from unittest.mock import MagicMock, call, patch

from brian_polling_manager.interface_stats.services.writers import (
    InfluxPointWriter,
    PointWriter,
)
import jsonschema
import pytest
from brian_polling_manager import influx
from brian_polling_manager.interface_stats import brian, cli, errors, vendors
from brian_polling_manager.interface_stats.services.netconf import (
    JuniperNetconfProvider,
    get_netconf_provider,
)
from brian_polling_manager.interface_stats.vendors import juniper
from ncclient.operations.rpc import RPCReply


def test_sanity_check_snapshot_data(polled_interfaces, all_juniper_routers):
    """
    verify that all routers with interfaces to be polled
    are in the test data set
    :return:
    """
    missing_routers = set(polled_interfaces.keys()) - set(all_juniper_routers)
    assert len(missing_routers) == 0


def _is_enabled(ifc_name, ifc_doc):
    m = re.match(r"^([^\.]+)\.?.*", ifc_name)
    assert m  # sanity: should never fail
    phy = ifc_doc.xpath(
        f'//interface-information/physical-interface[normalize-space(name)="{m.group(1)}"]'
    )[0]
    admin_status = phy.xpath("./admin-status/text()")[0].strip()
    oper_status = phy.xpath("./oper-status/text()")[0].strip()

    return admin_status == "up" and oper_status == "up"


def test_verify_all_interfaces_present(
    single_router_fqdn, get_netconf, polled_interfaces
):
    """
    verify that all the interfaces we expect to poll
    are available in the netconf data
    compares a snapshot of all netconf docs with a
    a snapshot of inventory /poller/interfaces
    (the snapshots were all taken around the same time)
    """
    if single_router_fqdn not in polled_interfaces:
        pytest.skip(f"{single_router_fqdn} has no expected polled interfaces")

    doc = get_netconf(single_router_fqdn)
    phy = juniper.physical_interface_counters(doc)
    log = juniper.logical_interface_counters(doc)
    interfaces = set(x["name"] for x in itertools.chain(phy, log))
    missing_interfaces = polled_interfaces[single_router_fqdn] - interfaces
    for ifc_name in missing_interfaces:
        # verify that any missing interfaces are admin/oper disabled
        assert not _is_enabled(ifc_name, doc)


@pytest.fixture(
    params=[
        (
            juniper.physical_interface_counters,
            vendors.PHYSICAL_INTERFACE_COUNTER_SCHEMA,
        ),
        (
            juniper.logical_interface_counters,
            vendors.LOGICAL_INTERFACE_COUNTER_SCHEMA,
        ),
    ]
)
def interface_counter_calculator(request):
    return request.param


def test_validate_interface_counters(
    interface_counter_calculator, single_router_fqdn, get_netconf
):
    fun, schema = interface_counter_calculator
    doc = get_netconf(single_router_fqdn)
    interfaces = list(fun(doc))
    assert interfaces
    for ifc in interfaces:
        jsonschema.validate(ifc, schema)


def test_interface_counters(router_fqdn, interface_counter_calculator, get_netconf):
    fun, _ = interface_counter_calculator
    doc = get_netconf(router_fqdn)
    interfaces = list(fun(doc))
    assert interfaces


@pytest.fixture(
    params=[
        (brian.counters, brian.BRIAN_COUNTER_DICT_SCHEMA),
        (errors.counters, errors.ERROR_COUNTER_DICT_SCHEMA),
    ]
)
def point_counter_calculator(request):
    return request.param


def test_validate_point_counters(
    single_router_fqdn,
    interface_counter_calculator,
    point_counter_calculator,
    get_netconf,
):
    interface_fun, _ = interface_counter_calculator
    point_fun, schema = point_counter_calculator
    if (
        interface_fun is juniper.logical_interface_counters
        and point_fun is errors.counters
    ):
        # We should not have any error counters for logical interfaces
        pytest.skip()

    doc = get_netconf(single_router_fqdn)
    interfaces = interface_fun(doc)
    counters = list(
        point_fun(router_fqdn=single_router_fqdn, interface_counters=interfaces)
    )

    assert counters
    for ctrs in counters:
        jsonschema.validate(ctrs, schema)


def test_point_counters(
    router_fqdn,
    interface_counter_calculator,
    point_counter_calculator,
    get_netconf,
):
    interface_fun, _ = interface_counter_calculator
    point_fun, _ = point_counter_calculator
    doc = get_netconf(router_fqdn)
    interfaces = interface_fun(doc)
    counters = list(point_fun(router_fqdn=router_fqdn, interface_counters=interfaces))
    if (
        interface_fun is juniper.logical_interface_counters
        and point_fun is errors.counters
    ):
        # We should not have any error counters for logical interfaces
        assert not counters
    else:
        assert counters


@pytest.fixture(params=[cli._brian_points, cli._error_points])
def influx_point_calculator(request):
    return request.param


def test_validate_influx_points(
    single_router_fqdn, influx_point_calculator, get_netconf
):
    points = list(
        influx_point_calculator(
            router_fqdn=single_router_fqdn,
            netconf_doc=get_netconf(single_router_fqdn),
            timestamp=datetime.datetime.now(),
            measurement_name="blah",
        )
    )
    assert len(points)
    for point in points:
        jsonschema.validate(point, influx.INFLUX_POINT)
        assert point["fields"]  # any trivial points should already be filtered


def test_influx_points(router_fqdn, influx_point_calculator, get_netconf):
    points = list(
        influx_point_calculator(
            router_fqdn=router_fqdn,
            netconf_doc=get_netconf(router_fqdn),
            timestamp=datetime.datetime.now(),
            measurement_name="blah",
        )
    )
    assert len(points)
    for point in points:
        assert point["fields"]  # any trivial points should already be filtered


def test_main_for_all_juniper_routers(all_juniper_routers, data_dir):
    config = {
        "juniper": {"some": "params"},
        "influx": {
            "brian-counters": {"measurement": "brian"},
            "error-counters": {"measurement": "error"},
        },
    }

    class ValidatingPointWriter(PointWriter):
        def __init__(self, *args, **kwargs) -> None:
            pass

        def write_points(self, points: Iterable[dict]):
            points = list(points)
            assert points
            for point in points:
                jsonschema.validate(point, influx.INFLUX_POINT)
                assert point["fields"]  # must contain at least one field

    cli.main(
        app_config_params=config,
        juniper_fqdns=all_juniper_routers,
        nokia_fqdns=[],
        get_point_writer_=ValidatingPointWriter,
        get_netconf_provider_=partial(get_netconf_provider, source_dir=data_dir),
        raise_errors=True,
    )


class TestJuniperNetconfProvider:
    RAW_RESPONSE_FILE = "raw-response-sample.xml"

    @pytest.fixture(autouse=True)
    def mocked_rpc(self, data_dir):
        raw_response = data_dir.joinpath(self.RAW_RESPONSE_FILE).read_text()
        with patch(
            "brian_polling_manager.interface_stats.services.netconf._rpc",
            return_value=RPCReply(raw_response),
        ) as mock:
            yield mock

    def test_calls_rpc_with_params(self, mocked_rpc):
        ssh_params = {}
        router_name = "some-router"
        provider = JuniperNetconfProvider(ssh_params)

        provider.get(router_name)

        call_args = mocked_rpc.call_args
        assert call_args[0][0] == router_name
        assert call_args[1]["ssh_params"] is ssh_params
        assert call_args[1]["command"].tag == "get-interface-information"

    def test_converts_rpc_response_to_xml(self):
        ssh_params = {}
        router_name = "some-router"
        provider = JuniperNetconfProvider(ssh_params)

        result = provider.get(router_name)

        assert result.tag == "rpc-reply"


def test_influx_point_writer():
    client_factory = MagicMock()
    influx_params = {"some": "param"}
    writer = InfluxPointWriter(influx_params, timeout=10, client_factory=client_factory)
    writer.write_points(["point1", "point2"])
    assert client_factory.call_args == call({"timeout": 10, "some": "param"})
    assert client_factory().__enter__.call_count == 1
    assert client_factory().write_points.call_args == call(["point1", "point2"])