From 99254b27c10dab8c28df699bb636e64401c1acff Mon Sep 17 00:00:00 2001
From: Erik Reid <erik.reid@geant.org>
Date: Sun, 14 Jul 2019 14:26:20 +0200
Subject: [PATCH] use get_next/current_redis

---
 inventory_provider/routes/classifier.py | 16 +++----
 inventory_provider/routes/data.py       |  4 +-
 inventory_provider/routes/poller.py     |  2 +-
 inventory_provider/routes/testing.py    | 22 ++++-----
 inventory_provider/tasks/worker.py      | 60 +++++++++----------------
 5 files changed, 42 insertions(+), 62 deletions(-)

diff --git a/inventory_provider/routes/classifier.py b/inventory_provider/routes/classifier.py
index fb09dad0..ccd0ac12 100644
--- a/inventory_provider/routes/classifier.py
+++ b/inventory_provider/routes/classifier.py
@@ -50,7 +50,7 @@ def base_interface_name(interface):
 
 
 def related_interfaces(hostname, interface):
-    r = common.get_redis()
+    r = common.get_current_redis()
     prefix = 'netconf-interfaces:%s:' % hostname
     for k in r.keys(prefix + base_interface_name(interface) + '.*'):
         k = k.decode('utf-8')
@@ -63,7 +63,7 @@ def related_interfaces(hostname, interface):
               methods=['GET', 'POST'])
 @common.require_accepts_json
 def get_juniper_link_info(source_equipment, interface):
-    r = common.get_redis()
+    r = common.get_current_redis()
 
     cache_key = 'classifier-cache:juniper:%s:%s' % (
         source_equipment, interface)
@@ -137,7 +137,7 @@ def ix_peering_info(peer_info):
     protocol = type(address).__name__
     keyword = description.split(' ')[0]  # regex needed??? (e.g. tabs???)
 
-    r = common.get_redis()
+    r = common.get_current_redis()
 
     for k in r.keys('ix_public_peer:*'):
         other = r.get(k.decode('utf-8')).decode('utf-8')
@@ -165,7 +165,7 @@ def find_interfaces(address):
     :param address: an ipaddress object
     :return:
     """
-    r = common.get_redis()
+    r = common.get_current_redis()
     for k in r.keys('reverse_interface_addresses:*'):
         info = r.get(k.decode('utf-8')).decode('utf-8')
         info = json.loads(info)
@@ -187,7 +187,7 @@ def find_interfaces_and_services(address_str):
         raise ClassifierProcessingError(
             'unable to parse %r as an ip address' % address_str)
 
-    r = common.get_redis()
+    r = common.get_current_redis()
     for interface in find_interfaces(address):
 
         services = r.get(
@@ -209,7 +209,7 @@ def find_interfaces_and_services(address_str):
 @common.require_accepts_json
 def peer_info(address):
 
-    r = common.get_redis()
+    r = common.get_current_redis()
 
     cache_key = 'classifier-cache:peer:%s' % address
 
@@ -257,7 +257,7 @@ def get_trap_metadata(source_equipment, interface, circuit_id):
     cache_key = 'classifier-cache:infinera:%s:%s' % (
         source_equipment, interface)
 
-    r = common.get_redis()
+    r = common.get_current_redis()
     result = r.get(cache_key)
 
     if result:
@@ -294,7 +294,7 @@ def get_trap_metadata(source_equipment, interface, circuit_id):
               methods=['GET', 'POST'])
 @common.require_accepts_json
 def get_coriant_info(equipment_name, entity_string):
-    r = common.get_redis()
+    r = common.get_current_redis()
 
     cache_key = 'classifier-cache:coriant:%s:%s' % (
         equipment_name, entity_string)
diff --git a/inventory_provider/routes/data.py b/inventory_provider/routes/data.py
index 905b2cde..f5f780ce 100644
--- a/inventory_provider/routes/data.py
+++ b/inventory_provider/routes/data.py
@@ -18,7 +18,7 @@ def after_request(resp):
 @routes.route("/routers", methods=['GET', 'POST'])
 @common.require_accepts_json
 def routers():
-    r = common.get_redis()
+    r = common.get_current_redis()
     result = []
     for k in r.keys('netconf:*'):
         m = re.match('^netconf:(.+)$', k.decode('utf-8'))
@@ -30,7 +30,7 @@ def routers():
 @routes.route("/interfaces/<hostname>", methods=['GET', 'POST'])
 @common.require_accepts_json
 def router_interfaces(hostname):
-    r = common.get_redis()
+    r = common.get_current_redis()
     interfaces = []
     for k in r.keys('netconf-interfaces:%s:*' % hostname):
         ifc = r.get(k.decode('utf-8'))
diff --git a/inventory_provider/routes/poller.py b/inventory_provider/routes/poller.py
index a3f890e5..8497b2cc 100644
--- a/inventory_provider/routes/poller.py
+++ b/inventory_provider/routes/poller.py
@@ -16,7 +16,7 @@ def after_request(resp):
 @routes.route('/interfaces/<hostname>', methods=['GET', 'POST'])
 @common.require_accepts_json
 def poller_interface_oids(hostname):
-    r = common.get_redis()
+    r = common.get_current_redis()
 
     netconf_string = r.get('netconf:' + hostname)
     if not netconf_string:
diff --git a/inventory_provider/routes/testing.py b/inventory_provider/routes/testing.py
index 2ee9da12..66d38912 100644
--- a/inventory_provider/routes/testing.py
+++ b/inventory_provider/routes/testing.py
@@ -14,7 +14,7 @@ routes = Blueprint("inventory-data-testing-support-routes", __name__)
 
 @routes.route("flushdb", methods=['GET', 'POST'])
 def flushdb():
-    common.get_redis().flushdb()
+    common.get_current_redis().flushdb()
     return Response('OK')
 
 
@@ -46,7 +46,7 @@ def update_interface_statuses():
 @common.require_accepts_json
 def juniper_addresses():
     # TODO: this route (and corant, infinera routes) can be removed
-    r = common.get_redis()
+    r = common.get_current_redis()
     routers = []
     for k in r.keys('junosspace:*'):
         info = r.get(k.decode('utf-8'))
@@ -58,7 +58,7 @@ def juniper_addresses():
 
 @routes.route("opsdb/interfaces")
 def get_all_interface_details():
-    r = common.get_redis()
+    r = common.get_current_redis()
     result = collections.defaultdict(list)
     for k in r.keys('opsdb:interface_services:*'):
         m = re.match(
@@ -71,7 +71,7 @@ def get_all_interface_details():
 
 @routes.route("opsdb/interfaces/<equipment_name>")
 def get_interface_details_for_equipment(equipment_name):
-    r = common.get_redis()
+    r = common.get_current_redis()
     result = []
     for k in r.keys('opsdb:interface_services:%s:*' % equipment_name):
         m = re.match(
@@ -84,7 +84,7 @@ def get_interface_details_for_equipment(equipment_name):
 
 @routes.route("opsdb/interfaces/<equipment_name>/<path:interface>")
 def get_interface_details(equipment_name, interface):
-    r = common.get_redis()
+    r = common.get_current_redis()
     key = 'opsdb:interface_services:%s:%s' % (equipment_name, interface)
     # TODO: handle None (return 404)
     return jsonify(json.loads(r.get(key).decode('utf-8')))
@@ -92,7 +92,7 @@ def get_interface_details(equipment_name, interface):
 
 @routes.route("opsdb/equipment-location")
 def get_all_equipment_locations():
-    r = common.get_redis()
+    r = common.get_current_redis()
     result = {}
     for k in r.keys('opsdb:location:*'):
         k = k.decode('utf-8')
@@ -104,7 +104,7 @@ def get_all_equipment_locations():
 
 @routes.route("opsdb/equipment-location/<path:equipment_name>")
 def get_equipment_location(equipment_name):
-    r = common.get_redis()
+    r = common.get_current_redis()
     result = r.get('opsdb:location:' + equipment_name)
     # TODO: handle None (return 404)
     return jsonify(json.loads(result.decode('utf-8')))
@@ -112,7 +112,7 @@ def get_equipment_location(equipment_name):
 
 @routes.route("opsdb/circuit-hierarchy/children/<int:parent_id>")
 def get_children(parent_id):
-    r = common.get_redis()
+    r = common.get_current_redis()
     result = r.get('opsdb:services:children:%d' % parent_id)
     # TODO: handle None (return 404)
     return jsonify(json.loads(result.decode('utf-8')))
@@ -120,7 +120,7 @@ def get_children(parent_id):
 
 @routes.route("opsdb/circuit-hierarchy/parents/<int:child_id>")
 def get_parents(child_id):
-    r = common.get_redis()
+    r = common.get_current_redis()
     result = r.get('opsdb:services:parents:%d' % child_id)
     # TODO: handle None (return 404)
     return jsonify(json.loads(result.decode('utf-8')))
@@ -129,7 +129,7 @@ def get_parents(child_id):
 @routes.route("bgp/<hostname>", methods=['GET', 'POST'])
 @common.require_accepts_json
 def bgp_configs(hostname):
-    r = common.get_redis()
+    r = common.get_current_redis()
     netconf_string = r.get('netconf:' + hostname)
     if not netconf_string:
         return Response(
@@ -151,7 +151,7 @@ def bgp_configs(hostname):
 @routes.route("snmp/<hostname>", methods=['GET', 'POST'])
 @common.require_accepts_json
 def snmp_ids(hostname):
-    r = common.get_redis()
+    r = common.get_next_redis()
     ifc_data_string = r.get('snmp-interfaces:' + hostname)
     ifc_data = json.loads(ifc_data_string.decode('utf-8'))
     return jsonify(ifc_data)
diff --git a/inventory_provider/tasks/worker.py b/inventory_provider/tasks/worker.py
index ccf9edb7..26d44863 100644
--- a/inventory_provider/tasks/worker.py
+++ b/inventory_provider/tasks/worker.py
@@ -10,7 +10,7 @@ from collections import defaultdict
 from lxml import etree
 
 from inventory_provider.tasks.app import app
-from inventory_provider.tasks.common import get_redis
+from inventory_provider.tasks.common import get_next_redis
 from inventory_provider import config
 from inventory_provider import environment
 from inventory_provider.db import db, opsdb
@@ -55,38 +55,16 @@ class InventoryTask(Task):
         super().update_state(**kwargs)
 
 
-def _save_value(key, value):
-    assert isinstance(value, str), \
-        "sanity failure: expected string data as value"
-    r = get_redis(InventoryTask.config)
-    r.set(name=key, value=value)
-    # InventoryTask.logger.debug("saved %s" % key)
-    return "OK"
-
-
-def _save_value_json(key, data_obj):
-    _save_value(
-        key,
-        json.dumps(data_obj))
-
-
-def _save_value_etree(key, xml_doc):
-    _save_value(
-        key,
-        etree.tostring(xml_doc, encoding='unicode'))
-
-
 @app.task
 def snmp_refresh_interfaces(hostname, community):
     logger = logging.getLogger(__name__)
     logger.debug(
         '>>> snmp_refresh_interfaces(%r, %r)' % (hostname, community))
 
-    _save_value_json(
-        'snmp-interfaces:' + hostname,
-        list(snmp.get_router_snmp_indexes(
-            hostname,
-            community)))
+    value = list(snmp.get_router_snmp_indexes(hostname, community))
+
+    r = get_next_redis(InventoryTask.config)
+    r.set('snmp-interfaces:' + hostname, json.dumps(value))
 
     logger.debug(
         '<<< snmp_refresh_interfaces(%r, %r)' % (hostname, community))
@@ -97,9 +75,11 @@ def netconf_refresh_config(hostname):
     logger = logging.getLogger(__name__)
     logger.debug('>>> netconf_refresh_config(%r)' % hostname)
 
-    _save_value_etree(
-        'netconf:' + hostname,
-        juniper.load_config(hostname, InventoryTask.config["ssh"]))
+    netconf_doc = juniper.load_config(hostname, InventoryTask.config["ssh"])
+    netconf_str = etree.tostring(netconf_doc, encoding='unicode')
+
+    r = get_next_redis(InventoryTask.config)
+    r.set('netconf:' + hostname, netconf_str)
 
     logger.debug('<<< netconf_refresh_config(%r)' % hostname)
 
@@ -116,7 +96,7 @@ def update_interfaces_to_services():
                 service['equipment'], service['interface_name'])
             interface_services[equipment_interface].append(service)
 
-    r = get_redis(InventoryTask.config)
+    r = get_next_redis(InventoryTask.config)
     for key in r.scan_iter('opsdb:interface_services:*'):
         r.delete(key)
     for equipment_interface, services in interface_services.items():
@@ -132,7 +112,7 @@ def update_equipment_locations():
     logger = logging.getLogger(__name__)
     logger.debug('>>> update_equipment_locations')
 
-    r = get_redis(InventoryTask.config)
+    r = get_next_redis(InventoryTask.config)
     for k in r.scan_iter('opsdb:location:*'):
         r.delete(k)
 
@@ -161,7 +141,7 @@ def update_circuit_hierarchy():
             parent_to_children[parent_id].append(relation)
             child_to_parents[child_id].append(relation)
 
-        r = get_redis(InventoryTask.config)
+        r = get_next_redis(InventoryTask.config)
         for key in r.scan_iter('opsdb:services:parents:*'):
             r.delete(key)
         for cid, parents in child_to_parents.items():
@@ -180,7 +160,7 @@ def update_geant_lambdas():
     logger = logging.getLogger(__name__)
     logger.debug('>>> update_geant_lambdas')
 
-    r = get_redis(InventoryTask.config)
+    r = get_next_redis(InventoryTask.config)
     for key in r.scan_iter('opsdb:geant_lambdas:*'):
         r.delete(key)
     with db.connection(InventoryTask.config["ops-db"]) as cx:
@@ -203,7 +183,7 @@ def update_junosspace_device_list(self):
             'message': 'querying junosspace for managed routers'
         })
 
-    r = get_redis(InventoryTask.config)
+    r = get_next_redis(InventoryTask.config)
 
     routers = {}
     for d in juniper.load_routers_from_junosspace(
@@ -237,7 +217,7 @@ def load_netconf_data(hostname):
     :param hostname:
     :return:
     """
-    r = get_redis(InventoryTask.config)
+    r = get_next_redis(InventoryTask.config)
     netconf = r.get('netconf:' + hostname)
     if not netconf:
         raise InventoryTaskError('no netconf data found for %r' % hostname)
@@ -252,7 +232,7 @@ def clear_cached_classifier_responses(hostname=None):
     else:
         logger.debug('removing all cached classifier responses')
 
-    r = get_redis(InventoryTask.config)
+    r = get_next_redis(InventoryTask.config)
 
     def _hostname_keys():
         for k in r.keys('classifier-cache:juniper:%s:*' % hostname):
@@ -281,7 +261,7 @@ def _refresh_peers(hostname, key_base, peers):
     logger = logging.getLogger(__name__)
     logger.debug(
         'removing cached %s for %r' % (key_base, hostname))
-    r = get_redis(InventoryTask.config)
+    r = get_next_redis(InventoryTask.config)
     for k in r.keys(key_base + ':*'):
         # potential race condition: another proc could have
         # delete this element between the time we read the
@@ -325,7 +305,7 @@ def refresh_juniper_interface_list(hostname, netconf):
     logger.debug(
         'removing cached netconf-interfaces for %r' % hostname)
 
-    r = get_redis(InventoryTask.config)
+    r = get_next_redis(InventoryTask.config)
     for k in r.keys('netconf-interfaces:%s:*' % hostname):
         r.delete(k)
 
@@ -433,7 +413,7 @@ def reload_router_config(self, hostname):
 
 def _derive_router_hostnames(config):
     logger = logging.getLogger(__name__)
-    r = get_redis(config)
+    r = get_next_redis(config)
     junosspace_equipment = set()
     for k in r.keys('junosspace:*'):
         m = re.match('^junosspace:(.*)$', k.decode('utf-8'))
-- 
GitLab