From 0f66457b8f4fc970e131992863064f5be194a76f Mon Sep 17 00:00:00 2001
From: Adeel Ahmad <adeel.ahmad@geant.org>
Date: Thu, 20 Mar 2025 10:11:24 +0000
Subject: [PATCH] Update config schema for API keys and restrict decorator
 input to be a list

---
 inventory_provider/__init__.py          |  1 +
 inventory_provider/auth.py              | 12 +++++-----
 inventory_provider/config.py            | 30 ++++++++++++++-----------
 inventory_provider/routes/classifier.py | 13 +++++++++++
 inventory_provider/routes/msr.py        |  2 --
 inventory_provider/routes/testing.py    |  2 --
 6 files changed, 38 insertions(+), 22 deletions(-)

diff --git a/inventory_provider/__init__.py b/inventory_provider/__init__.py
index 915b79a2..3155d6cc 100644
--- a/inventory_provider/__init__.py
+++ b/inventory_provider/__init__.py
@@ -53,6 +53,7 @@ def create_app(setup_logging=True):
     @app.before_request
     @auth.login_required
     def secure_before_request():
+        # This method is a boilerplate required by the library to enable authentication
         pass
 
     # IMS based routes
diff --git a/inventory_provider/auth.py b/inventory_provider/auth.py
index 7ee9651e..ba5ef9b6 100644
--- a/inventory_provider/auth.py
+++ b/inventory_provider/auth.py
@@ -1,6 +1,7 @@
 from flask import current_app, g, jsonify
 from flask_httpauth import HTTPTokenAuth
 from functools import wraps
+from config import ANONYMOUS_SERVICE_NAME
 
 auth = HTTPTokenAuth(scheme="ApiKey")
 
@@ -9,8 +10,8 @@ def verify_api_key(api_key):
     config = current_app.config["INVENTORY_PROVIDER_CONFIG"]
     # This is to enable anonymous access for testing.
     if not api_key:
-        g.auth_client = "anonymous"
-        return "anonymous"
+        g.auth_client = ANONYMOUS_SERVICE_NAME
+        return ANONYMOUS_SERVICE_NAME
 
     for client, details in config['api-keys'].items():
         if details.get('api-key') == api_key:
@@ -20,8 +21,9 @@ def verify_api_key(api_key):
 
 def authorize(*, allowed_clients):
     """Decorator to restrict route access to specific clients."""
-    if not isinstance(allowed_clients, (list, tuple)):
-        allowed_clients = [allowed_clients] # Convert single client to list
+    if not isinstance(allowed_clients, list):
+        raise TypeError("allowed_clients must be a list of allowed service names")
+
     def decorator(f):
         @wraps(f)
         def wrapped(*args, **kwargs):
@@ -33,7 +35,7 @@ def authorize(*, allowed_clients):
             if client not in allowed_clients:
                 # Anonymous clients are allowed to access any resource without providing an API key
                 # TODO: Only for testing, should be removed in Production
-                if client != "anonymous":
+                if client != ANONYMOUS_SERVICE_NAME:
                     return jsonify({"error": "Forbidden"}), 403
 
             return f(*args, **kwargs)
diff --git a/inventory_provider/config.py b/inventory_provider/config.py
index 0df04129..c8497a8b 100644
--- a/inventory_provider/config.py
+++ b/inventory_provider/config.py
@@ -1,6 +1,11 @@
 import json
 import jsonschema
 
+DASHBOARD_SERVICE_NAME = 'dashboard'
+BRIAN_SERVICE_NAME = 'brian'
+REPORTING_SERVICE_NAME = 'reporting'
+ANONYMOUS_SERVICE_NAME = 'anonymous'
+
 CONFIG_SCHEMA = {
     '$schema': 'https://json-schema.org/draft-07/schema#',
 
@@ -10,21 +15,20 @@ CONFIG_SCHEMA = {
             'maximum': 60,  # sanity
             'exclusiveMinimum': 0
         },
+        'api-key': {
+            "type": "object",
+            "properties": {
+                "api-key": {"type": "string"}
+            },
+            "required": ["api-key"],
+            "additionalProperties": False
+        },
         "api-keys-credentials": {
             "type": "object",
-            "patternProperties": {
-                "^[a-zA-Z0-9-_]+$": {
-                    "type": "object",
-                    "properties": {
-                        "api-key": {
-                            "type": "string",
-                            # "minLength": 32,
-                            # "description": "API key (Base64, UUID, or Hexadecimal format)"
-                        }
-                    },
-                    "required": ["api-key"],
-                    "additionalProperties": False
-                }
+            'properties': {
+                DASHBOARD_SERVICE_NAME: {'$ref': '#/definitions/api-key'},
+                BRIAN_SERVICE_NAME: {'$ref': '#/definitions/api-key'},
+                REPORTING_SERVICE_NAME: {'$ref': '#/definitions/api-key'}
             },
             "additionalProperties": False
         },
diff --git a/inventory_provider/routes/classifier.py b/inventory_provider/routes/classifier.py
index 47c4e888..8b65079b 100644
--- a/inventory_provider/routes/classifier.py
+++ b/inventory_provider/routes/classifier.py
@@ -67,6 +67,8 @@ from redis import Redis
 from inventory_provider.routes import common
 from inventory_provider.routes.common import _ignore_cache_or_retrieve, cache_result
 
+from inventory_provider.config import authorize, DASHBOARD_SERVICE_NAME
+
 routes = Blueprint("inventory-data-classifier-support-routes", __name__)
 
 logger = logging.getLogger(__name__)
@@ -331,6 +333,7 @@ def get_link_info_response_body(
 @routes.route("/juniper-link-info/<source_equipment>/<path:interface>",
               methods=['GET'])
 @common.require_accepts_json
+@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
 def handle_link_info_request(source_equipment: str, interface: str) -> Response:
     """
     Handler for /classifier/juniper-link-info that
@@ -372,6 +375,7 @@ def handle_link_info_request(source_equipment: str, interface: str) -> Response:
 
 @routes.route("/epipe-sap-info/<source_equipment>/<service_id>/<vpn_id>", methods=['GET'])
 @common.require_accepts_json
+@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
 def handle_epipe_sap_info_request(source_equipment: str, service_id: str, vpn_id: str) -> Response:
 
     r = common.get_current_redis()
@@ -600,6 +604,7 @@ def find_interfaces(address):
 
 @routes.route("/peer-info/<address_str>", methods=['GET'])
 @common.require_accepts_json
+@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
 def peer_info(address_str: str) -> Response:
     """
     Handler for /classifier/peer-info that returns bgp peering metadata.
@@ -695,6 +700,7 @@ def peer_info(address_str: str) -> Response:
 @routes.route(
     "/mtc-interface-info/<node>/<interface>", methods=['GET'])
 @common.require_accepts_json
+@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
 def get_mtc_interface_info(node, interface):
     """
     Handler for /classifier/mtc-interface-info that
@@ -736,6 +742,7 @@ def get_mtc_interface_info(node, interface):
               "<source_equipment>/<interface>/<circuit_id>",
               methods=['GET'])
 @common.require_accepts_json
+@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
 def get_trap_metadata(source_equipment: str, interface: str, circuit_id: str) \
         -> Response:
     """
@@ -829,6 +836,7 @@ def get_trap_metadata(source_equipment: str, interface: str, circuit_id: str) \
 @routes.route("/infinera-fiberlink-info/<ne_name_str>/<object_name_str>",
               methods=['GET'])
 @common.require_accepts_json
+@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
 def get_fiberlink_trap_metadata(ne_name_str: str, object_name_str: str) \
         -> Response:
     """
@@ -965,6 +973,7 @@ def get_fiberlink_trap_metadata(ne_name_str: str, object_name_str: str) \
 
 @routes.route("/tnms-fibre-info/<path:enms_pc_name>", methods=['GET'])
 @common.require_accepts_json
+@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
 def get_tnms_fibre_trap_metadata(enms_pc_name: str) -> Response:
     """
     Handler for /classifier/infinera-fiberlink-info that
@@ -1067,6 +1076,7 @@ def get_tnms_fibre_trap_metadata(enms_pc_name: str) -> Response:
 @routes.route('/coriant-port-info/<equipment_name>/<path:entity_string>',
               methods=['GET'])
 @common.require_accepts_json
+@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
 def get_coriant_port_info(equipment_name: str, entity_string: str) -> Response:
     """
     Handler for /classifier/coriant-info that
@@ -1088,6 +1098,7 @@ def get_coriant_port_info(equipment_name: str, entity_string: str) -> Response:
 @routes.route('/coriant-tp-info/<equipment_name>/<path:entity_string>',
               methods=['GET'])
 @common.require_accepts_json
+@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
 def get_coriant_tp_info(equipment_name: str, entity_string: str) -> Response:
     """
     Handler for /classifier/coriant-info that
@@ -1198,6 +1209,7 @@ def _get_coriant_info(
 
 @routes.route("/router-info", methods=["GET"])
 @common.require_accepts_json
+@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
 def get_all_routers() -> Response:
     redis = common.get_current_redis()
     result = cache_result(
@@ -1220,6 +1232,7 @@ def _get_router_list(redis):
 
 @routes.route("/router-info/<equipment_name>", methods=["GET"])
 @common.require_accepts_json
+@authorize(allowed_clients=[DASHBOARD_SERVICE_NAME])
 def get_router_info(equipment_name: str) -> Response:
     redis = common.get_current_redis()
     ims_equipment_name = get_ims_equipment_name_or_none(equipment_name, redis)
diff --git a/inventory_provider/routes/msr.py b/inventory_provider/routes/msr.py
index fe40bca8..ac2aa9b9 100644
--- a/inventory_provider/routes/msr.py
+++ b/inventory_provider/routes/msr.py
@@ -118,7 +118,6 @@ from inventory_provider.routes.common import _ignore_cache_or_retrieve, \
     ims_equipment_to_hostname
 from inventory_provider.routes.poller import get_services
 from inventory_provider.tasks import common as tasks_common
-from inventory_provider.auth import authorize
 
 routes = Blueprint('msr-query-routes', __name__)
 logger = logging.getLogger(__name__)
@@ -1448,7 +1447,6 @@ def _asn_peers(asn, group, instance):
 @routes.route('/asn-peers', methods=['GET'], defaults={'asn': None})
 @routes.route('/asn-peers/<int:asn>', methods=['GET'])
 @common.require_accepts_json
-@authorize(allowed_clients="reporting")
 def asn_peers_get(asn):
     """
     cf. doc for _asn_peers
diff --git a/inventory_provider/routes/testing.py b/inventory_provider/routes/testing.py
index c43dccb4..5fcb7cce 100644
--- a/inventory_provider/routes/testing.py
+++ b/inventory_provider/routes/testing.py
@@ -12,7 +12,6 @@ from inventory_provider import juniper
 from inventory_provider.routes import common
 from inventory_provider.tasks import worker
 from inventory_provider.tasks import common as worker_common
-from inventory_provider.auth import authorize
 
 routes = Blueprint("inventory-data-testing-support-routes", __name__)
 
@@ -111,7 +110,6 @@ def routers_from_config_dir():
 
 
 @routes.route("latchdb", methods=['GET'])
-@authorize(allowed_clients=("brian", "dashboard"))
 def latch_db():
     config = current_app.config["INVENTORY_PROVIDER_CONFIG"]
     worker_common.latch_db(config)
-- 
GitLab