import contextlib
import json
import logging
import re

import paramiko

from inventory_provider.constants import JUNIPER_LOGGER_NAME


def neighbors(
        parsed_json_output,
        routing_instances=["IAS"],
        group_expression=r'^GEANT-IX(v6)?[\s-].*$'):

    for config in parsed_json_output["configuration"]:
        for ri in config["routing-instances"]:
            for inst in ri["instance"]:
                if inst["name"]["data"] not in routing_instances:
                    continue
                for prot in inst["protocols"]:
                    for bgp in prot.get("bgp", []):
                        for g in bgp.get("group", []):
                            if group_expression and not \
                                    re.match(group_expression, g["name"]["data"]):
                                continue
                            for n in g["neighbor"]:
                                yield n


def interfaces(parsed_json_output):
    for ifc_info in parsed_json_output["interface-information"]:
        for ifc_list in ifc_info["logical-interface"]:
            for ifc in ifc_list:
                yield {
                    "name": ifc["name"][0]["data"],
                    "status": "%s/%s" % (
                        ifc["admin-status"][0]["data"],
                        ifc["oper-status"][0]["data"]),
                    "description": ifc["description"][0]["data"],
                    "type": "logical"
                }


@contextlib.contextmanager
def ssh_connection(hostname, ssh_params):
    import os
    key_filename = os.path.join(os.path.dirname(__file__), ssh_params["private-key"])
    known_hosts = os.path.join(os.path.dirname(__file__), ssh_params["known-hosts"])
    k = paramiko.DSSKey.from_private_key_file(key_filename)

    with paramiko.SSHClient() as ssh:
        ssh.load_host_keys(known_hosts)
        ssh.connect(
            hostname=hostname,
            username="Monit0r",
            pkey=k)
        yield ssh


def ssh_exec_commands(hostname, ssh_params, commands):
    juniper_logger = logging.getLogger(JUNIPER_LOGGER_NAME)
    with ssh_connection(hostname, ssh_params) as ssh:

        _, stdout, _ = ssh.exec_command("set cli screen-length 0")
        assert stdout.channel.recv_exit_status() == 0

        for c in commands:
            juniper_logger.debug("command: '%s'" % c)
            _, stdout, _ = ssh.exec_command(c)
            assert stdout.channel.recv_exit_status() == 0
            # TODO: error handling
            yield stdout.read().decode("utf-8")


def shell_commands():

    def _dups_to_list(pairs):
        counter_map = {}
        for k, v in pairs:
            counter_map.setdefault(k, []).append(v)
        result = {}
        for k, v in counter_map.items():
            if len(v) == 1:
                result[k] = v[0]
            else:
                result[k] = v
        return result

    def _loads(s, **args):
        """
        the json text contains raw backslashes
        :param s:
        :param args:
        :return:
        """
        return json.loads(s.replace("\\", "\\\\"), **args)

    yield {
        "command": "set cli screen-length 0",
        "key": None,
        "parser": lambda _: None
    }

    def _parse_bgp_output(txt, **args):
        if txt:
            return list(neighbors(_loads(txt), **args))
        else:
            return {}

    yield {
        "command": 'show configuration routing-instances IAS protocols bgp | display json',
        "key": "bgp",
        "parser": _parse_bgp_output
    }

    yield {
        "command": 'show configuration logical-systems VRR protocols bgp | display json',
        "key": "vrr",
        "parser": lambda txt: _loads(txt) if txt else {}
    }

    yield {
        "command": 'show interfaces descriptions | display json',
        "key": "interfaces",
        "parser": lambda txt: list(interfaces(
            _loads(
                txt,
                object_pairs_hook=_dups_to_list)
            )) if txt else {}
    }