Skip to content
Snippets Groups Projects
Commit dca90233 authored by Erik Reid's avatar Erik Reid
Browse files

moved get_redis from db.py to routes/common.py

parent 4f6e0c17
Branches
Tags
No related merge requests found
import contextlib import contextlib
import mysql.connector import mysql.connector
import redis
from flask import current_app, g
def get_redis(): # pragma: no cover
if 'redis_db' not in g:
config = current_app.config['INVENTORY_PROVIDER_CONFIG']
g.redis_db = redis.Redis(
host=config['redis']['hostname'],
port=config['redis']['port'])
return g.redis_db
@contextlib.contextmanager @contextlib.contextmanager
......
...@@ -4,7 +4,7 @@ from flask import Blueprint, request, Response, current_app, jsonify ...@@ -4,7 +4,7 @@ from flask import Blueprint, request, Response, current_app, jsonify
import json import json
import jsonschema import jsonschema
from inventory_provider import db from inventory_provider.routes import common
routes = Blueprint("inventory-data-classifier-support-routes", __name__) routes = Blueprint("inventory-data-classifier-support-routes", __name__)
...@@ -59,7 +59,7 @@ def juniper_addresses(): ...@@ -59,7 +59,7 @@ def juniper_addresses():
} }
} }
servers = db.get_redis().get('alarmsdb:juniper_servers') servers = common.get_redis().get('alarmsdb:juniper_servers')
if not servers: if not servers:
return Response( return Response(
response="no juniper server data found", response="no juniper server data found",
...@@ -77,7 +77,7 @@ def juniper_addresses(): ...@@ -77,7 +77,7 @@ def juniper_addresses():
def get_trap_metadata(trap_type, source_equipment, interface): def get_trap_metadata(trap_type, source_equipment, interface):
# todo - Move this to config # todo - Move this to config
interface_info_key = "interface_services" interface_info_key = "interface_services"
r = db.get_redis() r = common.get_redis()
# todo - Change this to a call to the yet-to-be-created one source of all # todo - Change this to a call to the yet-to-be-created one source of all
# relevant information # relevant information
......
import functools import functools
from flask import request, Response from flask import request, Response, current_app, g
import redis
def get_redis():
if 'redis_db' not in g:
config = current_app.config['INVENTORY_PROVIDER_CONFIG']
g.redis_db = redis.StrictRedis(
host=config['redis']['hostname'],
port=config['redis']['port'])
return g.redis_db
def require_accepts_json(f): def require_accepts_json(f):
......
...@@ -6,7 +6,8 @@ from flask import Blueprint, jsonify, request, Response, current_app ...@@ -6,7 +6,8 @@ from flask import Blueprint, jsonify, request, Response, current_app
from lxml import etree from lxml import etree
import redis import redis
from inventory_provider import db, juniper from inventory_provider import juniper
from inventory_provider.routes import common
routes = Blueprint("inventory-data-query-routes", __name__) routes = Blueprint("inventory-data-query-routes", __name__)
...@@ -168,7 +169,7 @@ def bgp_configs(hostname): ...@@ -168,7 +169,7 @@ def bgp_configs(hostname):
methods=['GET', 'POST']) methods=['GET', 'POST'])
@require_accepts_json @require_accepts_json
def interface_statuses(hostname, interface): def interface_statuses(hostname, interface):
r = db.get_redis() r = common.get_redis()
result = r.hget("interface_statuses", result = r.hget("interface_statuses",
"{}::{}".format(hostname, interface)) "{}::{}".format(hostname, interface))
if not result: if not result:
...@@ -182,7 +183,7 @@ def interface_statuses(hostname, interface): ...@@ -182,7 +183,7 @@ def interface_statuses(hostname, interface):
@routes.route("/services/<hostname>/<path:interface>", @routes.route("/services/<hostname>/<path:interface>",
methods=['GET', 'POST']) methods=['GET', 'POST'])
def services_for_interface(hostname, interface): def services_for_interface(hostname, interface):
r = db.get_redis() r = common.get_redis()
result = r.hget("interface_services", result = r.hget("interface_services",
"{}::{}".format(hostname, interface)) "{}::{}".format(hostname, interface))
if not result: if not result:
......
...@@ -2,7 +2,7 @@ import functools ...@@ -2,7 +2,7 @@ import functools
import json import json
from flask import Blueprint, request, Response from flask import Blueprint, request, Response
from inventory_provider import db from inventory_provider.routes import common
routes = Blueprint("inventory-opsdb-query-routes", __name__) routes = Blueprint("inventory-opsdb-query-routes", __name__)
...@@ -38,7 +38,7 @@ def _decode_utf8_dict(d): ...@@ -38,7 +38,7 @@ def _decode_utf8_dict(d):
@routes.route("/interfaces") @routes.route("/interfaces")
def get_all_interface_details(): def get_all_interface_details():
r = db.get_redis() r = common.get_redis()
result = _decode_utf8_dict( result = _decode_utf8_dict(
r.hgetall(interfaces_key)) r.hgetall(interfaces_key))
...@@ -49,7 +49,7 @@ def get_all_interface_details(): ...@@ -49,7 +49,7 @@ def get_all_interface_details():
@routes.route("/interfaces/<equipment_name>") @routes.route("/interfaces/<equipment_name>")
def get_interface_details_for_equipment(equipment_name): def get_interface_details_for_equipment(equipment_name):
r = db.get_redis() r = common.get_redis()
result = {} result = {}
for t in r.hscan_iter(interfaces_key, "{}::*".format(equipment_name)): for t in r.hscan_iter(interfaces_key, "{}::*".format(equipment_name)):
result[t[0].decode("utf8")] = json.loads(t[1]) result[t[0].decode("utf8")] = json.loads(t[1])
...@@ -61,7 +61,7 @@ def get_interface_details_for_equipment(equipment_name): ...@@ -61,7 +61,7 @@ def get_interface_details_for_equipment(equipment_name):
@routes.route("/interfaces/<equipment_name>/<path:interface>") @routes.route("/interfaces/<equipment_name>/<path:interface>")
def get_interface_details(equipment_name, interface): def get_interface_details(equipment_name, interface):
r = db.get_redis() r = common.get_redis()
return Response( return Response(
r.hget( r.hget(
interfaces_key, interfaces_key,
...@@ -71,7 +71,7 @@ def get_interface_details(equipment_name, interface): ...@@ -71,7 +71,7 @@ def get_interface_details(equipment_name, interface):
@routes.route("/equipment-location") @routes.route("/equipment-location")
def get_all_equipment_locations(): def get_all_equipment_locations():
r = db.get_redis() r = common.get_redis()
result = list( result = list(
_decode_utf8_dict( _decode_utf8_dict(
r.hgetall(equipment_locations_key)).values()) r.hgetall(equipment_locations_key)).values())
...@@ -83,7 +83,7 @@ def get_all_equipment_locations(): ...@@ -83,7 +83,7 @@ def get_all_equipment_locations():
@routes.route("/equipment-location/<path:equipment_name>") @routes.route("/equipment-location/<path:equipment_name>")
def get_equipment_location(equipment_name): def get_equipment_location(equipment_name):
r = db.get_redis() r = common.get_redis()
return Response( return Response(
r.hget(equipment_locations_key, equipment_name), r.hget(equipment_locations_key, equipment_name),
mimetype="application/json") mimetype="application/json")
...@@ -91,7 +91,7 @@ def get_equipment_location(equipment_name): ...@@ -91,7 +91,7 @@ def get_equipment_location(equipment_name):
@routes.route("/circuit-hierarchy/children/<parent_id>") @routes.route("/circuit-hierarchy/children/<parent_id>")
def get_children(parent_id): def get_children(parent_id):
r = db.get_redis() r = common.get_redis()
return Response( return Response(
r.hget( r.hget(
service_parent_to_children_key, service_parent_to_children_key,
...@@ -101,7 +101,7 @@ def get_children(parent_id): ...@@ -101,7 +101,7 @@ def get_children(parent_id):
@routes.route("/circuit-hierarchy/parents/<child_id>") @routes.route("/circuit-hierarchy/parents/<child_id>")
def get_parents(child_id): def get_parents(child_id):
r = db.get_redis() r = common.get_redis()
return Response( return Response(
r.hget( r.hget(
service_child_to_parents_key, service_child_to_parents_key,
......
...@@ -2,7 +2,6 @@ import json ...@@ -2,7 +2,6 @@ import json
from flask import Blueprint, Response, jsonify from flask import Blueprint, Response, jsonify
from lxml import etree from lxml import etree
from inventory_provider import db
from inventory_provider import juniper from inventory_provider import juniper
from inventory_provider.routes import common from inventory_provider.routes import common
...@@ -12,7 +11,7 @@ routes = Blueprint('poller-support-routes', __name__) ...@@ -12,7 +11,7 @@ routes = Blueprint('poller-support-routes', __name__)
@routes.route('/interfaces/<hostname>', methods=['GET', 'POST']) @routes.route('/interfaces/<hostname>', methods=['GET', 'POST'])
@common.require_accepts_json @common.require_accepts_json
def poller_interface_oids(hostname): def poller_interface_oids(hostname):
r = db.get_redis() r = common.get_redis()
netconf_string = r.hget(hostname, 'netconf') netconf_string = r.hget(hostname, 'netconf')
if not netconf_string: if not netconf_string:
......
...@@ -52,7 +52,7 @@ class MockedRedis(object): ...@@ -52,7 +52,7 @@ class MockedRedis(object):
@pytest.fixture @pytest.fixture
def client_with_mocked_data(mocker, client): def client_with_mocked_data(mocker, client):
mocker.patch( mocker.patch(
'inventory_provider.routes.data.redis.StrictRedis', 'inventory_provider.routes.common.redis.StrictRedis',
MockedRedis) MockedRedis)
return client return client
......
...@@ -58,8 +58,8 @@ class MockedRedis(object): ...@@ -58,8 +58,8 @@ class MockedRedis(object):
@pytest.fixture @pytest.fixture
def client_with_mocked_data(mocker, client): def client_with_mocked_data(mocker, client):
mocker.patch( mocker.patch(
'inventory_provider.routes.classifier.db.get_redis', 'inventory_provider.routes.common.redis.StrictRedis',
return_value=MockedRedis()) MockedRedis)
return client return client
......
...@@ -49,15 +49,15 @@ def test_juniper_addresses(mocker, client): ...@@ -49,15 +49,15 @@ def test_juniper_addresses(mocker, client):
] ]
class MockedRedis(): class MockedRedis():
def __init__(self): def __init__(self, *args, **kwargs):
pass pass
def get(self, ignored): def get(self, ignored):
return json.dumps(test_data).encode('utf-8') return json.dumps(test_data).encode('utf-8')
mocker.patch( mocker.patch(
'inventory_provider.routes.classifier.db.get_redis', 'inventory_provider.routes.common.redis.StrictRedis',
return_value=MockedRedis()) MockedRedis)
response_schema = { response_schema = {
"$schema": "http://json-schema.org/draft-07/schema#", "$schema": "http://json-schema.org/draft-07/schema#",
...@@ -93,7 +93,7 @@ def test_trap_metadata(mocker, client): ...@@ -93,7 +93,7 @@ def test_trap_metadata(mocker, client):
} }
] ]
mocked_redis = mocker.patch( mocked_redis = mocker.patch(
"inventory_provider.routes.classifier.db.get_redis") "inventory_provider.routes.common.get_redis")
mocked_redis.return_value.hget.return_value = json.dumps(test_data)\ mocked_redis.return_value.hget.return_value = json.dumps(test_data)\
.encode("utf-8") .encode("utf-8")
......
...@@ -9,7 +9,7 @@ DEFAULT_REQUEST_HEADERS = { ...@@ -9,7 +9,7 @@ DEFAULT_REQUEST_HEADERS = {
def test_get_one_equipment_location(mocker, client): def test_get_one_equipment_location(mocker, client):
mocked_redis = mocker.patch( mocked_redis = mocker.patch(
"inventory_provider.routes.opsdb.db.get_redis") "inventory_provider.routes.common.get_redis")
mocked_hget = mocked_redis.return_value.hget mocked_hget = mocked_redis.return_value.hget
dummy_data = { dummy_data = {
"absid": 1404, "absid": 1404,
...@@ -38,7 +38,7 @@ def test_get_one_equipment_location(mocker, client): ...@@ -38,7 +38,7 @@ def test_get_one_equipment_location(mocker, client):
def test_get_equipment_location(mocker, client): def test_get_equipment_location(mocker, client):
mocked_redis = mocker.patch( mocked_redis = mocker.patch(
"inventory_provider.routes.opsdb.db.get_redis") "inventory_provider.routes.common.get_redis")
mocked_hgetall = mocked_redis.return_value.hgetall mocked_hgetall = mocked_redis.return_value.hgetall
rv = client.get( rv = client.get(
...@@ -54,7 +54,7 @@ def test_get_equipment_location(mocker, client): ...@@ -54,7 +54,7 @@ def test_get_equipment_location(mocker, client):
def test_get_interface_info(mocker, client): def test_get_interface_info(mocker, client):
mocked_redis = mocker.patch( mocked_redis = mocker.patch(
"inventory_provider.routes.opsdb.db.get_redis") "inventory_provider.routes.common.get_redis")
mocked_hgetall = mocked_redis.return_value.hgetall mocked_hgetall = mocked_redis.return_value.hgetall
rv = client.get( rv = client.get(
...@@ -70,7 +70,7 @@ def test_get_interface_info(mocker, client): ...@@ -70,7 +70,7 @@ def test_get_interface_info(mocker, client):
def test_get_interface_info_for_equipment(mocker, client): def test_get_interface_info_for_equipment(mocker, client):
mocked_redis = mocker.patch( mocked_redis = mocker.patch(
"inventory_provider.routes.opsdb.db.get_redis") "inventory_provider.routes.common.get_redis")
mocked_hscan_iter = mocked_redis.return_value.hscan_iter mocked_hscan_iter = mocked_redis.return_value.hscan_iter
rv = client.get( rv = client.get(
...@@ -86,7 +86,7 @@ def test_get_interface_info_for_equipment(mocker, client): ...@@ -86,7 +86,7 @@ def test_get_interface_info_for_equipment(mocker, client):
def test_get_interface_info_for_equipment_and_interface(mocker, client): def test_get_interface_info_for_equipment_and_interface(mocker, client):
mocked_redis = mocker.patch( mocked_redis = mocker.patch(
"inventory_provider.routes.opsdb.db.get_redis") "inventory_provider.routes.common.get_redis")
mocked_hget = mocked_redis.return_value.hget mocked_hget = mocked_redis.return_value.hget
rv = client.get( rv = client.get(
...@@ -102,7 +102,7 @@ def test_get_interface_info_for_equipment_and_interface(mocker, client): ...@@ -102,7 +102,7 @@ def test_get_interface_info_for_equipment_and_interface(mocker, client):
def test_get_children(mocker, client): def test_get_children(mocker, client):
mocked_redis = mocker.patch( mocked_redis = mocker.patch(
"inventory_provider.routes.opsdb.db.get_redis") "inventory_provider.routes.common.get_redis")
mocked_hget = mocked_redis.return_value.hget mocked_hget = mocked_redis.return_value.hget
rv = client.get( rv = client.get(
...@@ -119,7 +119,7 @@ def test_get_children(mocker, client): ...@@ -119,7 +119,7 @@ def test_get_children(mocker, client):
def test_get_parents(mocker, client): def test_get_parents(mocker, client):
mocked_redis = mocker.patch( mocked_redis = mocker.patch(
"inventory_provider.routes.opsdb.db.get_redis") "inventory_provider.routes.common.get_redis")
mocked_hget = mocked_redis.return_value.hget mocked_hget = mocked_redis.return_value.hget
rv = client.get( rv = client.get(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment