From e406647496c4705bfd8ce2b1a759b66785ba5eaa Mon Sep 17 00:00:00 2001
From: Erik Reid <erik.reid@geant.org>
Date: Thu, 17 Feb 2022 10:50:47 +0100
Subject: [PATCH] catch some missed pyez errors

---
 inventory_provider/juniper.py | 22 ++++++++++++++--------
 1 file changed, 14 insertions(+), 8 deletions(-)

diff --git a/inventory_provider/juniper.py b/inventory_provider/juniper.py
index 19761764..3f1f94a3 100644
--- a/inventory_provider/juniper.py
+++ b/inventory_provider/juniper.py
@@ -1,3 +1,4 @@
+import contextlib
 import logging
 import re
 import ipaddress
@@ -114,6 +115,7 @@ class NetconfHandlingError(Exception):
     pass
 
 
+@contextlib.contextmanager
 def _rpc(hostname, ssh):
     dev = Device(
         host=hostname,
@@ -121,9 +123,9 @@ def _rpc(hostname, ssh):
         ssh_private_key_file=ssh['private-key'])
     try:
         dev.open()
-    except (EzErrors.ConnectError, EzErrors.RpcError) as e:
-        raise ConnectionError(str(e))
-    return dev.rpc
+        yield dev.rpc
+    finally:
+        dev.close()
 
 
 def validate_netconf_config(config_doc):
@@ -166,14 +168,18 @@ def load_config(hostname, ssh_params, validate=True):
     :param ssh_params: 'ssh' config element(cf. config.py:CONFIG_SCHEMA)
     :param validate: whether or not to validate netconf data (default True)
     :return:
-    :raises: NetconfHandlingError from validate_netconf_config
+    :raises: NetconfHandlingError or ConnectionError
     """
     logger = logging.getLogger(__name__)
     logger.info("capturing netconf data for '%s'" % hostname)
-    config = _rpc(hostname, ssh_params).get_config()
-    if validate:
-        validate_netconf_config(config)
-    return config
+    try:
+        with _rpc(hostname, ssh_params) as router:
+            config = router.get_config()
+        if validate:
+            validate_netconf_config(config)
+        return config
+    except (EzErrors.ConnectError, EzErrors.RpcError) as e:
+        raise ConnectionError(str(e))
 
 
 def list_interfaces(netconf_config):
-- 
GitLab