Skip to content
Snippets Groups Projects
__init__.py 10.87 KiB
import difflib
import pprint
from collections.abc import Callable
from copy import deepcopy
from typing import cast
from uuid import uuid4

import structlog
from orchestrator.db import ProcessTable, WorkflowTable, db
from orchestrator.services.processes import StateMerger, _db_create_process
from orchestrator.types import State
from orchestrator.utils.json import json_dumps, json_loads
from orchestrator.workflow import Process, ProcessStat, Step, Success, Workflow, runwf
from orchestrator.workflow import Process as WFProcess
from orchestrator.workflows import ALL_WORKFLOWS, LazyWorkflowInstance, get_workflow
from pydantic_forms.core import post_form

from test import LSO_RESULT_FAILURE, LSO_RESULT_SUCCESS, USER_CONFIRM_EMPTY_FORM

logger = structlog.get_logger(__name__)


def _raise_exception(state):
    if isinstance(state, Exception):
        raise state
    return state


def assert_success(result):
    assert (
        result.on_failed(_raise_exception).on_waiting(_raise_exception).issuccess()
    ), f"Unexpected process status. Expected Success, but was: {result}"


def assert_waiting(result):
    assert result.on_failed(
        _raise_exception,
    ).iswaiting(), f"Unexpected process status. Expected Waiting, but was: {result}"


def assert_suspended(result):
    assert result.on_failed(
        _raise_exception,
    ).issuspend(), f"Unexpected process status. Expected Suspend, but was: {result}"


def assert_awaiting_callback(result):
    assert result.on_failed(
        _raise_exception,
    ).isawaitingcallback(), f"Unexpected process status. Expected Awaiting Callback, but was: {result}"


def assert_aborted(result):
    assert result.on_failed(_raise_exception).isabort(), f"Unexpected process status. Expected Abort, but was: {result}"


def assert_failed(result):
    assert result.isfailed(), f"Unexpected process status. Expected Failed, but was: {result}"


def assert_complete(result):
    assert result.on_failed(
        _raise_exception,
    ).iscomplete(), f"Unexpected process status. Expected Complete, but was: {result}"


def assert_state(result, expected):
    state = result.unwrap()
    actual = {}
    for key in expected:
        actual[key] = state[key]
    assert expected == actual, f"Invalid state. Expected superset of: {expected}, but was: {actual}"


def assert_state_equal(result: ProcessTable, expected: dict, excluded_keys: list[str] | None = None) -> None:
    """Test state with certain keys excluded from both actual and expected state."""
    if excluded_keys is None:
        excluded_keys = ["process_id", "workflow_target", "workflow_name"]
    state = deepcopy(extract_state(result))
    expected_state = deepcopy(expected)
    for key in excluded_keys:
        if key in state:
            del state[key]
        if key in expected_state:
            del expected_state[key]

    assert state == expected_state, "Unexpected state:\n" + "\n".join(
        difflib.ndiff(
            pprint.pformat(state).splitlines(),
            pprint.pformat(expected_state).splitlines(),
        ),
    )


def assert_assignee(log, expected):
    actual = log[-1][0].assignee
    assert expected == actual, f"Unexpected assignee. Expected {expected}, but was: {actual}"


def assert_step_name(log, expected):
    actual = log[-1][0]
    assert actual.name == expected, f"Unexpected name. Expected {expected}, but was: {actual}"


def extract_state(result):
    return result.unwrap()


def extract_error(result):
    assert isinstance(result, Process), f"Expected a Process, but got {result!r} of type {type(result)}"
    assert not isinstance(result.s, Process), "Result contained a Process in a Process, this should not happen"

    return extract_state(result).get("error")


def store_workflow(wf: Workflow, name: str | None = None) -> WorkflowTable:
    wf_table = WorkflowTable(name=name or wf.name, target=wf.target, description=wf.description)
    db.session.add(wf_table)
    db.session.commit()
    return wf_table


def delete_workflow(wf: WorkflowTable) -> None:
    db.session.delete(wf)
    db.session.commit()


class WorkflowInstanceForTests(LazyWorkflowInstance):
    """Register Test workflows.

    Similar to `LazyWorkflowInstance` but does not require an import during instantiate
    Used for creating test workflows
    """

    package: str
    function: str
    is_callable: bool

    def __init__(self, workflow: Workflow, name: str) -> None:
        super().__init__("orchestrator.test", name)
        self.workflow = workflow
        self.name = name

    def __enter__(self):
        ALL_WORKFLOWS[self.name] = self
        self.workflow_instance = store_workflow(self.workflow, name=self.name)
        return self.workflow_instance

    def __exit__(self, _exc_type, _exc_value, _traceback):
        del ALL_WORKFLOWS[self.name]
        delete_workflow(self.workflow_instance)
        del self.workflow_instance

    def instantiate(self) -> Workflow:
        """Import and instantiate a workflow and return it.

        This can be as simple as merely importing a workflow function. However, if it concerns a workflow generating
        function, that function will be called with or without arguments as specified.

        Returns
        -------
            A workflow function.

        """
        self.workflow.name = self.name
        return self.workflow

    def __str__(self) -> str:
        return self.name

    def __repr__(self) -> str:
        return f"WorkflowInstanceForTests('{self.workflow}','{self.name}')"


def _store_step(step_log: list[tuple[Step, Process]]) -> Callable[[ProcessStat, Step, Process], Process]:
    def __store_step(pstat: ProcessStat, step: Step, process: Process) -> Process:
        try:
            process = process.map(lambda s: json_loads(json_dumps(s)))
        except Exception:
            logger.exception("Step state is not valid json", process=process)

        state = process.unwrap()
        state.pop("__step_name_override", None)
        for k in [*state.get("__remove_keys", []), "__remove_keys"]:
            state.pop(k, None)
        if state.pop("__replace_last_state", None):
            step_log[-1] = (step, process)
        else:
            step_log.append((step, process))
        return process

    return __store_step


def _sanitize_input(input_data: State | list[State]) -> list[State]:
    # To be backwards compatible convert single dict to list
    if not isinstance(input_data, list):
        input_data = [input_data]

    # We need a copy here and we want to mimic the actual code that returns a serialized version of the state
    return cast(list[State], json_loads(json_dumps(input_data)))


def run_workflow(workflow_key: str, input_data: State | list[State]) -> tuple[WFProcess, ProcessStat, list]:
    # ATTENTION!! This code needs to be as similar as possible to `server.services.processes.start_process`
    # The main differences are: we use a different step log function, and we don't run in
    # a separate thread
    user_data = _sanitize_input(input_data)
    user = "john.doe"

    step_log: list[tuple[Step, WFProcess]] = []

    process_id = uuid4()
    workflow = get_workflow(workflow_key)
    assert workflow, "Workflow does not exist"
    initial_state = {
        "process_id": process_id,
        "reporter": user,
        "workflow_name": workflow_key,
        "workflow_target": workflow.target,
    }

    user_input = post_form(workflow.initial_input_form, initial_state, user_data)

    pstat = ProcessStat(
        process_id,
        workflow=workflow,
        state=Success({**user_input, **initial_state}),
        log=workflow.steps,
        current_user=user,
    )

    _db_create_process(pstat)

    result = runwf(pstat, _store_step(step_log))

    return result, pstat, step_log


def resume_workflow(
    process: ProcessStat,
    step_log: list[tuple[Step, Process]],
    input_data: State | list[State],
) -> tuple[Process, list]:
    # ATTENTION!! This code needs to be as similar as possible to ``server.services.processes.resume_process``
    # The main differences are: we use a different step log function, and we don't run in a separate thread
    persistent = list(
        filter(
            lambda p: not (p[1].isfailed() or p[1].issuspend() or p[1].iswaiting() or p[1].isawaitingcallback()),
            step_log,
        ),
    )
    nr_of_steps_done = len(persistent)
    remaining_steps = process.workflow.steps[nr_of_steps_done:]

    if step_log and step_log[-1][1].issuspend():  # noqa: SIM114
        _, current_state = step_log[-1]
    elif step_log and step_log[-1][1].isawaitingcallback():
        _, current_state = step_log[-1]
    elif persistent:
        _, current_state = persistent[-1]
    else:
        current_state = Success({})

    if step_log and step_log[-1][1].isawaitingcallback():
        # Data is given as input by the external system, not a form.
        user_input = input_data
    else:
        user_input = post_form(remaining_steps[0].form, current_state.unwrap(), input_data)
    state = current_state.map(lambda state: StateMerger.merge(deepcopy(state), user_input))

    updated_process = process.update(log=remaining_steps, state=state)
    result = runwf(updated_process, _store_step(step_log))
    return result, step_log


def user_accept_and_assert_suspended(process_stat, step_log, extra_data=None):
    extra_data = extra_data or {}
    result, step_log = resume_workflow(process_stat, step_log, extra_data)
    assert_suspended(result)

    return result, step_log


def assert_lso_success(result: Process, process_stat: ProcessStat, step_log: list):
    """Assert a successful LSO execution in a workflow."""
    assert_awaiting_callback(result)
    return resume_workflow(process_stat, step_log, input_data=LSO_RESULT_SUCCESS)


def assert_lso_interaction_success(result: Process, process_stat: ProcessStat, step_log: list):
    """Assert a successful LSO interaction in a workflow.

    First, the workflow is awaiting callback. It is resumed but a result from LSO, after which the user submits the
    confirmation input step. Two assertions are made: the workflow is awaiting callback at first, and suspended when
    waiting for the user to confirm the results received.
    """
    assert_awaiting_callback(result)
    result, step_log = assert_lso_success(result, process_stat, step_log)
    assert_suspended(result)

    return resume_workflow(process_stat, step_log, input_data=USER_CONFIRM_EMPTY_FORM)


def assert_lso_interaction_failure(result: Process, process_stat: ProcessStat, step_log: list):
    """Assert a failed LSO interaction in a workflow.

    First, the workflow is awaiting callback. It is resumed by a "failure" result from LSO, after which the workflow is
    in a failed state. This failed state is also returned. Two assertions are made: the workflow is awaiting callback at
    first, and failed when the result is received from LSO.
    """
    assert_awaiting_callback(result)
    result, step_log = resume_workflow(process_stat, step_log, input_data=LSO_RESULT_FAILURE)
    assert_failed(result)

    return result, step_log