From 7517ac4de5b1cfd24a4ee4cd199e045e6b41dcc8 Mon Sep 17 00:00:00 2001
From: Karel van Klink <karel.vanklink@geant.org>
Date: Thu, 21 Mar 2024 16:33:18 +0100
Subject: [PATCH] Allow iBGP update on both provisioning and active trunks

Updated unit tests accordingly. iBGP update is allowed only on routers with a provisioning or active trunk, not initial or terminated
---
 gso/services/subscriptions.py                 | 17 ++++++-----
 gso/workflows/router/update_ibgp_mesh.py      |  9 ++++--
 .../workflows/router/test_update_ibgp_mesh.py | 29 ++++++++++++++-----
 3 files changed, 37 insertions(+), 18 deletions(-)

diff --git a/gso/services/subscriptions.py b/gso/services/subscriptions.py
index d1e789aa..8d9900a3 100644
--- a/gso/services/subscriptions.py
+++ b/gso/services/subscriptions.py
@@ -129,14 +129,17 @@ def get_active_iptrunk_subscriptions(includes: list[str] | None = None) -> list[
     )
 
 
-def get_active_trunks_that_terminate_on_router(subscription_id: UUIDstr) -> list[SubscriptionTable]:
-    """Get all IP trunk subscriptions that are active, and terminate on the given ``subscription_id`` of a Router.
+def get_trunks_that_terminate_on_router(
+    subscription_id: UUIDstr, lifecycle_state: SubscriptionLifecycle
+) -> list[SubscriptionTable]:
+    """Get all IP trunk subscriptions that terminate on the given ``subscription_id`` of a Router.
 
-    Given a ``subscription_id`` of a Router subscription, this method gives a list of all active IP trunk subscriptions
-    that terminate on this Router.
+    Given a ``subscription_id`` of a Router subscription, this method gives a list of all IP trunk subscriptions that
+    terminate on this Router. The given lifecycle state dictates the state of trunk subscriptions that are counted as
+    terminating on this router.
 
-    :param subscription_id: Subscription ID of a Router
-    :type subscription_id: UUIDstr
+    :param UUIDstr subscription_id: Subscription ID of a Router
+    :param SubscriptionLifecycle lifecycle_state: Required lifecycle state of the IP trunk
 
     :return: A list of IP trunk subscriptions
     :rtype: list[SubscriptionTable]
@@ -146,7 +149,7 @@ def get_active_trunks_that_terminate_on_router(subscription_id: UUIDstr) -> list
         .join(ProductTable)
         .filter(
             ProductTable.product_type == ProductType.IP_TRUNK,
-            SubscriptionTable.status == SubscriptionLifecycle.ACTIVE,
+            SubscriptionTable.status == lifecycle_state,
         )
         .all()
     )
diff --git a/gso/workflows/router/update_ibgp_mesh.py b/gso/workflows/router/update_ibgp_mesh.py
index 80f63d21..a3b20576 100644
--- a/gso/workflows/router/update_ibgp_mesh.py
+++ b/gso/workflows/router/update_ibgp_mesh.py
@@ -6,7 +6,7 @@ from orchestrator.config.assignee import Assignee
 from orchestrator.forms import FormPage
 from orchestrator.forms.validators import Label
 from orchestrator.targets import Target
-from orchestrator.types import FormGenerator, State, UUIDstr
+from orchestrator.types import FormGenerator, State, SubscriptionLifecycle, UUIDstr
 from orchestrator.workflow import StepList, done, init, inputstep, step, workflow
 from orchestrator.workflows.steps import resync, store_process_subscription, unsync
 from orchestrator.workflows.utils import wrap_modify_initial_input_form
@@ -16,7 +16,7 @@ from gso.products.product_blocks.router import RouterRole
 from gso.products.product_types.router import Router
 from gso.services import librenms_client, lso_client, subscriptions
 from gso.services.lso_client import lso_interaction
-from gso.services.subscriptions import get_active_trunks_that_terminate_on_router
+from gso.services.subscriptions import get_trunks_that_terminate_on_router
 from gso.utils.helpers import SNMPVersion
 
 
@@ -36,7 +36,10 @@ def initial_input_form_generator(subscription_id: UUIDstr) -> FormGenerator:
 
         @root_validator(allow_reuse=True)
         def router_has_a_trunk(cls, values: dict[str, Any]) -> dict[str, Any]:
-            if len(get_active_trunks_that_terminate_on_router(subscription_id)) == 0:
+            terminating_trunks = get_trunks_that_terminate_on_router(
+                subscription_id, SubscriptionLifecycle.PROVISIONING
+            ) + get_trunks_that_terminate_on_router(subscription_id, SubscriptionLifecycle.ACTIVE)
+            if len(terminating_trunks) == 0:
                 msg = "Selected router does not terminate any active IP trunks."
                 raise ValueError(msg)
 
diff --git a/test/workflows/router/test_update_ibgp_mesh.py b/test/workflows/router/test_update_ibgp_mesh.py
index b2f6756b..41dd4371 100644
--- a/test/workflows/router/test_update_ibgp_mesh.py
+++ b/test/workflows/router/test_update_ibgp_mesh.py
@@ -1,6 +1,7 @@
 from unittest.mock import patch
 
 import pytest
+from orchestrator.types import SubscriptionLifecycle
 from orchestrator.workflow import StepStatus
 from pydantic_forms.exceptions import FormValidationError
 
@@ -16,23 +17,22 @@ from test.workflows import (
 )
 
 
-@pytest.fixture()
-def ibgp_mesh_input_form_data(iptrunk_subscription_factory, faker):
-    ip_trunk = Iptrunk.from_subscription(iptrunk_subscription_factory())
-
-    return {"subscription_id": ip_trunk.iptrunk.iptrunk_sides[0].iptrunk_side_node.owner_subscription_id}
-
-
+@pytest.mark.parametrize("trunk_status", [SubscriptionLifecycle.PROVISIONING, SubscriptionLifecycle.ACTIVE])
 @pytest.mark.workflow()
 @patch("gso.workflows.router.update_ibgp_mesh.lso_client.execute_playbook")
 @patch("gso.workflows.router.update_ibgp_mesh.librenms_client.LibreNMSClient.add_device")
 def test_update_ibgp_mesh_success(
     mock_librenms_add_device,
     mock_execute_playbook,
-    ibgp_mesh_input_form_data,
+    trunk_status,
+    iptrunk_subscription_factory,
     data_config_filename,
     faker,
 ):
+    ip_trunk = Iptrunk.from_subscription(iptrunk_subscription_factory(status=trunk_status))
+    ibgp_mesh_input_form_data = {
+        "subscription_id": ip_trunk.iptrunk.iptrunk_sides[0].iptrunk_side_node.owner_subscription_id
+    }
     result, process_stat, step_log = run_workflow(
         "update_ibgp_mesh", [ibgp_mesh_input_form_data, {"tt_number": faker.tt_number()}]
     )
@@ -53,6 +53,19 @@ def test_update_ibgp_mesh_success(
     assert state["subscription"]["router"]["router_access_via_ts"] is False
 
 
+@pytest.mark.parametrize("trunk_status", [SubscriptionLifecycle.INITIAL, SubscriptionLifecycle.TERMINATED])
+@pytest.mark.workflow()
+def test_update_ibgp_mesh_failure(iptrunk_subscription_factory, data_config_filename, trunk_status):
+    ip_trunk = Iptrunk.from_subscription(iptrunk_subscription_factory(status=trunk_status))
+    ibgp_mesh_input_form_data = {
+        "subscription_id": ip_trunk.iptrunk.iptrunk_sides[0].iptrunk_side_node.owner_subscription_id
+    }
+
+    exception_message = "Selected router does not terminate any active IP trunks."
+    with pytest.raises(FormValidationError, match=exception_message):
+        run_workflow("update_ibgp_mesh", [ibgp_mesh_input_form_data, {}])
+
+
 @pytest.mark.workflow()
 def test_update_ibgp_mesh_isolated_router(nokia_router_subscription_factory, data_config_filename):
     router_id = nokia_router_subscription_factory(router_role=RouterRole.P)
-- 
GitLab