From cf4c8d3d0ab21d47c24b23ef02694facd2a69eb4 Mon Sep 17 00:00:00 2001
From: Erik Reid <erik.reid@geant.org>
Date: Wed, 14 Nov 2018 12:39:18 +0100
Subject: [PATCH] mock redis for testing

---
 inventory_provider/__init__.py          |   2 +-
 inventory_provider/data_routes.py       |  15 +++-
 inventory_provider/router_details.py    | 113 ++++++++++++++++++++++++
 inventory_provider/router_interfaces.py | 107 +---------------------
 requirements.txt                        |   1 +
 test/test_data_routes.py                |  32 ++++++-
 6 files changed, 162 insertions(+), 108 deletions(-)
 create mode 100644 inventory_provider/router_details.py

diff --git a/inventory_provider/__init__.py b/inventory_provider/__init__.py
index 2918d68e..44220ddd 100644
--- a/inventory_provider/__init__.py
+++ b/inventory_provider/__init__.py
@@ -36,7 +36,7 @@ def create_app():
     from inventory_provider import config
     with open(app.config["INVENTORY_PROVIDER_CONFIG_FILENAME"]) as f:
         # test the config file can be loaded
-        config.load(f)
+        app.config["INVENTORY_PROVIDER_CONFIG"] = config.load(f)
 
     logging.debug(app.config)
 
diff --git a/inventory_provider/data_routes.py b/inventory_provider/data_routes.py
index 3c3238ed..4496c185 100644
--- a/inventory_provider/data_routes.py
+++ b/inventory_provider/data_routes.py
@@ -1,8 +1,9 @@
 import functools
 import json
 
-from flask import Blueprint, request, Response
+from flask import Blueprint, request, Response, current_app
 # render_template, url_for
+import redis
 
 routes = Blueprint("python-utils-ui-routes", __name__)
 
@@ -38,3 +39,15 @@ def version():
         json.dumps(VERSION),
         mimetype="application/json"
     )
+
+
+@routes.route("/abc", methods=['GET', 'POST'])
+@require_accepts_json
+def abc():
+    redis_config = current_app.config["INVENTORY_PROVIDER_CONFIG"]["redis"]
+    r = redis.StrictRedis(
+        host=redis_config["hostname"],
+        port=redis_config["port"])
+    return Response(
+        json.dumps(r.keys("*")),
+        mimetype="application/json")
diff --git a/inventory_provider/router_details.py b/inventory_provider/router_details.py
new file mode 100644
index 00000000..fe38136e
--- /dev/null
+++ b/inventory_provider/router_details.py
@@ -0,0 +1,113 @@
+import json
+import logging
+from multiprocessing import Process, Queue
+
+import redis
+
+from inventory_provider import constants
+from inventory_provider import snmp
+from inventory_provider import juniper
+
+
+def get_router_interfaces_q(router, params, q):
+    threading_logger = logging.getLogger(constants.THREADING_LOGGER_NAME)
+    threading_logger.debug("[ENTER>>] get_router_interfaces_q: %r" % router)
+    q.put(list(snmp.get_router_interfaces(router, params)))
+    threading_logger.debug("[<<EXIT]  get_router_interfaces_q: %r" % router)
+
+
+def ssh_exec_commands_q(hostname, ssh_params, commands, q):
+    threading_logger = logging.getLogger(constants.THREADING_LOGGER_NAME)
+    threading_logger.debug("[ENTER>>] exec_router_commands_q: %r" % hostname)
+    q.put(list(juniper.ssh_exec_commands(hostname, ssh_params, commands)))
+    threading_logger.debug("[<<EXIT] exec_router_commands_q: %r" % hostname)
+
+
+def get_router_details(router, params):
+
+    threading_logger = logging.getLogger(constants.THREADING_LOGGER_NAME)
+
+    threading_logger.debug("[ENTER>>]get_router_details: %r" % router)
+
+    commands = list(juniper.shell_commands())
+
+    snmpifc_proc_queue = Queue()
+    snmpifc_proc = Process(
+        target=get_router_interfaces_q,
+        args=(router, params, snmpifc_proc_queue))
+    snmpifc_proc.start()
+
+    commands_proc_queue = Queue()
+    commands_proc = Process(
+        target=ssh_exec_commands_q,
+        args=(
+            router["hostname"],
+            params["ssh"],
+            [c["command"] for c in commands],
+            commands_proc_queue))
+    commands_proc.start()
+
+    threading_logger.debug("waiting for commands result: %r" % router)
+    command_output = commands_proc_queue.get()
+    assert len(command_output) == len(commands)
+
+    r = redis.StrictRedis(
+        host=params["redis"]["hostname"],
+        port=params["redis"]["port"])
+    for c, o in zip(commands, command_output):
+        if c["key"]:
+            r.hset(
+                name=router["hostname"],
+                key=c["key"],
+                value=json.dumps(c["parser"](o)))
+    commands_proc.join()
+    threading_logger.debug("... got commands result & joined: %r" % router)
+
+    threading_logger.debug("waiting for snmp ifc results: %r" % router)
+    r.hset(
+        name=router["hostname"],
+        key="snmp-interfaces",
+        value=json.dumps(snmpifc_proc_queue.get()))
+    snmpifc_proc.join()
+    threading_logger.debug("... got snmp ifc result & joined: %r" % router)
+
+    threading_logger.debug("[<<EXIT]get_router_details: %r" % router)
+
+
+def update_network_details(params):
+
+    threading_logger = logging.getLogger(constants.THREADING_LOGGER_NAME)
+
+    processes = []
+    for r in params["routers"]:
+        p = Process(target=get_router_details, args=(r, params))
+        p.start()
+        processes.append({"router": r, "process": p})
+
+    result = {}
+    for p in processes:
+        threading_logger.debug(
+            "waiting for get_router_details result: %r" % p["router"])
+        p["process"].join()
+        threading_logger.debug(
+            "got result and joined get_router_details proc: %r" % p["router"])
+
+    return result
+
+
+def load_network_details(redis_params):
+
+    r = redis.StrictRedis(
+        host=redis_params["hostname"],
+        port=redis_params["port"])
+
+    result = {}
+    for hostname in r.keys():
+        host = {}
+        for key in r.hkeys(name=hostname):
+            host[key.decode("utf-8")] = json.loads(
+                r.hget(hostname, key).decode("utf-8")
+            )
+        result[hostname.decode("utf-8")] = host
+
+    return result
diff --git a/inventory_provider/router_interfaces.py b/inventory_provider/router_interfaces.py
index 95a09be5..431ed641 100644
--- a/inventory_provider/router_interfaces.py
+++ b/inventory_provider/router_interfaces.py
@@ -1,102 +1,13 @@
 import json
 import logging
-from multiprocessing import Process, Queue
 
 import click
-import redis
 
 from inventory_provider import constants
-from inventory_provider import snmp
-from inventory_provider import juniper
+from inventory_provider import router_details
 from inventory_provider import config
 
 
-def get_router_interfaces_q(router, params, q):
-    threading_logger = logging.getLogger(constants.THREADING_LOGGER_NAME)
-    threading_logger.debug("[ENTER>>] get_router_interfaces_q: %r" % router)
-    q.put(list(snmp.get_router_interfaces(router, params)))
-    threading_logger.debug("[<<EXIT]  get_router_interfaces_q: %r" % router)
-
-
-def ssh_exec_commands_q(hostname, ssh_params, commands, q):
-    threading_logger = logging.getLogger(constants.THREADING_LOGGER_NAME)
-    threading_logger.debug("[ENTER>>] exec_router_commands_q: %r" % hostname)
-    q.put(list(juniper.ssh_exec_commands(hostname, ssh_params, commands)))
-    threading_logger.debug("[<<EXIT] exec_router_commands_q: %r" % hostname)
-
-
-def get_router_details(router, params):
-
-    threading_logger = logging.getLogger(constants.THREADING_LOGGER_NAME)
-
-    threading_logger.debug("[ENTER>>]get_router_details: %r" % router)
-
-    commands = list(juniper.shell_commands())
-
-    snmpifc_proc_queue = Queue()
-    snmpifc_proc = Process(
-        target=get_router_interfaces_q,
-        args=(router, params, snmpifc_proc_queue))
-    snmpifc_proc.start()
-
-    commands_proc_queue = Queue()
-    commands_proc = Process(
-        target=ssh_exec_commands_q,
-        args=(
-            router["hostname"],
-            params["ssh"],
-            [c["command"] for c in commands],
-            commands_proc_queue))
-    commands_proc.start()
-
-    threading_logger.debug("waiting for commands result: %r" % router)
-    command_output = commands_proc_queue.get()
-    assert len(command_output) == len(commands)
-
-    r = redis.StrictRedis(
-        host=params["redis"]["hostname"],
-        port=params["redis"]["port"])
-    for c, o in zip(commands, command_output):
-        if c["key"]:
-            r.hset(
-                name=router["hostname"],
-                key=c["key"],
-                value=json.dumps(c["parser"](o)))
-    commands_proc.join()
-    threading_logger.debug("... got commands result & joined: %r" % router)
-
-    threading_logger.debug("waiting for snmp ifc results: %r" % router)
-    r.hset(
-        name=router["hostname"],
-        key="snmp-interfaces",
-        value=json.dumps(snmpifc_proc_queue.get()))
-    snmpifc_proc.join()
-    threading_logger.debug("... got snmp ifc result & joined: %r" % router)
-
-    threading_logger.debug("[<<EXIT]get_router_details: %r" % router)
-
-
-def load_network_details(params):
-
-    threading_logger = logging.getLogger(constants.THREADING_LOGGER_NAME)
-
-    processes = []
-    for r in params["routers"]:
-        p = Process(target=get_router_details, args=(r, params))
-        p.start()
-        processes.append({"router": r, "process": p})
-
-    result = {}
-    for p in processes:
-        threading_logger.debug(
-            "waiting for get_router_details result: %r" % p["router"])
-        p["process"].join()
-        threading_logger.debug(
-            "got result and joined get_router_details proc: %r" % p["router"])
-
-    return result
-
-
 def _validate_config(ctx, param, value):
     return config.load(value)
 
@@ -110,20 +21,8 @@ def _validate_config(ctx, param, value):
     default=open("config.json"),
     callback=_validate_config)
 def cli(params):
-    load_network_details(params)
-
-    r = redis.StrictRedis(
-        host=params["redis"]["hostname"],
-        port=params["redis"]["port"])
-
-    result = {}
-    for hostname in r.keys():
-        host = {}
-        for key in r.hkeys(name=hostname):
-            host[key.decode("utf-8")] = json.loads(
-                r.hget(hostname, key).decode("utf-8")
-            )
-        result[hostname.decode("utf-8")] = host
+    router_details.update_network_details(params)
+    result = router_details.load_network_details(params["redis"])
 
     filename = "/tmp/router-info.json"
     logging.debug("writing output to: " + filename)
diff --git a/requirements.txt b/requirements.txt
index 9d2440c7..5ae83dfd 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,3 +7,4 @@ flask
 redis
 
 pytest
+pytest-mock
diff --git a/test/test_data_routes.py b/test/test_data_routes.py
index 811adb0d..9cbaae59 100644
--- a/test/test_data_routes.py
+++ b/test/test_data_routes.py
@@ -135,8 +135,6 @@ def app_config():
 @pytest.fixture
 def client(app_config):
     os.environ["SETTINGS_FILENAME"] = app_config
-    # with release_webservice.create_app().test_client() as c:
-    #         yield c
     with inventory_provider.create_app().test_client() as c:
         yield c
 
@@ -166,3 +164,33 @@ def test_version_request(client):
     jsonschema.validate(
         json.loads(rv.data.decode("utf-8")),
         version_schema)
+
+
+class MockedRedis(object):
+
+    db = {}
+
+    def __init__(self, *args, **kwargs):
+        pass
+
+    def set(self, key, value):
+        MockedRedis.db[key] = value
+
+    def keys(self, *args, **kwargs):
+        return ["a", "b", "c"]
+
+
+def test_abc(mocker, client):
+    mocker.patch(
+        'inventory_provider.router_details.redis.StrictRedis',
+        MockedRedis)
+    mocker.patch(
+        'inventory_provider.data_routes.redis.StrictRedis',
+        MockedRedis)
+    rv = client.post(
+        "data/abc",
+        headers=DEFAULT_REQUEST_HEADERS)
+    assert rv.status_code == 200
+
+    rsp = rv.data.decode("utf-8")
+    print(rsp)
-- 
GitLab