import difflib import pprint from copy import deepcopy from itertools import chain, repeat from typing import Callable, cast from uuid import uuid4 import structlog from orchestrator.db import ProcessTable from orchestrator.services.processes import StateMerger, _db_create_process from orchestrator.types import FormGenerator, InputForm, State from orchestrator.utils.json import json_dumps, json_loads from orchestrator.workflow import Process as WFProcess from orchestrator.workflow import ProcessStat, Step, Success, Workflow, runwf from orchestrator.workflows import ALL_WORKFLOWS, LazyWorkflowInstance, get_workflow from pydantic_forms.core import post_form logger = structlog.get_logger(__name__) def _raise_exception(state): if isinstance(state, Exception): raise state return state def assert_success(result): assert ( result.on_failed(_raise_exception).on_waiting(_raise_exception).issuccess() ), f"Unexpected process status. Expected Success, but was: {result}" def assert_waiting(result): assert result.on_failed( _raise_exception ).iswaiting(), f"Unexpected process status. Expected Waiting, but was: {result}" def assert_suspended(result): assert result.on_failed( _raise_exception ).issuspend(), f"Unexpected process status. Expected Suspend, but was: {result}" def assert_aborted(result): assert result.on_failed(_raise_exception).isabort(), f"Unexpected process status. Expected Abort, but was: {result}" def assert_failed(result): assert result.isfailed(), f"Unexpected process status. Expected Failed, but was: {result}" def assert_complete(result): assert result.on_failed( _raise_exception ).iscomplete(), f"Unexpected process status. Expected Complete, but was: {result}" def assert_state(result, expected): state = result.unwrap() actual = {} for key in expected.keys(): actual[key] = state[key] assert expected == actual, f"Invalid state. Expected superset of: {expected}, but was: {actual}" def assert_state_equal(result: ProcessTable, expected: dict, excluded_keys: list[str] | None = None) -> None: """Test state with certain keys excluded from both actual and expected state.""" if excluded_keys is None: excluded_keys = ["process_id", "workflow_target", "workflow_name"] state = deepcopy(extract_state(result)) expected_state = deepcopy(expected) for key in excluded_keys: if key in state: del state[key] if key in expected_state: del expected_state[key] assert state == expected_state, "Unexpected state:\n" + "\n".join( difflib.ndiff(pprint.pformat(state).splitlines(), pprint.pformat(expected_state).splitlines()) ) def assert_assignee(log, expected): actual = log[-1][0].assignee assert expected == actual, f"Unexpected assignee. Expected {expected}, but was: {actual}" def assert_step_name(log, expected): actual = log[-1][0] assert actual.name == expected, f"Unexpected name. Expected {expected}, but was: {actual}" def extract_state(result): return result.unwrap() def extract_error(result): from orchestrator.workflow import Process assert isinstance(result, Process), f"Expected a Process, but got {repr(result)} of type {type(result)}" assert not isinstance(result.s, Process), "Result contained a Process in a Process, this should not happen" return extract_state(result).get("error") class WorkflowInstanceForTests(LazyWorkflowInstance): """Register Test workflows. Similar to `LazyWorkflowInstance` but does not require an import during instantiate Used for creating test workflows """ package: str function: str is_callable: bool def __init__(self, workflow: Workflow, name: str) -> None: self.workflow = workflow self.name = name def __enter__(self): ALL_WORKFLOWS[self.name] = self def __exit__(self, _exc_type, _exc_value, _traceback): del ALL_WORKFLOWS[self.name] def instantiate(self) -> Workflow: """Import and instantiate a workflow and return it. This can be as simple as merely importing a workflow function. However, if it concerns a workflow generating function, that function will be called with or without arguments as specified. Returns: A workflow function. """ self.workflow.name = self.name return self.workflow def __str__(self) -> str: return self.name def __repr__(self) -> str: return f"WorkflowInstanceForTests('{self.workflow}','{self.name}')" def _store_step(step_log: list[tuple[Step, WFProcess]]) -> Callable[[ProcessStat, Step, WFProcess], WFProcess]: def __store_step(pstat: ProcessStat, step: Step, state: WFProcess) -> WFProcess: try: state = state.map(lambda s: json_loads(json_dumps(s))) except Exception: logger.exception("Step state is not valid json", state=state) step_log.append((step, state)) return state return __store_step def _sanitize_input(input_data: State | list[State]) -> list[State]: # To be backwards compatible convert single dict to list if not isinstance(input_data, list): input_data = [input_data] # We need a copy here, and we want to mimic the actual code that returns a serialized version of the state return cast(list[State], json_loads(json_dumps(input_data))) def run_workflow(workflow_key: str, input_data: State | list[State]) -> tuple[WFProcess, ProcessStat, list]: # ATTENTION!! This code needs to be as similar as possible to `server.services.processes.start_process` # The main differences are: we use a different step log function, and we don't run in # a separate thread user_data = _sanitize_input(input_data) user = "john.doe" step_log: list[tuple[Step, WFProcess]] = [] process_id = uuid4() workflow = get_workflow(workflow_key) assert workflow, "Workflow does not exist" initial_state = { "process_id": process_id, "reporter": user, "workflow_name": workflow_key, "workflow_target": workflow.target, } user_input = post_form(workflow.initial_input_form, initial_state, user_data) pstat = ProcessStat( process_id, workflow=workflow, state=Success({**user_input, **initial_state}), log=workflow.steps, current_user=user, ) _db_create_process(pstat) result = runwf(pstat, _store_step(step_log)) return result, pstat, step_log def resume_workflow( process: ProcessStat, step_log: list[tuple[Step, WFProcess]], input_data: State ) -> tuple[WFProcess, list]: # ATTENTION!! This code needs to be as similar as possible to `server.services.processes.resume_process` # The main differences are: we use a different step log function, and we don't run in a separate thread user_data = _sanitize_input(input_data) persistent = list(filter(lambda p: not (p[1].isfailed() or p[1].issuspend() or p[1].iswaiting()), step_log)) nr_of_steps_done = len(persistent) remaining_steps = process.workflow.steps[nr_of_steps_done:] # Make sure we get the last state from the suspend step (since we removed it before) if step_log and step_log[-1][1].issuspend(): _, current_state = step_log[-1] elif persistent: _, current_state = persistent[-1] else: current_state = Success({}) user_input = post_form(remaining_steps[0].form, current_state.unwrap(), user_data) state = current_state.map(lambda state: StateMerger.merge(deepcopy(state), user_input)) updated_process = process.update(log=remaining_steps, state=state) result = runwf(updated_process, _store_step(step_log)) return result, step_log def run_form_generator( form_generator: FormGenerator, extra_inputs: list[State] | None = None ) -> tuple[list[dict], State]: """Run a form generator to get the resulting forms and result. Warning! This does not run the actual pydantic validation on purpose. However, you should make sure that anything in extra_inputs matched the values and types as if the pydantic validation has been run. Args: ---- form_generator (FormGenerator): The form generator that will be run. extra_inputs (list[State] | None): list of user input dicts for each page in the generator. If no input is given for a page, an empty dict is used. The default value from the form is used as the default value for a field. Returns: ------- tuple[list[dict], State]: A list of generated forms and the result state for the whole generator. Example: ------- Given the following form generator: >>> from pydantic_forms.core import FormPage >>> def form_generator(state): ... class TestForm(FormPage): ... field: str = "foo" ... user_input = yield TestForm ... return {**user_input.dict(), "bar": 42} You can run this without extra_inputs >>> forms, result = run_form_generator(form_generator({"state_field": 1})) >>> forms [{'title': 'unknown', 'type': 'object', 'properties': { 'field': {'title': 'Field', 'default': 'foo', 'type': 'string'}}, 'additionalProperties': False}] >>> result {'field': 'foo', 'bar': 42} Or with extra_inputs: >>> forms, result = run_form_generator(form_generator({'state_field': 1}), [{'field':'baz'}]) >>> forms [{'title': 'unknown', 'type': 'object', 'properties': { 'field': {'title': 'Field', 'default': 'foo', 'type': 'string'}}, 'additionalProperties': False}] >>> result {'field': 'baz', 'bar': 42} """ forms: list[dict] = [] result: State = {"s": 3} if extra_inputs is None: extra_inputs = [] try: form = cast(InputForm, next(form_generator)) forms.append(form.schema()) for extra_input in chain(extra_inputs, repeat(cast(State, {}))): user_input_data = {field_name: field.default for field_name, field in form.__fields__.items()} user_input_data.update(extra_input) user_input = form.construct(**user_input_data) form = form_generator.send(user_input) forms.append(form.schema()) except StopIteration as stop: result = stop.value return forms, result