From 1fccf4e345fc4c617788700a55db86745f234291 Mon Sep 17 00:00:00 2001
From: Mohammad Torkashvand <mohammad.torkashvand@geant.org>
Date: Wed, 20 Sep 2023 19:10:40 +0200
Subject: [PATCH] add 2 test cases for create_iptrunk workflow

---
 gso/services/provisioning_proxy.py            |   3 +-
 requirements.txt                              |   3 +-
 test/conftest.py                              |   3 +
 test/fixtures.py                              | 124 ++++++++++
 test/imports/conftest.py                      | 125 +---------
 test/workflows/__init__.py                    | 217 ++++++++++++++++++
 test/workflows/conftest.py                    |  27 +++
 test/workflows/iptrunks/__init__.py           |   0
 .../iptrunks/test_create_iptrunks.py          | 152 ++++++++++++
 test/workflows/site/__init__.py               |   0
 test/workflows/site/test_create_site.py       |  40 ++++
 11 files changed, 567 insertions(+), 127 deletions(-)
 create mode 100644 test/fixtures.py
 create mode 100644 test/workflows/__init__.py
 create mode 100644 test/workflows/conftest.py
 create mode 100644 test/workflows/iptrunks/__init__.py
 create mode 100644 test/workflows/iptrunks/test_create_iptrunks.py
 create mode 100644 test/workflows/site/__init__.py
 create mode 100644 test/workflows/site/test_create_site.py

diff --git a/gso/services/provisioning_proxy.py b/gso/services/provisioning_proxy.py
index 7ce76ac9f..0ec3d7cc1 100644
--- a/gso/services/provisioning_proxy.py
+++ b/gso/services/provisioning_proxy.py
@@ -4,7 +4,6 @@
 """
 import json
 import logging
-from typing import NoReturn
 
 import requests
 from orchestrator import conditional, inputstep, step
@@ -252,7 +251,7 @@ def _await_pp_results(subscription: SubscriptionModel, label_text: str = DEFAULT
         confirm: Accept = Accept("INCOMPLETE")
 
         @validator("pp_run_results", allow_reuse=True, pre=True, always=True)
-        def run_results_must_be_given(cls, run_results: dict) -> dict | NoReturn:
+        def run_results_must_be_given(cls, run_results: dict) -> dict:
             if run_results is None:
                 raise ValueError("Run results may not be empty. Wait for the provisioning proxy to finish.")
             return run_results
diff --git a/requirements.txt b/requirements.txt
index 9cb2e8cf8..d1991211d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -14,4 +14,5 @@ mypy
 ruff
 sphinx
 sphinx-rtd-theme
-typer
\ No newline at end of file
+typer
+urllib3_mock
\ No newline at end of file
diff --git a/test/conftest.py b/test/conftest.py
index ee0cac0b6..16e1623d5 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -1,6 +1,7 @@
 import contextlib
 import ipaddress
 import json
+import logging
 import os
 import socket
 import tempfile
@@ -22,6 +23,8 @@ from starlette.testclient import TestClient
 
 from gso.main import init_gso_app
 
+logging.getLogger("faker.factory").setLevel(logging.WARNING)
+
 
 def pytest_collection_modifyitems(config, items):
     if bool(os.environ.get("SKIP_ALL_TESTS")):
diff --git a/test/fixtures.py b/test/fixtures.py
new file mode 100644
index 000000000..2f3ba8993
--- /dev/null
+++ b/test/fixtures.py
@@ -0,0 +1,124 @@
+import ipaddress
+
+import pytest
+from orchestrator.db import db
+from orchestrator.domain import SubscriptionModel
+from orchestrator.types import SubscriptionLifecycle, UUIDstr
+
+from gso.products.product_blocks.router import RouterRole, RouterVendor
+from gso.products.product_blocks.site import SiteTier
+from gso.products.product_types.router import RouterInactive
+from gso.products.product_types.site import Site, SiteInactive
+from gso.schemas.enums import ProductType
+from gso.services import subscriptions
+
+CUSTOMER_ID: UUIDstr = "2f47f65a-0911-e511-80d0-005056956c1a"
+
+
+@pytest.fixture
+def site_subscription_factory(faker):
+    def subscription_create(
+        description=None,
+        start_date="2023-05-24T00:00:00+00:00",
+        site_name=None,
+        site_city=None,
+        site_country=None,
+        site_country_code=None,
+        site_latitude=None,
+        site_longitude=None,
+        site_bgp_community_id=None,
+        site_internal_id=None,
+        site_tier=SiteTier.TIER1,
+        site_ts_address=None,
+    ) -> UUIDstr:
+        description = description or "Site Subscription"
+        site_name = site_name or faker.name()
+        site_city = site_city or faker.city()
+        site_country = site_country or faker.country()
+        site_country_code = site_country_code or faker.country_code()
+        site_latitude = site_latitude or float(faker.latitude())
+        site_longitude = site_longitude or float(faker.longitude())
+        site_bgp_community_id = site_bgp_community_id or faker.pyint()
+        site_internal_id = site_internal_id or faker.pyint()
+        site_ts_address = site_ts_address or faker.ipv4()
+
+        product_id = subscriptions.get_product_id_by_name(ProductType.SITE)
+        site_subscription = SiteInactive.from_product_id(product_id, customer_id=CUSTOMER_ID, insync=True)
+        site_subscription.site.site_city = site_city
+        site_subscription.site.site_name = site_name
+        site_subscription.site.site_country = site_country
+        site_subscription.site.site_country_code = site_country_code
+        site_subscription.site.site_latitude = site_latitude
+        site_subscription.site.site_longitude = site_longitude
+        site_subscription.site.site_bgp_community_id = site_bgp_community_id
+        site_subscription.site.site_internal_id = site_internal_id
+        site_subscription.site.site_tier = site_tier
+        site_subscription.site.site_ts_address = site_ts_address
+
+        site_subscription = SubscriptionModel.from_other_lifecycle(site_subscription, SubscriptionLifecycle.ACTIVE)
+        site_subscription.description = description
+        site_subscription.start_date = start_date
+        site_subscription.save()
+        db.session.commit()
+
+        return str(site_subscription.subscription_id)
+
+    return subscription_create
+
+
+@pytest.fixture
+def router_subscription_factory(site_subscription_factory, faker):
+    def subscription_create(
+        description=None,
+        start_date="2023-05-24T00:00:00+00:00",
+        router_fqdn=None,
+        router_ts_port=None,
+        router_access_via_ts=None,
+        router_lo_ipv4_address=None,
+        router_lo_ipv6_address=None,
+        router_lo_iso_address=None,
+        router_si_ipv4_network=None,
+        router_ias_lt_ipv4_network=None,
+        router_ias_lt_ipv6_network=None,
+        router_vendor=RouterVendor.NOKIA,
+        router_role=RouterRole.PE,
+        router_site=None,
+        router_is_ias_connected=True,
+    ) -> UUIDstr:
+        description = description or faker.text(max_nb_chars=30)
+        router_fqdn = router_fqdn or faker.domain_name()
+        router_ts_port = router_ts_port or faker.random_int(min=1, max=49151)
+        router_access_via_ts = router_access_via_ts or faker.boolean()
+        router_lo_ipv4_address = router_lo_ipv4_address or ipaddress.IPv4Address(faker.ipv4())
+        router_lo_ipv6_address = router_lo_ipv6_address or ipaddress.IPv6Address(faker.ipv6())
+        router_lo_iso_address = router_lo_iso_address or faker.word()
+        router_si_ipv4_network = router_si_ipv4_network or faker.ipv4_network()
+        router_ias_lt_ipv4_network = router_ias_lt_ipv4_network or faker.ipv4_network()
+        router_ias_lt_ipv6_network = router_ias_lt_ipv6_network or faker.ipv6_network()
+        router_site = router_site or site_subscription_factory()
+
+        product_id = subscriptions.get_product_id_by_name(ProductType.ROUTER)
+        router_subscription = RouterInactive.from_product_id(product_id, customer_id=CUSTOMER_ID, insync=True)
+        router_subscription.router.router_fqdn = router_fqdn
+        router_subscription.router.router_ts_port = router_ts_port
+        router_subscription.router.router_access_via_ts = router_access_via_ts
+        router_subscription.router.router_lo_ipv4_address = router_lo_ipv4_address
+        router_subscription.router.router_lo_ipv6_address = router_lo_ipv6_address
+        router_subscription.router.router_lo_iso_address = router_lo_iso_address
+        router_subscription.router.router_si_ipv4_network = router_si_ipv4_network
+        router_subscription.router.router_ias_lt_ipv4_network = router_ias_lt_ipv4_network
+        router_subscription.router.router_ias_lt_ipv6_network = router_ias_lt_ipv6_network
+        router_subscription.router.router_vendor = router_vendor
+        router_subscription.router.router_role = router_role
+        router_subscription.router.router_site = Site.from_subscription(router_site).site
+        router_subscription.router.router_is_ias_connected = router_is_ias_connected
+
+        router_subscription = SubscriptionModel.from_other_lifecycle(router_subscription, SubscriptionLifecycle.ACTIVE)
+        router_subscription.description = description
+        router_subscription.start_date = start_date
+        router_subscription.save()
+        db.session.commit()
+
+        return str(router_subscription.subscription_id)
+
+    return subscription_create
diff --git a/test/imports/conftest.py b/test/imports/conftest.py
index 320b06697..c807df14a 100644
--- a/test/imports/conftest.py
+++ b/test/imports/conftest.py
@@ -1,124 +1 @@
-import ipaddress
-
-import pytest
-from orchestrator.db import db
-from orchestrator.domain import SubscriptionModel
-from orchestrator.types import SubscriptionLifecycle, UUIDstr
-
-from gso.products.product_blocks.router import RouterRole, RouterVendor
-from gso.products.product_blocks.site import SiteTier
-from gso.products.product_types.router import RouterInactive
-from gso.products.product_types.site import Site, SiteInactive
-from gso.schemas.enums import ProductType
-from gso.services import subscriptions
-
-CUSTOMER_ID: UUIDstr = "2f47f65a-0911-e511-80d0-005056956c1a"
-
-
-@pytest.fixture
-def site_subscription_factory(faker):
-    def subscription_create(
-        description=None,
-        start_date="2023-05-24T00:00:00+00:00",
-        site_name=None,
-        site_city=None,
-        site_country=None,
-        site_country_code=None,
-        site_latitude=None,
-        site_longitude=None,
-        site_bgp_community_id=None,
-        site_internal_id=None,
-        site_tier=SiteTier.TIER1,
-        site_ts_address=None,
-    ) -> UUIDstr:
-        description = description or "Site Subscription"
-        site_name = site_name or faker.name()
-        site_city = site_city or faker.city()
-        site_country = site_country or faker.country()
-        site_country_code = site_country_code or faker.country_code()
-        site_latitude = site_latitude or float(faker.latitude())
-        site_longitude = site_longitude or float(faker.longitude())
-        site_bgp_community_id = site_bgp_community_id or faker.pyint()
-        site_internal_id = site_internal_id or faker.pyint()
-        site_ts_address = site_ts_address or faker.ipv4()
-
-        product_id = subscriptions.get_product_id_by_name(ProductType.SITE)
-        site_subscription = SiteInactive.from_product_id(product_id, customer_id=CUSTOMER_ID, insync=True)
-        site_subscription.site.site_city = site_city
-        site_subscription.site.site_name = site_name
-        site_subscription.site.site_country = site_country
-        site_subscription.site.site_country_code = site_country_code
-        site_subscription.site.site_latitude = site_latitude
-        site_subscription.site.site_longitude = site_longitude
-        site_subscription.site.site_bgp_community_id = site_bgp_community_id
-        site_subscription.site.site_internal_id = site_internal_id
-        site_subscription.site.site_tier = site_tier
-        site_subscription.site.site_ts_address = site_ts_address
-
-        site_subscription = SubscriptionModel.from_other_lifecycle(site_subscription, SubscriptionLifecycle.ACTIVE)
-        site_subscription.description = description
-        site_subscription.start_date = start_date
-        site_subscription.save()
-        db.session.commit()
-
-        return str(site_subscription.subscription_id)
-
-    return subscription_create
-
-
-@pytest.fixture
-def router_subscription_factory(site_subscription_factory, faker):
-    def subscription_create(
-        description=None,
-        start_date="2023-05-24T00:00:00+00:00",
-        router_fqdn=None,
-        router_ts_port=None,
-        router_access_via_ts=None,
-        router_lo_ipv4_address=None,
-        router_lo_ipv6_address=None,
-        router_lo_iso_address=None,
-        router_si_ipv4_network=None,
-        router_ias_lt_ipv4_network=None,
-        router_ias_lt_ipv6_network=None,
-        router_vendor=RouterVendor.NOKIA,
-        router_role=RouterRole.PE,
-        router_site=None,
-        router_is_ias_connected=True,
-    ) -> UUIDstr:
-        description = description or faker.text(max_nb_chars=30)
-        router_fqdn = router_fqdn or faker.domain_name()
-        router_ts_port = router_ts_port or faker.random_int(min=1, max=49151)
-        router_access_via_ts = router_access_via_ts or faker.boolean()
-        router_lo_ipv4_address = router_lo_ipv4_address or ipaddress.IPv4Address(faker.ipv4())
-        router_lo_ipv6_address = router_lo_ipv6_address or ipaddress.IPv6Address(faker.ipv6())
-        router_lo_iso_address = router_lo_iso_address or faker.word()
-        router_si_ipv4_network = router_si_ipv4_network or faker.ipv4_network()
-        router_ias_lt_ipv4_network = router_ias_lt_ipv4_network or faker.ipv4_network()
-        router_ias_lt_ipv6_network = router_ias_lt_ipv6_network or faker.ipv6_network()
-        router_site = router_site or site_subscription_factory()
-
-        product_id = subscriptions.get_product_id_by_name(ProductType.ROUTER)
-        router_subscription = RouterInactive.from_product_id(product_id, customer_id=CUSTOMER_ID, insync=True)
-        router_subscription.router.router_fqdn = router_fqdn
-        router_subscription.router.router_ts_port = router_ts_port
-        router_subscription.router.router_access_via_ts = router_access_via_ts
-        router_subscription.router.router_lo_ipv4_address = router_lo_ipv4_address
-        router_subscription.router.router_lo_ipv6_address = router_lo_ipv6_address
-        router_subscription.router.router_lo_iso_address = router_lo_iso_address
-        router_subscription.router.router_si_ipv4_network = router_si_ipv4_network
-        router_subscription.router.router_ias_lt_ipv4_network = router_ias_lt_ipv4_network
-        router_subscription.router.router_ias_lt_ipv6_network = router_ias_lt_ipv6_network
-        router_subscription.router.router_vendor = router_vendor
-        router_subscription.router.router_role = router_role
-        router_subscription.router.router_site = Site.from_subscription(router_site).site
-        router_subscription.router.router_is_ias_connected = router_is_ias_connected
-
-        router_subscription = SubscriptionModel.from_other_lifecycle(router_subscription, SubscriptionLifecycle.ACTIVE)
-        router_subscription.description = description
-        router_subscription.start_date = start_date
-        router_subscription.save()
-        db.session.commit()
-
-        return str(router_subscription.subscription_id)
-
-    return subscription_create
+from test.fixtures import router_subscription_factory, site_subscription_factory  # noqa
diff --git a/test/workflows/__init__.py b/test/workflows/__init__.py
new file mode 100644
index 000000000..02c91e9c4
--- /dev/null
+++ b/test/workflows/__init__.py
@@ -0,0 +1,217 @@
+import difflib
+import pprint
+from copy import deepcopy
+from typing import Callable, Dict, List, Optional, Tuple, Union, cast
+from uuid import uuid4
+
+import structlog
+from orchestrator.db import ProcessTable
+from orchestrator.forms import post_process
+from orchestrator.services.processes import StateMerger, _db_create_process
+from orchestrator.types import State
+from orchestrator.utils.json import json_dumps, json_loads
+from orchestrator.workflow import Process as WFProcess
+from orchestrator.workflow import ProcessStat, Step, Success, Workflow, runwf
+from orchestrator.workflows import ALL_WORKFLOWS, LazyWorkflowInstance, get_workflow
+
+logger = structlog.get_logger(__name__)
+
+
+def _raise_exception(state):
+    if isinstance(state, Exception):
+        raise state
+    return state
+
+
+def assert_success(result):
+    assert (
+        result.on_failed(_raise_exception).on_waiting(_raise_exception).issuccess()
+    ), f"Unexpected process status. Expected Success, but was: {result}"
+
+
+def assert_waiting(result):
+    assert result.on_failed(
+        _raise_exception
+    ).iswaiting(), f"Unexpected process status. Expected Waiting, but was: {result}"
+
+
+def assert_suspended(result):
+    assert result.on_failed(
+        _raise_exception
+    ).issuspend(), f"Unexpected process status. Expected Suspend, but was: {result}"
+
+
+def assert_aborted(result):
+    assert result.on_failed(_raise_exception).isabort(), f"Unexpected process status. Expected Abort, but was: {result}"
+
+
+def assert_failed(result):
+    assert result.isfailed(), f"Unexpected process status. Expected Failed, but was: {result}"
+
+
+def assert_complete(result):
+    assert result.on_failed(
+        _raise_exception
+    ).iscomplete(), f"Unexpected process status. Expected Complete, but was: {result}"
+
+
+def assert_state(result, expected):
+    state = result.unwrap()
+    actual = {}
+    for key in expected.keys():
+        actual[key] = state[key]
+    assert expected == actual, f"Invalid state. Expected superset of: {expected}, but was: {actual}"
+
+
+def assert_state_equal(result: ProcessTable, expected: Dict, excluded_keys: Optional[List[str]] = None) -> None:
+    """Test state with certain keys excluded from both actual and expected state."""
+    if excluded_keys is None:
+        excluded_keys = ["process_id", "workflow_target", "workflow_name"]
+    state = deepcopy(extract_state(result))
+    expected_state = deepcopy(expected)
+    for key in excluded_keys:
+        if key in state:
+            del state[key]
+        if key in expected_state:
+            del expected_state[key]
+
+    assert state == expected_state, "Unexpected state:\n" + "\n".join(
+        difflib.ndiff(pprint.pformat(state).splitlines(), pprint.pformat(expected_state).splitlines())
+    )
+
+
+def assert_assignee(log, expected):
+    actual = log[-1][0].assignee
+    assert expected == actual, f"Unexpected assignee. Expected {expected}, but was: {actual}"
+
+
+def assert_step_name(log, expected):
+    actual = log[-1][0]
+    assert actual.name == expected, f"Unexpected name. Expected {expected}, but was: {actual}"
+
+
+def extract_state(result):
+    return result.unwrap()
+
+
+def extract_error(result):
+    from orchestrator.workflow import Process
+
+    assert isinstance(result, Process), f"Expected a Process, but got {repr(result)} of type {type(result)}"
+    assert not isinstance(result.s, Process), "Result contained a Process in a Process, this should not happen"
+
+    return extract_state(result).get("error")
+
+
+class WorkflowInstanceForTests(LazyWorkflowInstance):
+    """Register Test workflows.
+
+    Similar to `LazyWorkflowInstance` but does not require an import during instantiate
+    Used for creating test workflows
+    """
+
+    package: str
+    function: str
+    is_callable: bool
+
+    def __init__(self, workflow: Workflow, name: str) -> None:
+        self.workflow = workflow
+        self.name = name
+
+    def __enter__(self):
+        ALL_WORKFLOWS[self.name] = self
+
+    def __exit__(self, _exc_type, _exc_value, _traceback):
+        del ALL_WORKFLOWS[self.name]
+
+    def instantiate(self) -> Workflow:
+        """Import and instantiate a workflow and return it.
+
+        This can be as simple as merely importing a workflow function. However, if it concerns a workflow generating
+        function, that function will be called with or without arguments as specified.
+
+        Returns: A workflow function.
+        """
+        self.workflow.name = self.name
+        return self.workflow
+
+    def __str__(self) -> str:
+        return self.name
+
+    def __repr__(self) -> str:
+        return f"WorkflowInstanceForTests('{self.workflow}','{self.name}')"
+
+
+def _store_step(step_log: List[Tuple[Step, WFProcess]]) -> Callable[[ProcessStat, Step, WFProcess], WFProcess]:
+    def __store_step(pstat: ProcessStat, step: Step, state: WFProcess) -> WFProcess:
+        try:
+            state = state.map(lambda s: json_loads(json_dumps(s)))
+        except Exception:
+            logger.exception("Step state is not valid json", state=state)
+        step_log.append((step, state))
+        return state
+
+    return __store_step
+
+
+def _sanitize_input(input_data: Union[State, List[State]]) -> List[State]:
+    # To be backwards compatible convert single dict to list
+    if not isinstance(input_data, List):
+        input_data = [input_data]
+
+    # We need a copy here and we want to mimic the actual code that returns a serialized version of the state
+    return cast(List[State], json_loads(json_dumps(input_data)))
+
+
+def run_workflow(workflow_key: str, input_data: Union[State, List[State]]) -> Tuple[WFProcess, ProcessStat, List]:
+    # ATTENTION!! This code needs to be as similar as possible to `server.services.processes.start_process`
+    # The main differences are: we use a different step log function and we don't run in
+    # a sepperate thread
+    user_data = _sanitize_input(input_data)
+    user = "john.doe"
+
+    step_log: List[Tuple[Step, WFProcess]] = []
+
+    pid = uuid4()
+    workflow = get_workflow(workflow_key)
+    assert workflow, "Workflow does not exist"
+    initial_state = {
+        "process_id": pid,
+        "reporter": user,
+        "workflow_name": workflow_key,
+        "workflow_target": workflow.target,
+    }
+
+    user_input = post_process(workflow.initial_input_form, initial_state, user_data)
+
+    pstat = ProcessStat(
+        pid, workflow=workflow, state=Success({**user_input, **initial_state}), log=workflow.steps, current_user=user
+    )
+
+    _db_create_process(pstat)
+
+    result = runwf(pstat, _store_step(step_log))
+
+    return result, pstat, step_log
+
+
+def resume_workflow(
+    process: ProcessStat, step_log: List[Tuple[Step, WFProcess]], input_data: State
+) -> Tuple[WFProcess, List]:
+    # ATTENTION!! This code needs to be as similar as possible to `server.services.processes.resume_process`
+    # The main differences are: we use a different step log function and we don't run in
+    # a sepperate thread
+    user_data = _sanitize_input(input_data)
+
+    persistent = list(filter(lambda p: not (p[1].isfailed() or p[1].issuspend() or p[1].iswaiting()), step_log))
+    nr_of_steps_done = len(persistent)
+    remaining_steps = process.workflow.steps[nr_of_steps_done:]
+
+    _, current_state = step_log[-1]
+
+    user_input = post_process(remaining_steps[0].form, current_state.unwrap(), user_data)
+    state = current_state.map(lambda state: StateMerger.merge(deepcopy(state), user_input))
+
+    updated_process = process.update(log=remaining_steps, state=state)
+    result = runwf(updated_process, _store_step(step_log))
+    return result, step_log
diff --git a/test/workflows/conftest.py b/test/workflows/conftest.py
new file mode 100644
index 000000000..6e6630890
--- /dev/null
+++ b/test/workflows/conftest.py
@@ -0,0 +1,27 @@
+import pytest
+from urllib3_mock import Responses
+
+from test.fixtures import router_subscription_factory, site_subscription_factory  # noqa
+
+
+@pytest.fixture(autouse=True)
+def responses():
+    responses_mock = Responses("requests.packages.urllib3")
+
+    def _find_request(call):
+        mock_url = responses_mock._find_match(call.request)
+        if not mock_url:
+            pytest.fail(f"Call not mocked: {call.request}")
+        return mock_url
+
+    def _to_tuple(url_mock):
+        return url_mock["url"], url_mock["method"], url_mock["match_querystring"]
+
+    with responses_mock:
+        yield responses_mock
+
+        mocked_urls = map(_to_tuple, responses_mock._urls)
+        used_urls = map(_to_tuple, map(_find_request, responses_mock.calls))
+        not_used = set(mocked_urls) - set(used_urls)
+        if not_used:
+            pytest.fail(f"Found unused responses mocks: {not_used}", pytrace=False)
diff --git a/test/workflows/iptrunks/__init__.py b/test/workflows/iptrunks/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/test/workflows/iptrunks/test_create_iptrunks.py b/test/workflows/iptrunks/test_create_iptrunks.py
new file mode 100644
index 000000000..2a766eaf9
--- /dev/null
+++ b/test/workflows/iptrunks/test_create_iptrunks.py
@@ -0,0 +1,152 @@
+from unittest.mock import patch
+
+import pytest
+
+from gso.products import Iptrunk
+from gso.products.product_blocks import PhyPortCapacity
+from gso.products.product_blocks.iptrunk import IptrunkType
+from gso.schemas.enums import ProductType
+from gso.services.subscriptions import get_product_id_by_name
+from gso.workflows.utils import customer_selector
+from test.workflows import (
+    assert_aborted,
+    assert_complete,
+    assert_suspended,
+    extract_state,
+    resume_workflow,
+    run_workflow,
+)
+
+
+@pytest.fixture
+def input_form_wizard_data(router_subscription_factory, faker):
+    router_side_a = router_subscription_factory()
+    router_side_b = router_subscription_factory()
+
+    create_ip_trunk_step = {
+        "customer": getattr(customer_selector(), "8f0df561-ce9d-4d9c-89a8-7953d3ffc961"),
+        "geant_s_sid": faker.pystr(),
+        "iptrunk_type": IptrunkType.DARK_FIBER,
+        "iptrunk_description": faker.sentence(),
+        "iptrunk_speed": PhyPortCapacity.HUNDRED_GIGABIT_PER_SECOND,
+        "iptrunk_minimum_links": 5,
+    }
+    create_ip_trunk_side_a_step = {
+        "iptrunk_sideA_node_id": router_side_a,
+        "iptrunk_sideA_ae_iface": faker.pystr(),
+        "iptrunk_sideA_ae_geant_a_sid": faker.pystr(),
+        "iptrunk_sideA_ae_members": [faker.pystr() for _ in range(5)],
+        "iptrunk_sideA_ae_members_descriptions": [faker.sentence() for _ in range(5)],
+    }
+
+    create_ip_trunk_side_b_step = {
+        "iptrunk_sideB_node_id": router_side_b,
+        "iptrunk_sideB_ae_iface": faker.pystr(),
+        "iptrunk_sideB_ae_geant_a_sid": faker.pystr(),
+        "iptrunk_sideB_ae_members": [faker.pystr() for _ in range(5)],
+        "iptrunk_sideB_ae_members_descriptions": [faker.sentence() for _ in range(5)],
+    }
+
+    return [create_ip_trunk_step, create_ip_trunk_side_a_step, create_ip_trunk_side_b_step]
+
+
+def _user_accept_and_assert_suspended(process_stat, step_log, extra_data=None):
+    extra_data = extra_data or {}
+    result, step_log = resume_workflow(process_stat, step_log, extra_data)
+    assert_suspended(result)
+
+    return result, step_log
+
+
+@pytest.mark.workflow
+@patch("gso.workflows.iptrunk.create_iptrunk.provisioning_proxy.check_ip_trunk")
+@patch("gso.workflows.iptrunk.create_iptrunk.provisioning_proxy.provision_ip_trunk")
+@patch("gso.workflows.iptrunk.create_iptrunk.infoblox.allocate_v6_network")
+@patch("gso.workflows.iptrunk.create_iptrunk.infoblox.allocate_v4_network")
+def test_successful_iptrunk_creation_with_standard_resume_data(
+    mock_allocate_v4_network,
+    mock_allocate_v6_network,
+    mock_provision_ip_trunk,
+    mock_check_ip_trunk,
+    responses,
+    input_form_wizard_data,
+    faker,
+):
+    mock_allocate_v4_network.return_value = faker.ipv4_network()
+    mock_allocate_v6_network.return_value = faker.ipv6_network()
+    product_id = get_product_id_by_name(ProductType.IP_TRUNK)
+    initial_site_data = [{"product": product_id}, *input_form_wizard_data]
+    result, process_stat, step_log = run_workflow("create_iptrunk", initial_site_data)
+    assert_suspended(result)
+
+    standard_resume_data = {
+        "pp_run_results": {
+            "status": "ok",
+            "job_id": "random_job_id",
+            "output": "parsed_output",
+            "return_code": 0,
+        },
+        "confirm": "ACCEPTED",
+    }
+    for _ in range(5):
+        result, step_log = _user_accept_and_assert_suspended(process_stat, step_log, standard_resume_data)
+        result, step_log = _user_accept_and_assert_suspended(process_stat, step_log)
+
+    result, step_log = _user_accept_and_assert_suspended(process_stat, step_log, standard_resume_data)
+    result, step_log = resume_workflow(process_stat, step_log, {})
+    assert_complete(result)
+
+    state = extract_state(result)
+    subscription_id = state["subscription_id"]
+    subscription = Iptrunk.from_subscription(subscription_id)
+
+    assert "active" == subscription.status
+    assert subscription.description == f"IP trunk, geant_s_sid:{input_form_wizard_data[0]['geant_s_sid']}"
+
+    assert mock_provision_ip_trunk.call_count == 4
+    assert mock_check_ip_trunk.call_count == 2
+
+
+@pytest.mark.workflow
+@patch("gso.workflows.iptrunk.create_iptrunk.provisioning_proxy.check_ip_trunk")
+@patch("gso.workflows.iptrunk.create_iptrunk.provisioning_proxy.provision_ip_trunk")
+@patch("gso.workflows.iptrunk.create_iptrunk.infoblox.allocate_v6_network")
+@patch("gso.workflows.iptrunk.create_iptrunk.infoblox.allocate_v4_network")
+def test_iptrunk_creation_fails_when_lso_return_code_is_one(
+    mock_allocate_v4_network,
+    mock_allocate_v6_network,
+    mock_provision_ip_trunk,
+    mock_check_ip_trunk,
+    responses,
+    input_form_wizard_data,
+    faker,
+):
+    mock_allocate_v4_network.return_value = faker.ipv4_network()
+    mock_allocate_v6_network.return_value = faker.ipv6_network()
+    product_id = get_product_id_by_name(ProductType.IP_TRUNK)
+
+    initial_site_data = [{"product": product_id}, *input_form_wizard_data]
+    result, process_stat, step_log = run_workflow("create_iptrunk", initial_site_data)
+    assert_suspended(result)
+
+    standard_resume_data = {
+        "pp_run_results": {
+            "status": "ok",
+            "job_id": "random_job_id",
+            "output": "parsed_output",
+            "return_code": 1,
+        },
+        "confirm": "ACCEPTED",
+    }
+
+    attempts = 3
+    for _ in range(0, attempts - 1):
+        result, step_log = _user_accept_and_assert_suspended(process_stat, step_log, standard_resume_data)
+        result, step_log = _user_accept_and_assert_suspended(process_stat, step_log)
+
+    result, step_log = _user_accept_and_assert_suspended(process_stat, step_log, standard_resume_data)
+    result, step_log = resume_workflow(process_stat, step_log, {})
+    assert_aborted(result)
+
+    assert mock_provision_ip_trunk.call_count == attempts
+    assert mock_check_ip_trunk.call_count == 0
diff --git a/test/workflows/site/__init__.py b/test/workflows/site/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/test/workflows/site/test_create_site.py b/test/workflows/site/test_create_site.py
new file mode 100644
index 000000000..d313b7f4a
--- /dev/null
+++ b/test/workflows/site/test_create_site.py
@@ -0,0 +1,40 @@
+import pytest
+
+from gso.products.product_blocks.site import SiteTier
+from gso.products.product_types.site import Site
+from gso.schemas.enums import ProductType
+from gso.services.crm import get_customer_by_name
+from gso.services.subscriptions import get_product_id_by_name
+from test.workflows import assert_complete, extract_state, run_workflow
+
+
+@pytest.mark.workflow
+def test_create_site(responses, faker):
+    product_id = get_product_id_by_name(ProductType.SITE)
+    initial_site_data = [
+        {"product": product_id},
+        {
+            "site_name": faker.name(),
+            "site_city": faker.city(),
+            "site_country": faker.country(),
+            "site_country_code": faker.country_code(),
+            "site_latitude": "-74.0060",
+            "site_longitude": "40.7128",
+            "site_bgp_community_id": faker.pyint(),
+            "site_internal_id": faker.pyint(),
+            "site_tier": SiteTier.TIER1,
+            "site_ts_address": faker.ipv4(),
+            "customer": get_customer_by_name("GÉANT")["id"],
+        },
+    ]
+    result, process, step_log = run_workflow(workflow_key="create_site", input_data=initial_site_data)
+    assert_complete(result)
+
+    state = extract_state(result)
+    subscription_id = state["subscription_id"]
+    subscription = Site.from_subscription(subscription_id)
+    assert "active" == subscription.status
+    assert (
+        subscription.description
+        == f"Site in {initial_site_data[1]['site_city']}, {initial_site_data[1]['site_country']}"
+    )
-- 
GitLab