From 6eebf4c30abe62a2ca769f2e7b600e8af6b781b6 Mon Sep 17 00:00:00 2001
From: Samuel Roberts <sam.roberts@geant.org>
Date: Tue, 24 Oct 2023 15:27:56 +0100
Subject: [PATCH] fix bugs, make the update chords work as intended

---
 inventory_provider/juniper.py      | 14 +++++++++++---
 inventory_provider/tasks/worker.py | 22 +++++++++++++---------
 2 files changed, 24 insertions(+), 12 deletions(-)

diff --git a/inventory_provider/juniper.py b/inventory_provider/juniper.py
index 09da2a06..935d293e 100644
--- a/inventory_provider/juniper.py
+++ b/inventory_provider/juniper.py
@@ -11,6 +11,7 @@ from jnpr.junos import exception as EzErrors
 from lxml import etree
 import netifaces
 import requests
+from ncclient.xml_ import to_xml
 
 CONFIG_SCHEMA = """<?xml version="1.1" encoding="UTF-8" ?>
 <xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
@@ -119,6 +120,10 @@ class NetconfHandlingError(Exception):
     pass
 
 
+class TimeoutError(Exception):
+    pass
+
+
 TIMEOUT = 10.0
 
 
@@ -135,8 +140,10 @@ def _nc_connection(host_params, ssh_params):
 
     try:
         yield conn  # wait here until caller context ends
+    except EzErrors.ConnectTimeoutError:
+        raise TimeoutError
     finally:
-        conn.close()
+        conn.close_session()
 
 
 def _raw_rpc(router, command):
@@ -144,7 +151,8 @@ def _raw_rpc(router, command):
     # query for router configs
     # this is needed for querying for other things eg. interface speeds
     obj = router.rpc(command)
-    return obj.reply
+    xml = obj.tostring
+    return xml
 
 
 @contextlib.contextmanager
@@ -487,7 +495,7 @@ def get_interface_info_for_router(hostname, ssh_config):
             host_params=host_params,
             ssh_params=ssh_config) as router:
         reply = _raw_rpc(router, etree.Element('get-interface-information'))
-        return reply.xml
+        return reply
 
 
 def get_interface_speeds(interface_info):
diff --git a/inventory_provider/tasks/worker.py b/inventory_provider/tasks/worker.py
index 91c83f14..b12548c5 100644
--- a/inventory_provider/tasks/worker.py
+++ b/inventory_provider/tasks/worker.py
@@ -451,16 +451,18 @@ def update_entry_point(self):
         chord(
             (
                 ims_task.s().on_error(task_error_handler.s()),
-                chord(
-                    (reload_router_config_chorded.s(r) for r in routers),
-                    empty_task.si('router tasks complete')
+                chord((
+                    chord(
+                        (reload_router_config_chorded.s(r) for r in routers),
+                        empty_task.si('router tasks complete')
+                    ),
+                    chord(
+                        (reload_lab_router_config_chorded.s(r)
+                         for r in lab_routers),
+                        empty_task.si('lab router tasks complete')
+                    )),
+                    collate_netconf_interfaces_all_cache.si().on_error(task_error_handler.s())
                 ),
-                chord(
-                    (reload_lab_router_config_chorded.s(r)
-                     for r in lab_routers),
-                    empty_task.si('lab router tasks complete')
-                ),
-                collate_netconf_interfaces_all_cache.s().on_error(task_error_handler.s())
             ),
             final_task.si().on_error(task_error_handler.s())
         )()
@@ -1462,6 +1464,8 @@ def collate_netconf_interfaces_all_cache(warning_callback=lambda s: None):
         for k in r.scan_iter(key_pattern, count=1000):
             key = k.decode('utf-8')
             doc_str = r.get(key).decode('utf-8')
+            doc = json.loads(doc_str)
+            doc['hostname'] = k.split(':')[1]  # get hostname part of key
             yield json.loads(doc_str)
 
     netconf_all_key = 'netconf-interfaces:all'
-- 
GitLab