diff --git a/inventory_provider/juniper.py b/inventory_provider/juniper.py index 1976176493c60059a3b0a969039da2d102f69117..3f1f94a3860a18dfd7ed0aba68b69d1718421ac5 100644 --- a/inventory_provider/juniper.py +++ b/inventory_provider/juniper.py @@ -1,3 +1,4 @@ +import contextlib import logging import re import ipaddress @@ -114,6 +115,7 @@ class NetconfHandlingError(Exception): pass +@contextlib.contextmanager def _rpc(hostname, ssh): dev = Device( host=hostname, @@ -121,9 +123,9 @@ def _rpc(hostname, ssh): ssh_private_key_file=ssh['private-key']) try: dev.open() - except (EzErrors.ConnectError, EzErrors.RpcError) as e: - raise ConnectionError(str(e)) - return dev.rpc + yield dev.rpc + finally: + dev.close() def validate_netconf_config(config_doc): @@ -166,14 +168,18 @@ def load_config(hostname, ssh_params, validate=True): :param ssh_params: 'ssh' config element(cf. config.py:CONFIG_SCHEMA) :param validate: whether or not to validate netconf data (default True) :return: - :raises: NetconfHandlingError from validate_netconf_config + :raises: NetconfHandlingError or ConnectionError """ logger = logging.getLogger(__name__) logger.info("capturing netconf data for '%s'" % hostname) - config = _rpc(hostname, ssh_params).get_config() - if validate: - validate_netconf_config(config) - return config + try: + with _rpc(hostname, ssh_params) as router: + config = router.get_config() + if validate: + validate_netconf_config(config) + return config + except (EzErrors.ConnectError, EzErrors.RpcError) as e: + raise ConnectionError(str(e)) def list_interfaces(netconf_config):