import difflib
import pprint
from copy import deepcopy
from itertools import chain, repeat
from typing import Callable, cast
from uuid import uuid4

import structlog
from orchestrator.db import ProcessTable
from orchestrator.services.processes import StateMerger, _db_create_process
from orchestrator.types import FormGenerator, InputForm, State
from orchestrator.utils.json import json_dumps, json_loads
from orchestrator.workflow import Process as WFProcess
from orchestrator.workflow import ProcessStat, Step, Success, Workflow, runwf
from orchestrator.workflows import ALL_WORKFLOWS, LazyWorkflowInstance, get_workflow
from pydantic_forms.core import post_form

logger = structlog.get_logger(__name__)


def _raise_exception(state):
    if isinstance(state, Exception):
        raise state
    return state


def assert_success(result):
    assert (
        result.on_failed(_raise_exception).on_waiting(_raise_exception).issuccess()
    ), f"Unexpected process status. Expected Success, but was: {result}"


def assert_waiting(result):
    assert result.on_failed(
        _raise_exception
    ).iswaiting(), f"Unexpected process status. Expected Waiting, but was: {result}"


def assert_suspended(result):
    assert result.on_failed(
        _raise_exception
    ).issuspend(), f"Unexpected process status. Expected Suspend, but was: {result}"


def assert_aborted(result):
    assert result.on_failed(_raise_exception).isabort(), f"Unexpected process status. Expected Abort, but was: {result}"


def assert_failed(result):
    assert result.isfailed(), f"Unexpected process status. Expected Failed, but was: {result}"


def assert_complete(result):
    assert result.on_failed(
        _raise_exception
    ).iscomplete(), f"Unexpected process status. Expected Complete, but was: {result}"


def assert_state(result, expected):
    state = result.unwrap()
    actual = {}
    for key in expected.keys():
        actual[key] = state[key]
    assert expected == actual, f"Invalid state. Expected superset of: {expected}, but was: {actual}"


def assert_state_equal(result: ProcessTable, expected: dict, excluded_keys: list[str] | None = None) -> None:
    """Test state with certain keys excluded from both actual and expected state."""
    if excluded_keys is None:
        excluded_keys = ["process_id", "workflow_target", "workflow_name"]
    state = deepcopy(extract_state(result))
    expected_state = deepcopy(expected)
    for key in excluded_keys:
        if key in state:
            del state[key]
        if key in expected_state:
            del expected_state[key]

    assert state == expected_state, "Unexpected state:\n" + "\n".join(
        difflib.ndiff(pprint.pformat(state).splitlines(), pprint.pformat(expected_state).splitlines())
    )


def assert_assignee(log, expected):
    actual = log[-1][0].assignee
    assert expected == actual, f"Unexpected assignee. Expected {expected}, but was: {actual}"


def assert_step_name(log, expected):
    actual = log[-1][0]
    assert actual.name == expected, f"Unexpected name. Expected {expected}, but was: {actual}"


def extract_state(result):
    return result.unwrap()


def extract_error(result):
    from orchestrator.workflow import Process

    assert isinstance(result, Process), f"Expected a Process, but got {repr(result)} of type {type(result)}"
    assert not isinstance(result.s, Process), "Result contained a Process in a Process, this should not happen"

    return extract_state(result).get("error")


class WorkflowInstanceForTests(LazyWorkflowInstance):
    """Register Test workflows.

    Similar to `LazyWorkflowInstance` but does not require an import during instantiate
    Used for creating test workflows
    """

    package: str
    function: str
    is_callable: bool

    def __init__(self, workflow: Workflow, name: str) -> None:
        self.workflow = workflow
        self.name = name

    def __enter__(self):
        ALL_WORKFLOWS[self.name] = self

    def __exit__(self, _exc_type, _exc_value, _traceback):
        del ALL_WORKFLOWS[self.name]

    def instantiate(self) -> Workflow:
        """Import and instantiate a workflow and return it.

        This can be as simple as merely importing a workflow function. However, if it concerns a workflow generating
        function, that function will be called with or without arguments as specified.

        Returns: A workflow function.
        """
        self.workflow.name = self.name
        return self.workflow

    def __str__(self) -> str:
        return self.name

    def __repr__(self) -> str:
        return f"WorkflowInstanceForTests('{self.workflow}','{self.name}')"


def _store_step(step_log: list[tuple[Step, WFProcess]]) -> Callable[[ProcessStat, Step, WFProcess], WFProcess]:
    def __store_step(pstat: ProcessStat, step: Step, state: WFProcess) -> WFProcess:
        try:
            state = state.map(lambda s: json_loads(json_dumps(s)))
        except Exception:
            logger.exception("Step state is not valid json", state=state)
        step_log.append((step, state))
        return state

    return __store_step


def _sanitize_input(input_data: State | list[State]) -> list[State]:
    # To be backwards compatible convert single dict to list
    if not isinstance(input_data, list):
        input_data = [input_data]

    # We need a copy here, and we want to mimic the actual code that returns a serialized version of the state
    return cast(list[State], json_loads(json_dumps(input_data)))


def run_workflow(workflow_key: str, input_data: State | list[State]) -> tuple[WFProcess, ProcessStat, list]:
    # ATTENTION!! This code needs to be as similar as possible to `server.services.processes.start_process`
    # The main differences are: we use a different step log function, and we don't run in
    # a separate thread
    user_data = _sanitize_input(input_data)
    user = "john.doe"

    step_log: list[tuple[Step, WFProcess]] = []

    process_id = uuid4()
    workflow = get_workflow(workflow_key)
    assert workflow, "Workflow does not exist"
    initial_state = {
        "process_id": process_id,
        "reporter": user,
        "workflow_name": workflow_key,
        "workflow_target": workflow.target,
    }

    user_input = post_form(workflow.initial_input_form, initial_state, user_data)

    pstat = ProcessStat(
        process_id,
        workflow=workflow,
        state=Success({**user_input, **initial_state}),
        log=workflow.steps,
        current_user=user,
    )

    _db_create_process(pstat)

    result = runwf(pstat, _store_step(step_log))

    return result, pstat, step_log


def resume_workflow(
    process: ProcessStat, step_log: list[tuple[Step, WFProcess]], input_data: State
) -> tuple[WFProcess, list]:
    # ATTENTION!! This code needs to be as similar as possible to `server.services.processes.resume_process`
    # The main differences are: we use a different step log function, and we don't run in a separate thread
    user_data = _sanitize_input(input_data)

    persistent = list(filter(lambda p: not (p[1].isfailed() or p[1].issuspend() or p[1].iswaiting()), step_log))
    nr_of_steps_done = len(persistent)
    remaining_steps = process.workflow.steps[nr_of_steps_done:]

    # Make sure we get the last state from the suspend step (since we removed it before)
    if step_log and step_log[-1][1].issuspend():
        _, current_state = step_log[-1]
    elif persistent:
        _, current_state = persistent[-1]
    else:
        current_state = Success({})

    user_input = post_form(remaining_steps[0].form, current_state.unwrap(), user_data)
    state = current_state.map(lambda state: StateMerger.merge(deepcopy(state), user_input))

    updated_process = process.update(log=remaining_steps, state=state)
    result = runwf(updated_process, _store_step(step_log))
    return result, step_log


def run_form_generator(
    form_generator: FormGenerator, extra_inputs: list[State] | None = None
) -> tuple[list[dict], State]:
    """Run a form generator to get the resulting forms and result.

    Warning! This does not run the actual pydantic validation on purpose. However, you should
    make sure that anything in extra_inputs matched the values and types as if the pydantic validation has
    been run.

    Args:
    ----
    form_generator (FormGenerator): The form generator that will be run.
    extra_inputs (list[State] | None): list of user input dicts for each page in the generator.
                                         If no input is given for a page, an empty dict is used.
                                         The default value from the form is used as the default value for a field.

    Returns:
    -------
        tuple[list[dict], State]: A list of generated forms and the result state for the whole generator.

    Example:
    -------
        Given the following form generator:

        >>> from pydantic_forms.core import FormPage
        >>> def form_generator(state):
        ...     class TestForm(FormPage):
        ...         field: str = "foo"
        ...     user_input = yield TestForm
        ...     return {**user_input.dict(), "bar": 42}

        You can run this without extra_inputs
        >>> forms, result = run_form_generator(form_generator({"state_field": 1}))
        >>> forms
        [{'title': 'unknown', 'type': 'object', 'properties': {
            'field': {'title': 'Field', 'default': 'foo', 'type': 'string'}}, 'additionalProperties': False}]
        >>> result
        {'field': 'foo', 'bar': 42}


        Or with extra_inputs:
        >>> forms, result = run_form_generator(form_generator({'state_field': 1}), [{'field':'baz'}])
        >>> forms
        [{'title': 'unknown', 'type': 'object', 'properties': {
            'field': {'title': 'Field', 'default': 'foo', 'type': 'string'}}, 'additionalProperties': False}]
        >>> result
        {'field': 'baz', 'bar': 42}

    """
    forms: list[dict] = []
    result: State = {"s": 3}
    if extra_inputs is None:
        extra_inputs = []

    try:
        form = cast(InputForm, next(form_generator))
        forms.append(form.schema())
        for extra_input in chain(extra_inputs, repeat(cast(State, {}))):
            user_input_data = {field_name: field.default for field_name, field in form.__fields__.items()}
            user_input_data.update(extra_input)
            user_input = form.construct(**user_input_data)
            form = form_generator.send(user_input)
            forms.append(form.schema())
    except StopIteration as stop:
        result = stop.value

    return forms, result