import contextlib
import json
import logging
import re

import paramiko

from inventory_provider.constants import JUNIPER_LOGGER_NAME


def neighbors(
        parsed_json_output,
        routing_instances=["IAS"],
        group_expression=r'^GEANT-IX(v6)?[\s-].*$'):

    for config in parsed_json_output["configuration"]:
        for ri in config["routing-instances"]:
            for inst in ri["instance"]:
                if inst["name"]["data"] not in routing_instances:
                    continue
                for prot in inst["protocols"]:
                    for bgp in prot.get("bgp", []):
                        for g in bgp.get("group", []):
                            if group_expression and not \
                                    re.match(
                                        group_expression,
                                        g["name"]["data"]):
                                continue
                            for n in g["neighbor"]:
                                yield n


def interfaces(parsed_json_output):
    for ifc_info in parsed_json_output["interface-information"]:
        for ifc_list in ifc_info["logical-interface"]:
            for ifc in ifc_list:
                yield {
                    "name": ifc["name"][0]["data"],
                    "status": "%s/%s" % (
                        ifc["admin-status"][0]["data"],
                        ifc["oper-status"][0]["data"]),
                    "description": ifc["description"][0]["data"],
                    "type": "logical"
                }


@contextlib.contextmanager
def ssh_connection(hostname, ssh_params):
    # TODO: remove this relative path logic!!
    import os
    key_filename = os.path.join(
        os.path.dirname(__file__),
        ssh_params["private-key"])
    known_hosts = os.path.join(
        os.path.dirname(__file__),
        ssh_params["known-hosts"])
    k = paramiko.DSSKey.from_private_key_file(key_filename)

    with paramiko.SSHClient() as ssh:
        ssh.load_host_keys(known_hosts)
        ssh.connect(
            hostname=hostname,
            username="Monit0r",
            pkey=k)
        yield ssh


def ssh_exec_commands(hostname, ssh_params, commands):
    juniper_logger = logging.getLogger(JUNIPER_LOGGER_NAME)
    with ssh_connection(hostname, ssh_params) as ssh:

        _, stdout, _ = ssh.exec_command("set cli screen-length 0")
        assert stdout.channel.recv_exit_status() == 0

        for c in commands:
            juniper_logger.debug("command: '%s'" % c)
            _, stdout, _ = ssh.exec_command(c)
            assert stdout.channel.recv_exit_status() == 0
            # TODO: error handling
            yield stdout.read().decode("utf-8")


def _loads(s, **args):
    """
    the json text contains raw backslashes
    :param s:
    :param args:
    :return:
    """
    return json.loads(s.replace("\\", "\\\\"), **args)


_DISABLE_PAGING_COMMAND = r'set cli screen-length 0'


def fetch_bgp_config(hostname, ssh_params, **args):

    commands = [
        _DISABLE_PAGING_COMMAND,
        ('show configuration routing-instances'
         ' IAS protocols bgp | display json')
    ]

    output = list(ssh_exec_commands(hostname, ssh_params, commands))
    assert len(output) == len(commands)

    if output[1]:
        return list(neighbors(_loads(output[1]), **args))
    else:
        return {}


def fetch_vrr_config(hostname, ssh_params):

    commands = [
        _DISABLE_PAGING_COMMAND,
        ('show configuration logical-systems '
         'VRR protocols bgp | display json')
    ]

    output = list(ssh_exec_commands(hostname, ssh_params, commands))
    assert len(output) == len(commands)

    return _loads(output[1]) if output[1] else {}


def fetch_interfaces(hostname, ssh_params):

    def _dups_to_list(pairs):
        counter_map = {}
        for k, v in pairs:
            counter_map.setdefault(k, []).append(v)
        result = {}
        for k, v in counter_map.items():
            if len(v) == 1:
                result[k] = v[0]
            else:
                result[k] = v
        return result

    commands = [
        _DISABLE_PAGING_COMMAND,
        'show interfaces descriptions | display json'
    ]

    output = list(ssh_exec_commands(hostname, ssh_params, commands))
    assert len(output) == len(commands)

    return _loads(
        output[1],
        object_pairs_hook=_dups_to_list) if output[1] else {}