diff --git a/gso/api/v1/__init__.py b/gso/api/v1/__init__.py index 89fd2c8eb6d1455a1f80cf3be6ddd552918ad775..6553f1f83b6a31d91aee49224d72242c937820c8 100644 --- a/gso/api/v1/__init__.py +++ b/gso/api/v1/__init__.py @@ -1,7 +1,9 @@ from fastapi import APIRouter from gso.api.v1.imports import router as imports_router +from gso.api.v1.subscriptions import router as subscriptions_router router = APIRouter() router.include_router(imports_router) +router.include_router(subscriptions_router) diff --git a/gso/api/v1/imports.py b/gso/api/v1/imports.py index 51f58c5d62705969f054a2d680aab9e806ff2171..5a7ef6c6ef791747fda0e82ee3300e4c45df4edf 100644 --- a/gso/api/v1/imports.py +++ b/gso/api/v1/imports.py @@ -77,7 +77,8 @@ class IptrunkImportModel(BaseModel): @classmethod def _get_active_routers(cls) -> set[str]: return { - str(router_id) for router_id in subscriptions.get_active_router_subscriptions(fields=["subscription_id"]) + str(router["subscription_id"]) + for router in subscriptions.get_active_router_subscriptions(includes=["subscription_id"]) } @validator("customer") diff --git a/gso/api/v1/subscriptions.py b/gso/api/v1/subscriptions.py new file mode 100644 index 0000000000000000000000000000000000000000..65eae878f18cec1c9aa5e45693fc60f029b06d68 --- /dev/null +++ b/gso/api/v1/subscriptions.py @@ -0,0 +1,24 @@ +from typing import Any + +from fastapi import Depends, status +from fastapi.routing import APIRouter +from orchestrator.domain import SubscriptionModel +from orchestrator.schemas import SubscriptionDomainModelSchema +from orchestrator.security import opa_security_default +from orchestrator.services.subscriptions import build_extended_domain_model + +from gso.services.subscriptions import get_active_router_subscriptions + +router = APIRouter(prefix="/subscriptions", tags=["Subscriptions"], dependencies=[Depends(opa_security_default)]) + + +@router.get("/routers", status_code=status.HTTP_200_OK, response_model=list[SubscriptionDomainModelSchema]) +def subscription_routers() -> list[dict[str, Any]]: + """Retrieve all active routers subscriptions.""" + subscriptions = [] + for r in get_active_router_subscriptions(): + subscription = SubscriptionModel.from_subscription(r["subscription_id"]) + extended_model = build_extended_domain_model(subscription) + subscriptions.append(extended_model) + + return subscriptions diff --git a/gso/services/subscriptions.py b/gso/services/subscriptions.py index f1c6b075a6fe211217ab3ad407a1ef371df47d9c..c6c3ffaf89ad6c503cb67b7a65c282f7ae6df1ec 100644 --- a/gso/services/subscriptions.py +++ b/gso/services/subscriptions.py @@ -1,6 +1,6 @@ +from typing import Any from uuid import UUID -from asyncio_redis import Subscription from orchestrator.db import ( ProductTable, ResourceTypeTable, @@ -8,57 +8,71 @@ from orchestrator.db import ( SubscriptionInstanceValueTable, SubscriptionTable, ) +from orchestrator.graphql.schemas.subscription import Subscription from orchestrator.types import SubscriptionLifecycle from gso.products import ProductType +SubscriptionType = dict[str, Any] -def get_active_subscriptions(product_type: str, fields: list[str]) -> list[Subscription]: + +def get_active_subscriptions( + product_type: str, + includes: list[str] | None = None, + excludes: list[str] | None = None, +) -> list[SubscriptionType]: """Retrieve active subscriptions for a specific product type. :param product_type: The type of the product for which to retrieve subscriptions. :type product_type: str - :param fields: List of fields to be included in the returned Subscription objects. - :type fields: list[str] + :param includes: List of fields to be included in the returned Subscription objects. + :type includes: list[str] + :param excludes: List of fields to be excluded from the returned Subscription objects. + :type excludes: list[str] :return: A list of Subscription objects that match the query. :rtype: list[Subscription] """ - dynamic_fields = [getattr(SubscriptionTable, field) for field in fields] + if not includes: + includes = [col.name for col in SubscriptionTable.__table__.columns] - return ( - SubscriptionTable.query.join(ProductTable) - .filter( - ProductTable.product_type == product_type, - SubscriptionTable.status == SubscriptionLifecycle.ACTIVE, - ) - .with_entities(*dynamic_fields) - .all() + if excludes: + includes = [field for field in includes if field not in excludes] + + dynamic_fields = [getattr(SubscriptionTable, field) for field in includes] + + query = SubscriptionTable.query.join(ProductTable).filter( + ProductTable.product_type == product_type, + SubscriptionTable.status == SubscriptionLifecycle.ACTIVE, ) + results = query.with_entities(*dynamic_fields).all() + + return [dict(zip(includes, result)) for result in results] + -def get_active_site_subscriptions(fields: list[str]) -> list[Subscription]: +def get_active_site_subscriptions(includes: list[str] | None = None) -> list[SubscriptionType]: """Retrieve active subscriptions specifically for sites. - :param fields: The fields to be included in the returned Subscription objects. - :type fields: list[str] + :param includes: The fields to be included in the returned Subscription objects. + :type includes: list[str] :return: A list of Subscription objects for sites. :rtype: list[Subscription] """ - return get_active_subscriptions(ProductType.SITE, fields) + return get_active_subscriptions(product_type=ProductType.SITE, includes=includes) -def get_active_router_subscriptions(fields: list[str]) -> list[Subscription]: +def get_active_router_subscriptions(includes: list[str] | None = None) -> list[SubscriptionType]: """Retrieve active subscriptions specifically for routers. - :param fields: The fields to be included in the returned Subscription objects. - :type fields: list[str] + :param includes: The fields to be included in the returned Subscription objects. + :type includes: list[str] :return: A list of Subscription objects for routers. :rtype: list[Subscription] """ - return get_active_subscriptions(product_type=ProductType.ROUTER, fields=fields) + return get_active_subscriptions(product_type=ProductType.ROUTER, includes=includes) def get_product_id_by_name(product_name: ProductType) -> UUID: diff --git a/gso/workflows/iptrunk/create_iptrunk.py b/gso/workflows/iptrunk/create_iptrunk.py index f022181b5060632807b5c67aa9548aba9876505e..4f428fcaecc3a5c827136f38e9ff8923c9e1452c 100644 --- a/gso/workflows/iptrunk/create_iptrunk.py +++ b/gso/workflows/iptrunk/create_iptrunk.py @@ -33,10 +33,9 @@ def initial_input_form_generator(product_name: str) -> FormGenerator: # * interface names must be validated routers = {} - for router_id, router_description in subscriptions.get_active_router_subscriptions( - fields=["subscription_id", "description"] - ): - routers[str(router_id)] = router_description + + for router in subscriptions.get_active_router_subscriptions(includes=["subscription_id", "description"]): + routers[str(router["subscription_id"])] = router["description"] class CreateIptrunkForm(FormPage): class Config: diff --git a/gso/workflows/router/create_router.py b/gso/workflows/router/create_router.py index 79311e470df340386502cb94f7e13d7b67ddf901..0902f20d8a6290f8e981a20056d6d367ef0a9f13 100644 --- a/gso/workflows/router/create_router.py +++ b/gso/workflows/router/create_router.py @@ -23,10 +23,8 @@ from gso.utils.helpers import iso_from_ipv4 def _site_selector() -> Choice: site_subscriptions = {} - for site_id, site_description in subscriptions.get_active_site_subscriptions( - fields=["subscription_id", "description"] - ): - site_subscriptions[str(site_id)] = site_description + for site in subscriptions.get_active_site_subscriptions(includes=["subscription_id", "description"]): + site_subscriptions[str(site["subscription_id"])] = site["description"] # noinspection PyTypeChecker return Choice("Select a site", zip(site_subscriptions.keys(), site_subscriptions.items())) # type: ignore[arg-type] diff --git a/gso/workflows/tasks/import_iptrunk.py b/gso/workflows/tasks/import_iptrunk.py index 84c12aa6667962eb1a40184384fad34d0a85ba04..9c36d62c06c0b197a1df3a0fe290aaae8c771c75 100644 --- a/gso/workflows/tasks/import_iptrunk.py +++ b/gso/workflows/tasks/import_iptrunk.py @@ -20,10 +20,9 @@ from gso.workflows.iptrunk.create_iptrunk import initialize_subscription def _generate_routers() -> dict[str, str]: """Generate a dictionary of router IDs and descriptions.""" routers = {} - for router_id, router_description in subscriptions.get_active_router_subscriptions( - fields=["subscription_id", "description"] - ): - routers[str(router_id)] = router_description + for subscription in subscriptions.get_active_router_subscriptions(includes=["subscription_id", "description"]): + routers[str(subscription["subscription_id"])] = subscription["description"] + return routers diff --git a/test/fixtures.py b/test/fixtures.py index ec2b2bd56179f2fa6e54dfe69bfa78d21a408c08..601463de6f392f13b83b2efa7b74b9443255a636 100644 --- a/test/fixtures.py +++ b/test/fixtures.py @@ -86,6 +86,7 @@ def router_subscription_factory(site_subscription_factory, faker): router_role=RouterRole.PE, router_site=None, router_is_ias_connected=True, + status: SubscriptionLifecycle | None = None, ) -> UUIDstr: description = description or faker.text(max_nb_chars=30) router_fqdn = router_fqdn or faker.domain_name(levels=4) @@ -118,6 +119,10 @@ def router_subscription_factory(site_subscription_factory, faker): router_subscription = SubscriptionModel.from_other_lifecycle(router_subscription, SubscriptionLifecycle.ACTIVE) router_subscription.description = description router_subscription.start_date = start_date + + if status: + router_subscription.status = status + router_subscription.save() db.session.commit() diff --git a/test/imports/test_imports.py b/test/imports/test_imports.py index 3cdfa3ed1d5f87abd72521198220a58687b7960d..34b39be69cb72eb57b3ee892ffadacb9561f6e9d 100644 --- a/test/imports/test_imports.py +++ b/test/imports/test_imports.py @@ -48,13 +48,23 @@ def mock_routers(iptrunk_data): with patch("gso.services.subscriptions.get_active_router_subscriptions") as mock_get_active_router_subscriptions: def _active_router_subscriptions(*args, **kwargs): - if kwargs["fields"] == ["subscription_id", "description"]: + if kwargs["includes"] == ["subscription_id", "description"]: return [ - (iptrunk_data["side_a_node_id"], "side_a_node_id description"), - (iptrunk_data["side_b_node_id"], "side_b_node_id description"), - (str(uuid4()), "random description"), + { + "subscription_id": iptrunk_data["side_a_node_id"], + "description": "iptrunk_sideA_node_id description", + }, + { + "subscription_id": iptrunk_data["side_b_node_id"], + "description": "iptrunk_sideB_node_id description", + }, + {"subscription_id": str(uuid4()), "description": "random description"}, ] - return [iptrunk_data["side_a_node_id"], iptrunk_data["side_b_node_id"], str(uuid4())] + return [ + {"subscription_id": iptrunk_data["side_a_node_id"]}, + {"subscription_id": iptrunk_data["side_b_node_id"]}, + {"subscription_id": str(uuid4())}, + ] mock_get_active_router_subscriptions.side_effect = _active_router_subscriptions yield mock_get_active_router_subscriptions @@ -200,6 +210,9 @@ def test_import_iptrunk_invalid_customer(mock_start_process, test_client, mock_r @patch("gso.api.v1.imports._start_process") def test_import_iptrunk_invalid_router_id_side_a_and_b(mock_start_process, test_client, iptrunk_data): + iptrunk_data["side_a_node_id"] = "NOT FOUND" + iptrunk_data["side_b_node_id"] = "NOT FOUND" + mock_start_process.return_value = "123e4567-e89b-12d3-a456-426655440000" response = test_client.post(IPTRUNK_IMPORT_API_URL, json=iptrunk_data) diff --git a/test/subscriptions/__init__.py b/test/subscriptions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/subscriptions/conftest.py b/test/subscriptions/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..425a0e627a4592241e2c3f81cce910255dd34a5e --- /dev/null +++ b/test/subscriptions/conftest.py @@ -0,0 +1 @@ +from test.fixtures import router_subscription_factory, site_subscription_factory # noqa diff --git a/test/subscriptions/test_subscriptions.py b/test/subscriptions/test_subscriptions.py new file mode 100644 index 0000000000000000000000000000000000000000..8e9980996c0486a65759ff820883f403399e684a --- /dev/null +++ b/test/subscriptions/test_subscriptions.py @@ -0,0 +1,16 @@ +from orchestrator.types import SubscriptionLifecycle + +ROUTER_SUBSCRIPTION_ENDPOINT = "/api/v1/subscriptions/routers" + + +def test_router_subscriptions_endpoint(test_client, router_subscription_factory): + router_subscription_factory() + router_subscription_factory() + router_subscription_factory() + router_subscription_factory(status=SubscriptionLifecycle.TERMINATED) + router_subscription_factory(status=SubscriptionLifecycle.INITIAL) + + response = test_client.get(ROUTER_SUBSCRIPTION_ENDPOINT) + + assert response.status_code == 200 + assert len(response.json()) == 3