Skip to content
Snippets Groups Projects
Commit cdf5ac6a authored by Pelle Koster's avatar Pelle Koster
Browse files

add check to inventory provider, add more tests

parent 026bcfa2
Branches
Tags
No related merge requests found
...@@ -4,8 +4,12 @@ import os ...@@ -4,8 +4,12 @@ import os
import pathlib import pathlib
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Callable, Iterable, List, Sequence, Union from typing import Callable, List
from brian_polling_manager.interface_stats.services.validation import (
RouterListValidator,
get_router_list_validator,
)
import click import click
from brian_polling_manager.interface_stats.click_helpers import ( from brian_polling_manager.interface_stats.click_helpers import (
...@@ -69,6 +73,20 @@ def setup_logging(): ...@@ -69,6 +73,20 @@ def setup_logging():
logging.config.dictConfig(logging_config) logging.config.dictConfig(logging_config)
def process_router(
router_fqdn: str,
vendor: Vendor,
all_influx_params: dict,
netconf_provider: NetconfProvider,
get_point_writer_: Callable[[dict], PointWriter],
):
if vendor == Vendor.JUNIPER:
process_func = process_juniper_router
else:
process_func = process_nokia_router
process_func(router_fqdn, all_influx_params, netconf_provider, get_point_writer_)
def process_juniper_router( def process_juniper_router(
router_fqdn: str, router_fqdn: str,
all_influx_params: dict, all_influx_params: dict,
...@@ -139,86 +157,41 @@ def process_nokia_router( ...@@ -139,86 +157,41 @@ def process_nokia_router(
logger.warning(f"skipping Nokia router {router_fqdn}") logger.warning(f"skipping Nokia router {router_fqdn}")
def get_routers_from_inventory_provider(vendor: Vendor, hosts: Iterable[str]):
return []
def validate_hosts(
router_fqdn: Sequence[str],
vendor: Vendor,
inprov_hosts: Union[str, Sequence[str], None] = None,
):
if inprov_hosts is None:
return
all_fqdns = get_routers_from_inventory_provider(vendor, inprov_hosts)
extra_fqdns = set(router_fqdn) - set(all_fqdns)
if extra_fqdns:
raise ValueError(
f"Routers are not in inventory provider or not {vendor.value}: "
f"{' '.join(extra_fqdns)}"
)
def main( def main(
app_config_params: dict, app_config_params: dict,
juniper_fqdns: List[str], router_fqdns: List[str],
nokia_fqdns: List[str], vendor: Vendor,
get_netconf_provider_: Callable[[Vendor, dict], NetconfProvider], get_netconf_provider_: Callable[[Vendor, dict], NetconfProvider],
get_point_writer_: Callable[[dict], PointWriter], get_point_writer_: Callable[[dict], PointWriter],
host_validator: RouterListValidator,
raise_errors=False, raise_errors=False,
): ):
vendor_str = vendor.value
if not juniper_fqdns and not nokia_fqdns: host_validator.validate_hosts(router_fqdns, vendor=vendor)
raise click.ClickException(
"At least one --juniper or --nokia router is required"
)
inprov_hosts = app_config_params.get("inventory", None)
if juniper_fqdns:
validate_hosts(juniper_fqdns, vendor=Vendor.JUNIPER, inprov_hosts=inprov_hosts)
juniper_ssh_params = app_config_params.get(Vendor.JUNIPER.value)
if not juniper_ssh_params:
raise ValueError(f"'{Vendor.JUNIPER.value}' ssh params are required")
juniper_netconf_provider = get_netconf_provider_(
Vendor.JUNIPER, juniper_ssh_params
)
if nokia_fqdns:
validate_hosts(juniper_fqdns, vendor=Vendor.NOKIA, inprov_hosts=inprov_hosts)
nokia_ssh_params = app_config_params.get(Vendor.NOKIA.value) ssh_params = app_config_params.get(vendor_str)
if not nokia_ssh_params: if not ssh_params:
raise ValueError(f"'{Vendor.JUNIPER.value}' ssh params are required") raise ValueError(f"'{vendor_str}' ssh params are required")
netconf_provider = get_netconf_provider_(vendor, ssh_params)
nokia_netconf_provider = get_netconf_provider_(Vendor.NOKIA, nokia_ssh_params)
error_count = 0 error_count = 0
for router in juniper_fqdns: for router in router_fqdns:
try: try:
process_juniper_router( process_router(
router, router_fqdn=router,
app_config_params["influx"], vendor=vendor,
netconf_provider=juniper_netconf_provider, all_influx_params=app_config_params["influx"],
netconf_provider=netconf_provider,
get_point_writer_=get_point_writer_, get_point_writer_=get_point_writer_,
) )
except Exception as e: except Exception as e:
logger.exception(f"Error while processing juniper {router}", exc_info=e) logger.exception(
if raise_errors: f"Error while processing {vendor_str} {router}", exc_info=e
raise
for router in nokia_fqdns:
try:
process_nokia_router(
router,
app_config_params["influx"],
get_netconf=nokia_netconf_provider,
get_point_writer_=get_point_writer_,
) )
except Exception as e:
logger.exception(f"Error while processing nokia {router}", exc_info=e)
if raise_errors: if raise_errors:
raise raise
return error_count return error_count
...@@ -233,19 +206,13 @@ def main( ...@@ -233,19 +206,13 @@ def main(
) )
@click.option( @click.option(
"--juniper", "--juniper",
"juniper_fqdns", is_flag=True,
multiple=True, help="The given router fqdns are juniper routers",
type=click.STRING,
help="juniper router fqdn(s)",
callback=validate_hostname,
) )
@click.option( @click.option(
"--nokia", "--nokia",
"nokia_fqdns", is_flag=True,
multiple=True, help="The given router fqdns are nokia routers",
type=click.STRING,
help="nokia router fqdn(s)",
callback=validate_hostname,
) )
@click.option( @click.option(
"--source-dir", "--source-dir",
...@@ -264,25 +231,49 @@ def main( ...@@ -264,25 +231,49 @@ def main(
is_flag=True, is_flag=True,
help="Perform a dry run, do not write to influx but dump to stdout", help="Perform a dry run, do not write to influx but dump to stdout",
) )
@click.option(
"--no-out",
is_flag=True,
help="Perform a dry run, but do not write any points to stdout",
)
@click.argument(
"router-fqdn", nargs=-1, callback=validate_hostname
)
def cli( def cli(
app_config_params: dict, app_config_params: dict,
juniper_fqdns: List[str], juniper: bool,
nokia_fqdns: List[str], nokia: bool,
source_dir: pathlib.Path, source_dir: pathlib.Path,
dry_run: bool, dry_run: bool,
no_out: bool,
router_fqdn,
): ):
if not router_fqdn:
# Do nothing if no routers are specified
return
if not (juniper ^ nokia):
raise click.BadParameter("Set either '--juniper' or '--nokia', but not both")
vendor = Vendor.JUNIPER if juniper else Vendor.NOKIA
writer_factory = partial(get_point_writer, kind="stdout" if dry_run else "influx") writer_kind = "influx"
if dry_run:
writer_kind = "stdout"
if no_out:
writer_kind = "noout"
writer_factory = partial(get_point_writer, kind=writer_kind)
netconf_provider_factory = partial(get_netconf_provider, source_dir=source_dir) netconf_provider_factory = partial(get_netconf_provider, source_dir=source_dir)
router_validator = get_router_list_validator(app_config_params)
setup_logging() setup_logging()
error_count = main( error_count = main(
app_config_params=app_config_params, app_config_params=app_config_params,
juniper_fqdns=juniper_fqdns, router_fqdns=router_fqdn,
nokia_fqdns=nokia_fqdns, vendor=vendor,
get_netconf_provider_=netconf_provider_factory, get_netconf_provider_=netconf_provider_factory,
get_point_writer_=writer_factory, get_point_writer_=writer_factory,
host_validator=router_validator,
) )
if error_count: if error_count:
raise click.ClickException( raise click.ClickException(
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
"username": "bogus-user", "username": "bogus-user",
"password": "bogus-password" "password": "bogus-password"
}, },
"inventory": ["http://blah"], "inventory": ["https://test-inprov01.geant.org"],
"influx": { "influx": {
"brian-counters": { "brian-counters": {
"hostname": "hostname", "hostname": "hostname",
......
...@@ -27,7 +27,7 @@ def get_netconf_provider( ...@@ -27,7 +27,7 @@ def get_netconf_provider(
class NetconfProvider(abc.ABC): class NetconfProvider(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def get(self, router_name: str) -> etree.Element: def get(self, router_name: str) -> etree.Element: # pragma: no cover
pass pass
......
import abc
import logging
from typing import Sequence
from brian_polling_manager.interface_stats.vendors import Vendor
from brian_polling_manager.inventory import load_interfaces
logger = logging.getLogger(__file__)
def get_router_list_validator(app_config_params):
inprov_hosts = app_config_params.get("inventory", None)
if inprov_hosts:
return InprovRouterListValidator(inprov_hosts)
else:
return DummyRouterListValidator()
class RouterListValidator(abc.ABC):
def validate_hosts(
self, hostnames: Sequence[str], vendor: Vendor
) -> bool: # pragma: no cover
pass
class InprovRouterListValidator(RouterListValidator):
def __init__(self, hostnames, load_interfaces_=load_interfaces) -> None:
self.hostnames = hostnames
self.load_interfaces = load_interfaces_
def _get_routers(self, vendor: Vendor):
if vendor == Vendor.NOKIA:
return []
return {ifc["router"] for ifc in self.load_interfaces(self.hostnames)}
def validate_hosts(self, hostnames: Sequence[str], vendor: Vendor):
logger.info(
f"Validating hosts {' '.join(hostnames)} using "
f"inventory providers {self.hostnames}"
)
all_fqdns = self._get_routers(vendor)
extra_fqdns = set(hostnames) - set(all_fqdns)
if extra_fqdns:
raise ValueError(
f"Routers are not in inventory provider or not {vendor.value}: "
f"{' '.join(extra_fqdns)}"
)
return True
class DummyRouterListValidator:
def validate_hosts(self, hostnames: Sequence[str], vendor: Vendor) -> bool:
return True
...@@ -15,11 +15,13 @@ def get_point_writer(influx_params: dict, kind="influx") -> WritePoints: ...@@ -15,11 +15,13 @@ def get_point_writer(influx_params: dict, kind="influx") -> WritePoints:
return InfluxPointWriter(influx_params=influx_params) return InfluxPointWriter(influx_params=influx_params)
if kind == "stdout": if kind == "stdout":
return StreamPointWriter(measurement=influx_params["measurement"]) return StreamPointWriter(measurement=influx_params["measurement"])
if kind == "noout":
return NothingWriter()
raise ValueError(f"unsupported writer kind '{kind}'") raise ValueError(f"unsupported writer kind '{kind}'")
class PointWriter(abc.ABC): class PointWriter(abc.ABC):
def write_points(self, points: Iterable[dict]): def write_points(self, points: Iterable[dict]): # pragma: no cover
pass pass
...@@ -51,3 +53,8 @@ class StreamPointWriter(PointWriter): ...@@ -51,3 +53,8 @@ class StreamPointWriter(PointWriter):
for point in points: for point in points:
self.stream.write(f"{self.measurement} - {json.dumps(point)}\n") self.stream.write(f"{self.measurement} - {json.dumps(point)}\n")
self.stream.flush() self.stream.flush()
class NothingWriter(PointWriter):
def write_points(self, points: Iterable[dict]):
pass
...@@ -21,7 +21,7 @@ setup( ...@@ -21,7 +21,7 @@ setup(
entry_points={ entry_points={
'console_scripts': [ 'console_scripts': [
'brian-polling-manager=brian_polling_manager.main:cli', 'brian-polling-manager=brian_polling_manager.main:cli',
'get-interface-stats=brian_polling_manager.interface_stats.cli:main' 'get-interface-stats=brian_polling_manager.interface_stats.cli:cli'
] ]
}, },
include_package_data=True, include_package_data=True,
......
...@@ -3,7 +3,10 @@ from functools import partial ...@@ -3,7 +3,10 @@ from functools import partial
import itertools import itertools
import re import re
from typing import Iterable from typing import Iterable
from unittest.mock import MagicMock, call, patch from unittest.mock import MagicMock, call
from brian_polling_manager.interface_stats.services.validation import (
DummyRouterListValidator,
)
from lxml import etree from lxml import etree
from brian_polling_manager.interface_stats.services.writers import ( from brian_polling_manager.interface_stats.services.writers import (
InfluxPointWriter, InfluxPointWriter,
...@@ -14,11 +17,9 @@ import pytest ...@@ -14,11 +17,9 @@ import pytest
from brian_polling_manager import influx from brian_polling_manager import influx
from brian_polling_manager.interface_stats import cli from brian_polling_manager.interface_stats import cli
from brian_polling_manager.interface_stats.services.netconf import ( from brian_polling_manager.interface_stats.services.netconf import (
JuniperNetconfProvider,
get_netconf_provider, get_netconf_provider,
) )
from brian_polling_manager.interface_stats.vendors import common, juniper from brian_polling_manager.interface_stats.vendors import Vendor, common, juniper
from ncclient.operations.rpc import RPCReply
def test_sanity_check_snapshot_data(polled_interfaces, all_juniper_routers): def test_sanity_check_snapshot_data(polled_interfaces, all_juniper_routers):
...@@ -374,50 +375,17 @@ def test_main_for_all_juniper_routers(all_juniper_routers, data_dir): ...@@ -374,50 +375,17 @@ def test_main_for_all_juniper_routers(all_juniper_routers, data_dir):
writer = ValidatingPointWriter() writer = ValidatingPointWriter()
cli.main( cli.main(
app_config_params=config, app_config_params=config,
juniper_fqdns=all_juniper_routers, router_fqdns=all_juniper_routers,
nokia_fqdns=[], vendor=Vendor.JUNIPER,
get_point_writer_=lambda _: writer, get_point_writer_=lambda _: writer,
get_netconf_provider_=partial(get_netconf_provider, source_dir=data_dir), get_netconf_provider_=partial(get_netconf_provider, source_dir=data_dir),
host_validator=DummyRouterListValidator(),
raise_errors=True, raise_errors=True,
) )
assert writer.calls == 104 assert writer.calls == 104
assert writer.total_points == 6819 assert writer.total_points == 6819
class TestJuniperNetconfProvider:
RAW_RESPONSE_FILE = "raw-response-sample.xml"
@pytest.fixture(autouse=True)
def mocked_rpc(self, data_dir):
raw_response = data_dir.joinpath(self.RAW_RESPONSE_FILE).read_text()
with patch(
"brian_polling_manager.interface_stats.services.netconf._rpc",
return_value=RPCReply(raw_response),
) as mock:
yield mock
def test_calls_rpc_with_params(self, mocked_rpc):
ssh_params = {}
router_name = "some-router"
provider = JuniperNetconfProvider(ssh_params)
provider.get(router_name)
call_args = mocked_rpc.call_args
assert call_args[0][0] == router_name
assert call_args[1]["ssh_params"] is ssh_params
assert call_args[1]["command"].tag == "get-interface-information"
def test_converts_rpc_response_to_xml(self):
ssh_params = {}
router_name = "some-router"
provider = JuniperNetconfProvider(ssh_params)
result = provider.get(router_name)
assert result.tag == "rpc-reply"
def test_influx_point_writer(): def test_influx_point_writer():
client_factory = MagicMock() client_factory = MagicMock()
influx_params = {"some": "param"} influx_params = {"some": "param"}
......
...@@ -236,7 +236,6 @@ def app_config_params(free_host_port): ...@@ -236,7 +236,6 @@ def app_config_params(free_host_port):
"ssh_config": f.name, "ssh_config": f.name,
"hostkey_verify": False, "hostkey_verify": False,
}, },
"inventory": ["http://blah"],
"influx": { "influx": {
"brian-counters": { "brian-counters": {
"hostname": "localhost", "hostname": "localhost",
...@@ -277,9 +276,14 @@ def test_e2e( ...@@ -277,9 +276,14 @@ def test_e2e(
all potential counter fields are populated all potential counter fields are populated
""" """
cli_args = ["--config", app_config_filename, "--source-dir", str(data_dir)] cli_args = [
for router_fqdn in all_juniper_routers: "--config",
cli_args.extend(["--juniper", router_fqdn]) app_config_filename,
"--source-dir",
str(data_dir),
"--juniper",
*all_juniper_routers,
]
runner = CliRunner() runner = CliRunner()
result = runner.invoke(cli.cli, cli_args) result = runner.invoke(cli.cli, cli_args)
assert result.exit_code == 0, str(result) assert result.exit_code == 0, str(result)
......
from unittest.mock import Mock, call, patch
from brian_polling_manager.interface_stats.services.netconf import (
JuniperNetconfProvider,
)
from brian_polling_manager.interface_stats.services.validation import (
InprovRouterListValidator,
)
from brian_polling_manager.interface_stats.services.writers import StreamPointWriter
from brian_polling_manager.interface_stats.vendors import Vendor
from ncclient.operations.rpc import RPCReply
import pytest
class TestJuniperNetconfProvider:
RAW_RESPONSE_FILE = "raw-response-sample.xml"
@pytest.fixture(autouse=True)
def mocked_rpc(self, data_dir):
raw_response = data_dir.joinpath(self.RAW_RESPONSE_FILE).read_text()
with patch(
"brian_polling_manager.interface_stats.services.netconf._rpc",
return_value=RPCReply(raw_response),
) as mock:
yield mock
def test_calls_rpc_with_params(self, mocked_rpc):
ssh_params = {}
router_name = "some-router"
provider = JuniperNetconfProvider(ssh_params)
provider.get(router_name)
call_args = mocked_rpc.call_args
assert call_args[0][0] == router_name
assert call_args[1]["ssh_params"] is ssh_params
assert call_args[1]["command"].tag == "get-interface-information"
def test_converts_rpc_response_to_xml(self):
ssh_params = {}
router_name = "some-router"
provider = JuniperNetconfProvider(ssh_params)
result = provider.get(router_name)
assert result.tag == "rpc-reply"
class TestInprovRouterListValidator:
@pytest.fixture
def validator(self):
validator = InprovRouterListValidator(
hostnames=object(), load_interfaces_=Mock()
)
validator.load_interfaces.return_value = [
{"router": "host1"},
{"router": "host2"},
{"router": "host3"},
]
return validator
def test_validate_valid_hosts(self, validator):
assert validator.validate_hosts(("host1", "host2"), vendor=Vendor.JUNIPER)
def test_validate_invalid_hosts(self, validator):
with pytest.raises(ValueError):
validator.validate_hosts(("host1", "invalid"), vendor=Vendor.JUNIPER)
class TestStreamPointWriter:
@pytest.fixture
def writer(self):
return StreamPointWriter(measurement="measurement", stream=Mock())
def test_writes_points_to_stream(self, writer):
writer.write_points([{"point": "one"}, {"point": "two"}])
assert writer.stream.write.call_args_list == [
call('measurement - {"point": "one"}\n'),
call('measurement - {"point": "two"}\n'),
]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment