diff --git a/inventory_provider/auth.py b/inventory_provider/auth.py index 1dd8ff81ef56049673223f08c3e1f327376ee8a8..ad7fe03396493fbd78996a2609fa58ceeeea1255 100644 --- a/inventory_provider/auth.py +++ b/inventory_provider/auth.py @@ -2,17 +2,19 @@ from functools import wraps from flask import current_app, g, jsonify from flask_httpauth import HTTPTokenAuth -from inventory_provider.config import ANONYMOUS_SERVICE_NAME +from inventory_provider.config import ServiceName +from typing import Callable, List, Optional auth = HTTPTokenAuth(scheme="ApiKey") + @auth.verify_token -def verify_api_key(api_key): +def verify_api_key(api_key: Optional[str]) -> Optional[ServiceName]: config = current_app.config["INVENTORY_PROVIDER_CONFIG"] # This is to enable anonymous access for testing. if not api_key: - g.auth_client = ANONYMOUS_SERVICE_NAME - return ANONYMOUS_SERVICE_NAME + g.auth_client = ServiceName.ANONYMOUS + return ServiceName.ANONYMOUS for client, details in config['api-keys'].items(): if details.get('api-key') == api_key: @@ -20,19 +22,20 @@ def verify_api_key(api_key): return client return None -def authorize(*, allowed_clients): + +def authorize(*, allowed_clients: List[ServiceName]) -> Callable: """Decorator to restrict route access to specific clients.""" if not isinstance(allowed_clients, list): raise TypeError("allowed_clients must be a list of allowed service names") - def decorator(f): + def decorator(f: Callable) -> Callable: @wraps(f) def wrapped(*args, **kwargs): client = g.get("auth_client") if client not in allowed_clients: # Anonymous clients are allowed to access any resource without providing an API key # TODO: Only for testing, should be removed in Production - if client != ANONYMOUS_SERVICE_NAME: + if client != ServiceName.ANONYMOUS: return jsonify({"error": "Forbidden"}), 403 return f(*args, **kwargs) return wrapped diff --git a/inventory_provider/config.py b/inventory_provider/config.py index c8497a8bc03c1e2cc225a772fe450582bed98467..6e2dfb7fa365c1d0436aec8815de045dce6a8d6a 100644 --- a/inventory_provider/config.py +++ b/inventory_provider/config.py @@ -1,10 +1,14 @@ import json import jsonschema +from enum import StrEnum + + +class ServiceName(StrEnum): + DASHBOARD = 'dashboard' + BRIAN = 'brian' + REPORTING = 'reporting' + ANONYMOUS = 'anonymous' -DASHBOARD_SERVICE_NAME = 'dashboard' -BRIAN_SERVICE_NAME = 'brian' -REPORTING_SERVICE_NAME = 'reporting' -ANONYMOUS_SERVICE_NAME = 'anonymous' CONFIG_SCHEMA = { '$schema': 'https://json-schema.org/draft-07/schema#', @@ -26,9 +30,9 @@ CONFIG_SCHEMA = { "api-keys-credentials": { "type": "object", 'properties': { - DASHBOARD_SERVICE_NAME: {'$ref': '#/definitions/api-key'}, - BRIAN_SERVICE_NAME: {'$ref': '#/definitions/api-key'}, - REPORTING_SERVICE_NAME: {'$ref': '#/definitions/api-key'} + ServiceName.DASHBOARD: {'$ref': '#/definitions/api-key'}, + ServiceName.BRIAN: {'$ref': '#/definitions/api-key'}, + ServiceName.REPORTING: {'$ref': '#/definitions/api-key'} }, "additionalProperties": False }, diff --git a/inventory_provider/routes/classifier.py b/inventory_provider/routes/classifier.py index 018df7f53f12ef9372bf56e92ae0af0b26ca5288..41f38716bd66c90765408fe4a164b0e6bf1d8b5b 100644 --- a/inventory_provider/routes/classifier.py +++ b/inventory_provider/routes/classifier.py @@ -68,7 +68,7 @@ from inventory_provider.routes import common from inventory_provider.routes.common import _ignore_cache_or_retrieve, cache_result from inventory_provider.auth import authorize -from inventory_provider.config import DASHBOARD_SERVICE_NAME +from inventory_provider.config import ServiceName routes = Blueprint("inventory-data-classifier-support-routes", __name__) @@ -334,7 +334,7 @@ def get_link_info_response_body( @routes.route("/juniper-link-info/<source_equipment>/<path:interface>", methods=['GET']) @common.require_accepts_json -@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME]) +@authorize(allowed_clients=[ServiceName.DASHBOARD]) def handle_link_info_request(source_equipment: str, interface: str) -> Response: """ Handler for /classifier/juniper-link-info that @@ -376,7 +376,7 @@ def handle_link_info_request(source_equipment: str, interface: str) -> Response: @routes.route("/epipe-sap-info/<source_equipment>/<service_id>/<vpn_id>", methods=['GET']) @common.require_accepts_json -@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME]) +@authorize(allowed_clients=[ServiceName.DASHBOARD]) def handle_epipe_sap_info_request(source_equipment: str, service_id: str, vpn_id: str) -> Response: r = common.get_current_redis() @@ -605,7 +605,7 @@ def find_interfaces(address): @routes.route("/peer-info/<address_str>", methods=['GET']) @common.require_accepts_json -@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME]) +@authorize(allowed_clients=[ServiceName.DASHBOARD]) def peer_info(address_str: str) -> Response: """ Handler for /classifier/peer-info that returns bgp peering metadata. @@ -701,7 +701,7 @@ def peer_info(address_str: str) -> Response: @routes.route( "/mtc-interface-info/<node>/<interface>", methods=['GET']) @common.require_accepts_json -@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME]) +@authorize(allowed_clients=[ServiceName.DASHBOARD]) def get_mtc_interface_info(node, interface): """ Handler for /classifier/mtc-interface-info that @@ -743,7 +743,7 @@ def get_mtc_interface_info(node, interface): "<source_equipment>/<interface>/<circuit_id>", methods=['GET']) @common.require_accepts_json -@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME]) +@authorize(allowed_clients=[ServiceName.DASHBOARD]) def get_trap_metadata(source_equipment: str, interface: str, circuit_id: str) \ -> Response: """ @@ -837,7 +837,7 @@ def get_trap_metadata(source_equipment: str, interface: str, circuit_id: str) \ @routes.route("/infinera-fiberlink-info/<ne_name_str>/<object_name_str>", methods=['GET']) @common.require_accepts_json -@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME]) +@authorize(allowed_clients=[ServiceName.DASHBOARD]) def get_fiberlink_trap_metadata(ne_name_str: str, object_name_str: str) \ -> Response: """ @@ -974,7 +974,7 @@ def get_fiberlink_trap_metadata(ne_name_str: str, object_name_str: str) \ @routes.route("/tnms-fibre-info/<path:enms_pc_name>", methods=['GET']) @common.require_accepts_json -@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME]) +@authorize(allowed_clients=[ServiceName.DASHBOARD]) def get_tnms_fibre_trap_metadata(enms_pc_name: str) -> Response: """ Handler for /classifier/infinera-fiberlink-info that @@ -1077,7 +1077,7 @@ def get_tnms_fibre_trap_metadata(enms_pc_name: str) -> Response: @routes.route('/coriant-port-info/<equipment_name>/<path:entity_string>', methods=['GET']) @common.require_accepts_json -@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME]) +@authorize(allowed_clients=[ServiceName.DASHBOARD]) def get_coriant_port_info(equipment_name: str, entity_string: str) -> Response: """ Handler for /classifier/coriant-info that @@ -1099,7 +1099,7 @@ def get_coriant_port_info(equipment_name: str, entity_string: str) -> Response: @routes.route('/coriant-tp-info/<equipment_name>/<path:entity_string>', methods=['GET']) @common.require_accepts_json -@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME]) +@authorize(allowed_clients=[ServiceName.DASHBOARD]) def get_coriant_tp_info(equipment_name: str, entity_string: str) -> Response: """ Handler for /classifier/coriant-info that @@ -1210,7 +1210,7 @@ def _get_coriant_info( @routes.route("/router-info", methods=["GET"]) @common.require_accepts_json -@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME]) +@authorize(allowed_clients=[ServiceName.DASHBOARD]) def get_all_routers() -> Response: redis = common.get_current_redis() result = cache_result( @@ -1233,7 +1233,7 @@ def _get_router_list(redis): @routes.route("/router-info/<equipment_name>", methods=["GET"]) @common.require_accepts_json -@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME]) +@authorize(allowed_clients=[ServiceName.DASHBOARD]) def get_router_info(equipment_name: str) -> Response: redis = common.get_current_redis() ims_equipment_name = get_ims_equipment_name_or_none(equipment_name, redis) diff --git a/test/test_auth.py b/test/test_auth.py index 226ab7a5009b2df1dd187ef14bafca01c276c7d2..f12e2d6d3f63d5c846ce50ad42a3400cab5b79ff 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -4,36 +4,12 @@ from inventory_provider.routes.classifier_schema import ( ROUTER_INFO_ALL_ROUTERS_RESPONSE_SCHEMA, ) -DEFAULT_REQUEST_HEADERS_NO_KEY = { - "Content-type": "application/json", - "Accept": ["application/json"], -} - -DEFAULT_REQUEST_HEADERS_BRIAN_KEY = { - "Content-type": "application/json", - "Accept": ["application/json"], - "Authorization": "ApiKey brian_key", -} - -DEFAULT_REQUEST_HEADERS_REPORTING_KEY = { - "Content-type": "application/json", - "Accept": ["application/json"], - "Authorization": "ApiKey reporting_key", -} - -DEFAULT_REQUEST_HEADERS_DASHBOARD_KEY = { - "Content-type": "application/json", - "Accept": ["application/json"], - "Authorization": "ApiKey dashboard_key", -} - -DEFAULT_REQUEST_HEADERS_BAD_KEY = { - "Content-type": "application/json", - "Accept": ["application/json"], - "Authorization": "ApiKey badapikey", -} def test_classifier_router_no_key(client): + DEFAULT_REQUEST_HEADERS_NO_KEY = { + "Content-type": "application/json", + "Accept": ["application/json"], + } rv = client.get("/classifier/router-info", headers=DEFAULT_REQUEST_HEADERS_NO_KEY) assert rv.status_code == 200 assert rv.is_json @@ -42,7 +18,13 @@ def test_classifier_router_no_key(client): jsonschema.validate(result, ROUTER_INFO_ALL_ROUTERS_RESPONSE_SCHEMA) assert len(result) > 0 + def test_classifier_router_dashboard_key(client): + DEFAULT_REQUEST_HEADERS_DASHBOARD_KEY = { + "Content-type": "application/json", + "Accept": ["application/json"], + "Authorization": "ApiKey dashboard_key", + } rv = client.get("/classifier/router-info", headers=DEFAULT_REQUEST_HEADERS_DASHBOARD_KEY) assert rv.status_code == 200 assert rv.is_json @@ -51,7 +33,13 @@ def test_classifier_router_dashboard_key(client): jsonschema.validate(result, ROUTER_INFO_ALL_ROUTERS_RESPONSE_SCHEMA) assert len(result) > 0 + def test_classifier_router_brian_key(client): + DEFAULT_REQUEST_HEADERS_BRIAN_KEY = { + "Content-type": "application/json", + "Accept": ["application/json"], + "Authorization": "ApiKey brian_key", + } rv = client.get("/classifier/router-info", headers=DEFAULT_REQUEST_HEADERS_BRIAN_KEY) assert rv.status_code == 403 assert rv.is_json @@ -59,7 +47,13 @@ def test_classifier_router_brian_key(client): assert result["error"] == "Forbidden" + def test_classifier_router_reporting_key(client): + DEFAULT_REQUEST_HEADERS_REPORTING_KEY = { + "Content-type": "application/json", + "Accept": ["application/json"], + "Authorization": "ApiKey reporting_key", + } rv = client.get("/classifier/router-info", headers=DEFAULT_REQUEST_HEADERS_REPORTING_KEY) assert rv.status_code == 403 assert rv.is_json @@ -67,7 +61,13 @@ def test_classifier_router_reporting_key(client): assert result["error"] == "Forbidden" + def test_classifier_router_bad_key(client): + DEFAULT_REQUEST_HEADERS_BAD_KEY = { + "Content-type": "application/json", + "Accept": ["application/json"], + "Authorization": "ApiKey badapikey", + } rv = client.get("/classifier/router-info", headers=DEFAULT_REQUEST_HEADERS_BAD_KEY) assert rv.status_code == 401 result = rv.text