Skip to content
Snippets Groups Projects
test_interface_stats.py 16.47 KiB
import itertools
import re
from datetime import datetime
from unittest.mock import MagicMock, Mock, call, patch

import jsonschema
import pytest
from brian_polling_manager.interface_stats import cli, services
from brian_polling_manager.interface_stats.vendors import Vendor, common, juniper
from lxml import etree
from ncclient.operations.rpc import RPCReply


def test_sanity_check_snapshot_data(polled_interfaces, all_juniper_routers):
    """
    verify that all routers with interfaces to be polled
    are in the test data set
    :return:
    """
    missing_routers = set(polled_interfaces.keys()) - set(all_juniper_routers)
    assert len(missing_routers) == 0


def test_verify_all_interfaces_present(single_router_fqdn, polled_interfaces):
    """
    verify that all the interfaces we expect to poll
    are available in the netconf data
    compares a snapshot of all netconf docs with a
    a snapshot of inventory /poller/interfaces
    (the snapshots were all taken around the same time)
    """

    def _is_enabled(ifc_name, ifc_doc):
        m = re.match(r"^([^\.]+)\.?.*", ifc_name)
        assert m  # sanity: should never fail
        phy = ifc_doc.xpath(
            f'//interface-information/physical-interface[normalize-space(name)="{m.group(1)}"]'
        )[0]
        admin_status = phy.xpath("./admin-status/text()")[0].strip()
        oper_status = phy.xpath("./oper-status/text()")[0].strip()

        return admin_status == "up" and oper_status == "up"

    if single_router_fqdn not in polled_interfaces:
        pytest.skip(f"{single_router_fqdn} has no expected polled interfaces")

    doc = cli.get_netconf(single_router_fqdn, ssh_params=None)
    phy = juniper._physical_interface_counters(doc)
    log = juniper._logical_interface_counters(doc)
    interfaces = set(x["name"] for x in itertools.chain(phy, log))
    missing_interfaces = polled_interfaces[single_router_fqdn] - interfaces
    for ifc_name in missing_interfaces:
        # verify that any missing interfaces are admin/oper disabled
        assert not _is_enabled(ifc_name, doc)


class TestParseCounters:
    def test_parse_counters(self):
        xml = """<root><ab>42</ab></root>"""
        struct = {"something": {"path": "./ab", "transform": int}}
        result = juniper._parse_interface_xml(etree.fromstring(xml), struct)
        assert result == {"something": 42}

    def test_parse_counters_multiple_path(self):
        xml = """<root><a>This is something</a></root>"""
        struct = {"something": {"path": ["./b", "./a"], "transform": str}}
        result = juniper._parse_interface_xml(etree.fromstring(xml), struct)
        assert result == {"something": "This is something"}

    def test_parse_counters_nested(self):
        xml = """<root><a>This is something</a></root>"""
        struct = {"something": {"nested": {"path": "./a", "transform": str}}}
        result = juniper._parse_interface_xml(etree.fromstring(xml), struct)
        assert result == {"something": {"nested": "This is something"}}

    def test_skips_unavailable_field(self):
        xml = """<root><a>This is something</a></root>"""
        struct = {
            "something": {"path": "./a", "transform": str},
            "something_else": {"nested": {"path": "./b", "transform": str}},
        }
        result = juniper._parse_interface_xml(etree.fromstring(xml), struct)
        assert result == {"something": "This is something"}

    def test_logs_on_missing_required_field(self, caplog):
        xml = """<root></root>"""
        struct = {
            "something": {"path": "./a", "transform": str, "required": True},
        }
        result = juniper._parse_interface_xml(etree.fromstring(xml), struct)
        assert result is None
        record = caplog.records[0]
        assert record.levelname == "ERROR"
        assert "required path ./a" in record.message

    def test_logs_on_double_entry(self, caplog):
        xml = """<root><a>This is something</a><a>Something Else</a></root>"""
        struct = {
            "something": {"path": "./a", "transform": str},
        }
        result = juniper._parse_interface_xml(etree.fromstring(xml), struct)
        assert result == {"something": "This is something"}
        record = caplog.records[0]
        assert record.levelname == "WARNING"
        assert "found more than one ./a" in record.message


@pytest.fixture
def juniper_router_doc_containing_every_field():
    doc = """\
<interface-information>
    <physical-interface>
        <name>ae12</name>
        <admin-status>up</admin-status>
        <oper-status>up</oper-status>
        <traffic-statistics>
            <input-bytes>1</input-bytes>
            <input-packets>2</input-packets>
            <output-bytes>3</output-bytes>
            <output-packets>4</output-packets>
            <ipv6-transit-statistics>
                <input-bytes>5</input-bytes>
                <input-packets>6</input-packets>
                <output-bytes>7</output-bytes>
                <output-packets>8</output-packets>
            </ipv6-transit-statistics>
        </traffic-statistics>
        <input-error-list>
            <input-errors>11</input-errors>
            <input-discards>12</input-discards>
            <input-fifo-errors>13</input-fifo-errors>
            <input-drops>14</input-drops>
            <framing-errors>15</framing-errors>
            <input-resource-errors>16</input-resource-errors>
        </input-error-list>
        <output-error-list>
            <output-errors>21</output-errors>
            <output-drops>22</output-drops>
            <output-resource-errors>23</output-resource-errors>
            <output-fifo-errors>24</output-fifo-errors>
            <output-collisions>25</output-collisions>
        </output-error-list>
        <ethernet-mac-statistics>
            <input-crc-errors>31</input-crc-errors>
            <output-crc-errors>32</output-crc-errors>
            <input-total-errors>33</input-total-errors>
            <output-total-errors>34</output-total-errors>
        </ethernet-mac-statistics>
        <ethernet-pcs-statistics>
            <bit-error-seconds>41</bit-error-seconds>
            <errored-blocks-seconds>42</errored-blocks-seconds>
        </ethernet-pcs-statistics>
        <logical-interface>
            <name>ae12.1</name>
            <traffic-statistics>
                <input-bytes>51</input-bytes>
                <input-packets>52</input-packets>
                <output-bytes>53</output-bytes>
                <output-packets>54</output-packets>
                <ipv6-transit-statistics>
                    <input-bytes>55</input-bytes>
                    <input-packets>56</input-packets>
                    <output-bytes>57</output-bytes>
                    <output-packets>58</output-packets>
                </ipv6-transit-statistics>
            </traffic-statistics>
        </logical-interface>
    </physical-interface>
</interface-information>
"""
    return etree.fromstring(doc)


def test_physical_interface_counters(juniper_router_doc_containing_every_field):
    result = list(
        juniper._physical_interface_counters(juniper_router_doc_containing_every_field)
    )
    assert len(result) == 1
    assert result[0] == {
        "name": "ae12",
        "brian": {
            "ingressOctets": 1,
            "ingressPackets": 2,
            "egressOctets": 3,
            "egressPackets": 4,
            "ingressOctetsv6": 5,
            "ingressPacketsv6": 6,
            "egressOctetsv6": 7,
            "egressPacketsv6": 8,
            "ingressErrors": 11,
            "ingressDiscards": 12,
            "egressErrors": 21,
        },
        "errors": {
            "input_discards": 12,
            "input_fifo_errors": 13,
            "input_drops": 14,
            "input_framing_errors": 15,
            "input_resource_errors": 16,
            "output_drops": 22,
            "output_resource_errors": 23,
            "output_fifo_errors": 24,
            "output_collisions": 25,
            "input_crc_errors": 31,
            "output_crc_errors": 32,
            "input_total_errors": 33,
            "output_total_errors": 34,
            "bit_error_seconds": 41,
            "errored_blocks_seconds": 42,
        },
    }


def test_logical_interface_counters(juniper_router_doc_containing_every_field):
    result = list(
        juniper._logical_interface_counters(juniper_router_doc_containing_every_field)
    )
    assert len(result) == 1
    assert result[0] == {
        "name": "ae12.1",
        "brian": {
            "ingressOctets": 51,
            "ingressPackets": 52,
            "egressOctets": 53,
            "egressPackets": 54,
            "ingressOctetsv6": 55,
            "ingressPacketsv6": 56,
            "egressOctetsv6": 57,
            "egressPacketsv6": 58,
        },
    }


@pytest.fixture(
    params=[juniper._physical_interface_counters, juniper._logical_interface_counters]
)
def generate_interface_counters(request):
    return request.param


def test_router_docs_do_not_generate_errors(
    router_fqdn, generate_interface_counters, caplog
):
    doc = cli.get_netconf(router_fqdn, ssh_params=None)
    counters = list(generate_interface_counters(doc))
    assert counters
    assert not [r for r in caplog.records if r.levelname in ("ERROR", "WARNING")]


@pytest.fixture(
    params=[
        (common.brian_points, common.BRIAN_POINT_FIELDS_SCHEMA),
        (common.error_points, common.ERROR_POINT_FIELDS_SCHEMA),
    ]
)
def generate_points_with_schema(request):
    return request.param


def test_brian_point_counters():
    now = datetime.now()
    points = list(
        common.brian_points(
            router_fqdn="some.router",
            interfaces=[
                {"name": "ae12", "brian": {"some": "fields"}},
                {"name": "ae12.1", "brian": {"more": "other fields"}},
            ],
            timestamp=now,
            measurement_name="blah",
        )
    )

    assert points == [
        {
            "time": now.strftime("%Y-%m-%dT%H:%M:%SZ"),
            "measurement": "blah",
            "tags": {"hostname": "some.router", "interface_name": "ae12"},
            "fields": {"some": "fields"},
        },
        {
            "time": now.strftime("%Y-%m-%dT%H:%M:%SZ"),
            "measurement": "blah",
            "tags": {"hostname": "some.router", "interface_name": "ae12.1"},
            "fields": {"more": "other fields"},
        },
    ]


def test_error_point_counters():
    now = datetime.now()
    points = list(
        common.error_points(
            router_fqdn="some.router",
            interfaces=[
                {"name": "ae12", "errors": {"some": "fields"}},
            ],
            timestamp=now,
            measurement_name="blah-errors",
        )
    )

    assert points == [
        {
            "time": now.strftime("%Y-%m-%dT%H:%M:%SZ"),
            "measurement": "blah-errors",
            "tags": {"hostname": "some.router", "interface_name": "ae12"},
            "fields": {"some": "fields"},
        }
    ]


def test_no_error_point_counters():
    now = datetime.now()
    points = list(
        common.error_points(
            router_fqdn="some.router",
            interfaces=[
                {
                    "name": "ae12.1",
                },
                {"name": "ae12", "errors": {"some": "fields"}},
            ],
            timestamp=now,
            measurement_name="blah-errors",
        )
    )

    assert points == [
        {
            "time": now.strftime("%Y-%m-%dT%H:%M:%SZ"),
            "measurement": "blah-errors",
            "tags": {"hostname": "some.router", "interface_name": "ae12"},
            "fields": {"some": "fields"},
        }
    ]


@patch.object(cli, "write_points")
def test_main_for_all_juniper_routers(write_points, all_juniper_routers):
    config = {
        "juniper": {"some": "params"},
        "influx": {
            "brian-counters": {"measurement": "brian"},
            "error-counters": {"measurement": "error"},
        },
    }
    total_points = 0
    calls = 0

    def validate(points, *_, **__):
        nonlocal total_points, calls
        points = list(points)
        calls += 1
        assert points
        for point in points:
            total_points += 1
            jsonschema.validate(point, services.INFLUX_POINT)
            assert point["fields"]  # must contain at least one field

    write_points.side_effect = validate

    cli.main(
        app_config_params=config,
        router_fqdns=all_juniper_routers,
        vendor=Vendor.JUNIPER,
        raise_errors=True,
    )
    assert calls == 104
    assert total_points == 6819


class TestGetJuniperNetConnf:
    RAW_RESPONSE_FILE = "raw-response-sample.xml"

    @pytest.fixture(autouse=True)
    def app_params(self):
        cli.set_app_params({"juniper": {"some": "param"}})

    @pytest.fixture(autouse=True)
    def mocked_rpc(self, data_dir):
        raw_response = data_dir.joinpath(self.RAW_RESPONSE_FILE).read_text()
        with patch.object(
            juniper, "_rpc", return_value=RPCReply(raw_response)
        ) as mock:
            yield mock

    def test_calls_rpc_with_params(self, mocked_rpc):
        router_name = "some-router"
        cli.get_netconf(router_name, vendor=Vendor.JUNIPER)

        call_args = mocked_rpc.call_args
        assert call_args[0][0] == router_name
        assert call_args[1]["ssh_params"] == {"some": "param"}
        assert call_args[1]["command"].tag == "get-interface-information"

    def test_converts_rpc_response_to_xml(self):
        router_name = "some-router"

        doc = cli.get_netconf(router_name, vendor=Vendor.JUNIPER)

        assert doc.tag == "rpc-reply"


@pytest.fixture
def load_interfaces():
    return Mock(
        return_value=[{"router": "host1"}, {"router": "host2"}, {"router": "host3"}]
    )


def test_validate_valid_hosts(load_interfaces):
    assert cli.validate_router_hosts(
        ("host1", "host2"),
        vendor=Vendor.JUNIPER,
        inprov_hosts=["some_host"],
        load_interfaces_=load_interfaces,
    )
    assert load_interfaces.called


def test_validate_invalid_hosts(load_interfaces):
    with pytest.raises(ValueError):
        cli.validate_router_hosts(
            ("host1", "invalid"),
            vendor=Vendor.JUNIPER,
            inprov_hosts=["some_host"],
            load_interfaces_=load_interfaces,
        )
    assert load_interfaces.called


def test_doesnt_validate_without_inprov_hosts(load_interfaces):
    assert cli.validate_router_hosts(
        ("host1", "invalid"),
        vendor=Vendor.JUNIPER,
        inprov_hosts=None,
        load_interfaces_=load_interfaces,
    )
    assert not load_interfaces.called


def test_write_points_to_influx():
    cli.set_app_params({})
    influx_factory = MagicMock()
    points = [{"point": "one"}, {"point": "two"}]
    influx_params = {"influx": "param"}
    cli.write_points(
        points=points, influx_params=influx_params, client_factory=influx_factory
    )
    assert influx_factory.call_args == call({"timeout": 5, "influx": "param"})
    assert influx_factory().__enter__.call_count
    assert influx_factory().write_points.call_args == call(points, batch_size=50)


def test_write_points_to_stdout():
    cli.set_app_params({"testing": {"dry_run": True}})
    stream = Mock()
    points = [{"point": "one"}, {"point": "two"}]
    influx_params = {"measurement": "meas"}
    cli.write_points(points=points, influx_params=influx_params, stream=stream)
    assert stream.write.call_args_list == [
        call('meas - {"point": "one"}\n'),
        call('meas - {"point": "two"}\n'),
    ]


@pytest.mark.parametrize(
    "input_params, expected",
    [
        (
            dict(
                hostname="localhost", port=1234, ssl=True, verify_ssl=True, timeout=10
            ),
            dict(host="localhost", port=1234, ssl=True, verify_ssl=True, timeout=10),
        ),
        (
            dict(hostname="http://localhost:1234"),
            dict(host="localhost", port=1234, ssl=False, verify_ssl=False),
        ),
        (
            dict(hostname="https://localhost:1234"),
            dict(host="localhost", port=1234, ssl=True, verify_ssl=True),
        ),
        (
            dict(hostname="http://localhost:1234", port=456, ssl=True, verify_ssl=True),
            dict(host="localhost", port=456, ssl=True, verify_ssl=True),
        ),
        (
            dict(hostname="http://localhost", port=456),
            dict(host="localhost", port=456, ssl=False, verify_ssl=False),
        ),
    ],
)
def test_prepare_influx_params(input_params, expected):
    defaults = dict(database="counters", username="user", password="pass", timeout=5)
    result = services.prepare_influx_params({**defaults, **input_params})
    assert result == {**defaults, **expected}