Skip to content
Snippets Groups Projects
Unverified Commit 10f5ba5c authored by Adeel Ahmad's avatar Adeel Ahmad
Browse files

Add type hinting and Enum definition for service names

parent 25c0d990
No related branches found
No related tags found
1 merge request!50Dboard3 1142/token auth
......@@ -2,17 +2,19 @@ from functools import wraps
from flask import current_app, g, jsonify
from flask_httpauth import HTTPTokenAuth
from inventory_provider.config import ANONYMOUS_SERVICE_NAME
from inventory_provider.config import ServiceName
from typing import Callable, List, Optional
auth = HTTPTokenAuth(scheme="ApiKey")
@auth.verify_token
def verify_api_key(api_key):
def verify_api_key(api_key: Optional[str]) -> Optional[ServiceName]:
config = current_app.config["INVENTORY_PROVIDER_CONFIG"]
# This is to enable anonymous access for testing.
if not api_key:
g.auth_client = ANONYMOUS_SERVICE_NAME
return ANONYMOUS_SERVICE_NAME
g.auth_client = ServiceName.ANONYMOUS
return ServiceName.ANONYMOUS
for client, details in config['api-keys'].items():
if details.get('api-key') == api_key:
......@@ -20,19 +22,20 @@ def verify_api_key(api_key):
return client
return None
def authorize(*, allowed_clients):
def authorize(*, allowed_clients: List[ServiceName]) -> Callable:
"""Decorator to restrict route access to specific clients."""
if not isinstance(allowed_clients, list):
raise TypeError("allowed_clients must be a list of allowed service names")
def decorator(f):
def decorator(f: Callable) -> Callable:
@wraps(f)
def wrapped(*args, **kwargs):
client = g.get("auth_client")
if client not in allowed_clients:
# Anonymous clients are allowed to access any resource without providing an API key
# TODO: Only for testing, should be removed in Production
if client != ANONYMOUS_SERVICE_NAME:
if client != ServiceName.ANONYMOUS:
return jsonify({"error": "Forbidden"}), 403
return f(*args, **kwargs)
return wrapped
......
import json
import jsonschema
from enum import StrEnum
class ServiceName(StrEnum):
DASHBOARD = 'dashboard'
BRIAN = 'brian'
REPORTING = 'reporting'
ANONYMOUS = 'anonymous'
DASHBOARD_SERVICE_NAME = 'dashboard'
BRIAN_SERVICE_NAME = 'brian'
REPORTING_SERVICE_NAME = 'reporting'
ANONYMOUS_SERVICE_NAME = 'anonymous'
CONFIG_SCHEMA = {
'$schema': 'https://json-schema.org/draft-07/schema#',
......@@ -26,9 +30,9 @@ CONFIG_SCHEMA = {
"api-keys-credentials": {
"type": "object",
'properties': {
DASHBOARD_SERVICE_NAME: {'$ref': '#/definitions/api-key'},
BRIAN_SERVICE_NAME: {'$ref': '#/definitions/api-key'},
REPORTING_SERVICE_NAME: {'$ref': '#/definitions/api-key'}
ServiceName.DASHBOARD: {'$ref': '#/definitions/api-key'},
ServiceName.BRIAN: {'$ref': '#/definitions/api-key'},
ServiceName.REPORTING: {'$ref': '#/definitions/api-key'}
},
"additionalProperties": False
},
......
......@@ -68,7 +68,7 @@ from inventory_provider.routes import common
from inventory_provider.routes.common import _ignore_cache_or_retrieve, cache_result
from inventory_provider.auth import authorize
from inventory_provider.config import DASHBOARD_SERVICE_NAME
from inventory_provider.config import ServiceName
routes = Blueprint("inventory-data-classifier-support-routes", __name__)
......@@ -334,7 +334,7 @@ def get_link_info_response_body(
@routes.route("/juniper-link-info/<source_equipment>/<path:interface>",
methods=['GET'])
@common.require_accepts_json
@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
@authorize(allowed_clients=[ServiceName.DASHBOARD])
def handle_link_info_request(source_equipment: str, interface: str) -> Response:
"""
Handler for /classifier/juniper-link-info that
......@@ -376,7 +376,7 @@ def handle_link_info_request(source_equipment: str, interface: str) -> Response:
@routes.route("/epipe-sap-info/<source_equipment>/<service_id>/<vpn_id>", methods=['GET'])
@common.require_accepts_json
@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
@authorize(allowed_clients=[ServiceName.DASHBOARD])
def handle_epipe_sap_info_request(source_equipment: str, service_id: str, vpn_id: str) -> Response:
r = common.get_current_redis()
......@@ -605,7 +605,7 @@ def find_interfaces(address):
@routes.route("/peer-info/<address_str>", methods=['GET'])
@common.require_accepts_json
@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
@authorize(allowed_clients=[ServiceName.DASHBOARD])
def peer_info(address_str: str) -> Response:
"""
Handler for /classifier/peer-info that returns bgp peering metadata.
......@@ -701,7 +701,7 @@ def peer_info(address_str: str) -> Response:
@routes.route(
"/mtc-interface-info/<node>/<interface>", methods=['GET'])
@common.require_accepts_json
@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
@authorize(allowed_clients=[ServiceName.DASHBOARD])
def get_mtc_interface_info(node, interface):
"""
Handler for /classifier/mtc-interface-info that
......@@ -743,7 +743,7 @@ def get_mtc_interface_info(node, interface):
"<source_equipment>/<interface>/<circuit_id>",
methods=['GET'])
@common.require_accepts_json
@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
@authorize(allowed_clients=[ServiceName.DASHBOARD])
def get_trap_metadata(source_equipment: str, interface: str, circuit_id: str) \
-> Response:
"""
......@@ -837,7 +837,7 @@ def get_trap_metadata(source_equipment: str, interface: str, circuit_id: str) \
@routes.route("/infinera-fiberlink-info/<ne_name_str>/<object_name_str>",
methods=['GET'])
@common.require_accepts_json
@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
@authorize(allowed_clients=[ServiceName.DASHBOARD])
def get_fiberlink_trap_metadata(ne_name_str: str, object_name_str: str) \
-> Response:
"""
......@@ -974,7 +974,7 @@ def get_fiberlink_trap_metadata(ne_name_str: str, object_name_str: str) \
@routes.route("/tnms-fibre-info/<path:enms_pc_name>", methods=['GET'])
@common.require_accepts_json
@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
@authorize(allowed_clients=[ServiceName.DASHBOARD])
def get_tnms_fibre_trap_metadata(enms_pc_name: str) -> Response:
"""
Handler for /classifier/infinera-fiberlink-info that
......@@ -1077,7 +1077,7 @@ def get_tnms_fibre_trap_metadata(enms_pc_name: str) -> Response:
@routes.route('/coriant-port-info/<equipment_name>/<path:entity_string>',
methods=['GET'])
@common.require_accepts_json
@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
@authorize(allowed_clients=[ServiceName.DASHBOARD])
def get_coriant_port_info(equipment_name: str, entity_string: str) -> Response:
"""
Handler for /classifier/coriant-info that
......@@ -1099,7 +1099,7 @@ def get_coriant_port_info(equipment_name: str, entity_string: str) -> Response:
@routes.route('/coriant-tp-info/<equipment_name>/<path:entity_string>',
methods=['GET'])
@common.require_accepts_json
@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
@authorize(allowed_clients=[ServiceName.DASHBOARD])
def get_coriant_tp_info(equipment_name: str, entity_string: str) -> Response:
"""
Handler for /classifier/coriant-info that
......@@ -1210,7 +1210,7 @@ def _get_coriant_info(
@routes.route("/router-info", methods=["GET"])
@common.require_accepts_json
@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
@authorize(allowed_clients=[ServiceName.DASHBOARD])
def get_all_routers() -> Response:
redis = common.get_current_redis()
result = cache_result(
......@@ -1233,7 +1233,7 @@ def _get_router_list(redis):
@routes.route("/router-info/<equipment_name>", methods=["GET"])
@common.require_accepts_json
@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
@authorize(allowed_clients=[ServiceName.DASHBOARD])
def get_router_info(equipment_name: str) -> Response:
redis = common.get_current_redis()
ims_equipment_name = get_ims_equipment_name_or_none(equipment_name, redis)
......
......@@ -4,36 +4,12 @@ from inventory_provider.routes.classifier_schema import (
ROUTER_INFO_ALL_ROUTERS_RESPONSE_SCHEMA,
)
DEFAULT_REQUEST_HEADERS_NO_KEY = {
"Content-type": "application/json",
"Accept": ["application/json"],
}
DEFAULT_REQUEST_HEADERS_BRIAN_KEY = {
"Content-type": "application/json",
"Accept": ["application/json"],
"Authorization": "ApiKey brian_key",
}
DEFAULT_REQUEST_HEADERS_REPORTING_KEY = {
"Content-type": "application/json",
"Accept": ["application/json"],
"Authorization": "ApiKey reporting_key",
}
DEFAULT_REQUEST_HEADERS_DASHBOARD_KEY = {
"Content-type": "application/json",
"Accept": ["application/json"],
"Authorization": "ApiKey dashboard_key",
}
DEFAULT_REQUEST_HEADERS_BAD_KEY = {
"Content-type": "application/json",
"Accept": ["application/json"],
"Authorization": "ApiKey badapikey",
}
def test_classifier_router_no_key(client):
DEFAULT_REQUEST_HEADERS_NO_KEY = {
"Content-type": "application/json",
"Accept": ["application/json"],
}
rv = client.get("/classifier/router-info", headers=DEFAULT_REQUEST_HEADERS_NO_KEY)
assert rv.status_code == 200
assert rv.is_json
......@@ -42,7 +18,13 @@ def test_classifier_router_no_key(client):
jsonschema.validate(result, ROUTER_INFO_ALL_ROUTERS_RESPONSE_SCHEMA)
assert len(result) > 0
def test_classifier_router_dashboard_key(client):
DEFAULT_REQUEST_HEADERS_DASHBOARD_KEY = {
"Content-type": "application/json",
"Accept": ["application/json"],
"Authorization": "ApiKey dashboard_key",
}
rv = client.get("/classifier/router-info", headers=DEFAULT_REQUEST_HEADERS_DASHBOARD_KEY)
assert rv.status_code == 200
assert rv.is_json
......@@ -51,7 +33,13 @@ def test_classifier_router_dashboard_key(client):
jsonschema.validate(result, ROUTER_INFO_ALL_ROUTERS_RESPONSE_SCHEMA)
assert len(result) > 0
def test_classifier_router_brian_key(client):
DEFAULT_REQUEST_HEADERS_BRIAN_KEY = {
"Content-type": "application/json",
"Accept": ["application/json"],
"Authorization": "ApiKey brian_key",
}
rv = client.get("/classifier/router-info", headers=DEFAULT_REQUEST_HEADERS_BRIAN_KEY)
assert rv.status_code == 403
assert rv.is_json
......@@ -59,7 +47,13 @@ def test_classifier_router_brian_key(client):
assert result["error"] == "Forbidden"
def test_classifier_router_reporting_key(client):
DEFAULT_REQUEST_HEADERS_REPORTING_KEY = {
"Content-type": "application/json",
"Accept": ["application/json"],
"Authorization": "ApiKey reporting_key",
}
rv = client.get("/classifier/router-info", headers=DEFAULT_REQUEST_HEADERS_REPORTING_KEY)
assert rv.status_code == 403
assert rv.is_json
......@@ -67,7 +61,13 @@ def test_classifier_router_reporting_key(client):
assert result["error"] == "Forbidden"
def test_classifier_router_bad_key(client):
DEFAULT_REQUEST_HEADERS_BAD_KEY = {
"Content-type": "application/json",
"Accept": ["application/json"],
"Authorization": "ApiKey badapikey",
}
rv = client.get("/classifier/router-info", headers=DEFAULT_REQUEST_HEADERS_BAD_KEY)
assert rv.status_code == 401
result = rv.text
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment