From 25c0d990a5f80fc209fd0eddcfd96100c821870f Mon Sep 17 00:00:00 2001
From: Adeel Ahmad <adeel.ahmad@geant.org>
Date: Fri, 21 Mar 2025 10:18:13 +0000
Subject: [PATCH] Add unit tests for API authentication flow

---
 inventory_provider/auth.py              | 13 ++---
 inventory_provider/routes/classifier.py |  3 +-
 test/conftest.py                        | 11 ++++
 test/test_auth.py                       | 75 +++++++++++++++++++++++++
 test/test_flask_config.py               | 11 ++++
 5 files changed, 103 insertions(+), 10 deletions(-)
 create mode 100644 test/test_auth.py

diff --git a/inventory_provider/auth.py b/inventory_provider/auth.py
index ba5ef9b6..1dd8ff81 100644
--- a/inventory_provider/auth.py
+++ b/inventory_provider/auth.py
@@ -1,7 +1,8 @@
+from functools import wraps
 from flask import current_app, g, jsonify
 from flask_httpauth import HTTPTokenAuth
-from functools import wraps
-from config import ANONYMOUS_SERVICE_NAME
+
+from inventory_provider.config import ANONYMOUS_SERVICE_NAME
 
 auth = HTTPTokenAuth(scheme="ApiKey")
 
@@ -28,17 +29,11 @@ def authorize(*, allowed_clients):
         @wraps(f)
         def wrapped(*args, **kwargs):
             client = g.get("auth_client")
-
-            if not client:
-                return jsonify({"error": "Unauthorized"}), 403
-
             if client not in allowed_clients:
                 # Anonymous clients are allowed to access any resource without providing an API key
                 # TODO: Only for testing, should be removed in Production
                 if client != ANONYMOUS_SERVICE_NAME:
                     return jsonify({"error": "Forbidden"}), 403
-
             return f(*args, **kwargs)
-
         return wrapped
-    return decorator
\ No newline at end of file
+    return decorator
diff --git a/inventory_provider/routes/classifier.py b/inventory_provider/routes/classifier.py
index 8b65079b..018df7f5 100644
--- a/inventory_provider/routes/classifier.py
+++ b/inventory_provider/routes/classifier.py
@@ -67,7 +67,8 @@ from redis import Redis
 from inventory_provider.routes import common
 from inventory_provider.routes.common import _ignore_cache_or_retrieve, cache_result
 
-from inventory_provider.config import authorize, DASHBOARD_SERVICE_NAME
+from inventory_provider.auth import authorize
+from inventory_provider.config import DASHBOARD_SERVICE_NAME
 
 routes = Blueprint("inventory-data-classifier-support-routes", __name__)
 
diff --git a/test/conftest.py b/test/conftest.py
index 1c276b4c..35412682 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -70,6 +70,17 @@ def data_config_filename():
 
     with tempfile.NamedTemporaryFile() as f:
         config = {
+            "api-keys": {
+                "brian": {
+                    "api-key": "brian_key"
+                },
+                "dashboard": {
+                    "api-key": "dashboard_key"
+                },
+                "reporting": {
+                    "api-key": "reporting_key"
+                },
+            },
             "ssh": {
                 "username": "uSeR-NaMe",
                 "private-key": "private-key-filename",
diff --git a/test/test_auth.py b/test/test_auth.py
new file mode 100644
index 00000000..226ab7a5
--- /dev/null
+++ b/test/test_auth.py
@@ -0,0 +1,75 @@
+import jsonschema
+
+from inventory_provider.routes.classifier_schema import (
+    ROUTER_INFO_ALL_ROUTERS_RESPONSE_SCHEMA,
+)
+
+DEFAULT_REQUEST_HEADERS_NO_KEY = {
+    "Content-type": "application/json",
+    "Accept": ["application/json"],
+}
+
+DEFAULT_REQUEST_HEADERS_BRIAN_KEY = {
+    "Content-type": "application/json",
+    "Accept": ["application/json"],
+    "Authorization": "ApiKey brian_key",
+}
+
+DEFAULT_REQUEST_HEADERS_REPORTING_KEY = {
+    "Content-type": "application/json",
+    "Accept": ["application/json"],
+    "Authorization": "ApiKey reporting_key",
+}
+
+DEFAULT_REQUEST_HEADERS_DASHBOARD_KEY = {
+    "Content-type": "application/json",
+    "Accept": ["application/json"],
+    "Authorization": "ApiKey dashboard_key",
+}
+
+DEFAULT_REQUEST_HEADERS_BAD_KEY = {
+    "Content-type": "application/json",
+    "Accept": ["application/json"],
+    "Authorization": "ApiKey badapikey",
+}
+
+def test_classifier_router_no_key(client):
+    rv = client.get("/classifier/router-info", headers=DEFAULT_REQUEST_HEADERS_NO_KEY)
+    assert rv.status_code == 200
+    assert rv.is_json
+    result = rv.json
+
+    jsonschema.validate(result, ROUTER_INFO_ALL_ROUTERS_RESPONSE_SCHEMA)
+    assert len(result) > 0
+
+def test_classifier_router_dashboard_key(client):
+    rv = client.get("/classifier/router-info", headers=DEFAULT_REQUEST_HEADERS_DASHBOARD_KEY)
+    assert rv.status_code == 200
+    assert rv.is_json
+    result = rv.json
+
+    jsonschema.validate(result, ROUTER_INFO_ALL_ROUTERS_RESPONSE_SCHEMA)
+    assert len(result) > 0
+
+def test_classifier_router_brian_key(client):
+    rv = client.get("/classifier/router-info", headers=DEFAULT_REQUEST_HEADERS_BRIAN_KEY)
+    assert rv.status_code == 403
+    assert rv.is_json
+    result = rv.json
+
+    assert result["error"] == "Forbidden"
+
+def test_classifier_router_reporting_key(client):
+    rv = client.get("/classifier/router-info", headers=DEFAULT_REQUEST_HEADERS_REPORTING_KEY)
+    assert rv.status_code == 403
+    assert rv.is_json
+    result = rv.json
+
+    assert result["error"] == "Forbidden"
+
+def test_classifier_router_bad_key(client):
+    rv = client.get("/classifier/router-info", headers=DEFAULT_REQUEST_HEADERS_BAD_KEY)
+    assert rv.status_code == 401
+    result = rv.text
+
+    assert result == "Unauthorized Access"
diff --git a/test/test_flask_config.py b/test/test_flask_config.py
index 7ba83ec5..48f7e29c 100644
--- a/test/test_flask_config.py
+++ b/test/test_flask_config.py
@@ -7,6 +7,17 @@ from inventory_provider.config import CONFIG_SCHEMA
 @pytest.fixture
 def config():
     return {
+        "api-keys": {
+            "brian": {
+                "api-key": "brian_key"
+            },
+            "dashboard": {
+                "api-key": "dashboard_key"
+            },
+            "reporting": {
+                "api-key": "reporting_key"
+            },
+        },
         'redis': {
             'hostname': 'localhost',
             'port': 6379,
-- 
GitLab