import contextlib
import json
import logging
import re
from multiprocessing import Pool, Process, Queue

import click
import jsonschema
import mysql.connector
import paramiko
from pysnmp.hlapi import nextCmd, SnmpEngine, CommunityData, \
        UdpTransportTarget, ContextData, ObjectType, ObjectIdentity


SNMP_LOGGER_NAME = "snmp-logger"
THREADING_LOGGER_NAME = "threading-logger"
JUNIPER_LOGGER_NAME = "juniper-logger"
DATABASE_LOGGER_NAME = "database-logger"

CONFIG_SCHEMA = {
    "$schema": "http://json-schema.org/draft-07/schema#",
    "type": "object",
    "properties": {
        "alarms-db": {
            "type": "object",
            "properties": {
                "hostname": {"type": "string"},
                "dbname": {"type": "string"},
                "username": {"type": "string"},
                "password": {"type": "string"}
            },
            "required": ["hostname", "dbname", "username", "password"],
            "additionalProperties": False
        },
        "oid_list.conf": {"type": "string"},
        "routers_community.conf": {"type": "string"},
        "ssh": {
            "type": "object",
            "properties": {
                "private-key": {"type": "string"},
                "known-hosts": {"type": "string"}
            },
            "required": ["private-key", "known-hosts"],
            "additionalProperties": False
        }
    },
    "required": ["alarms-db", "oid_list.conf", "routers_community.conf"],
    "additionalProperties": False
}


def walk(agent_hostname, community, base_oid):
    """
    https://stackoverflow.com/a/45001921
    http://snmplabs.com/pysnmp/docs/hlapi/asyncore/sync/manager/cmdgen/nextcmd.html
    http://snmplabs.com/pysnmp/faq/pass-custom-mib-to-manager.html
    https://github.com/etingof/pysnmp/blob/master/examples/v3arch/asyncore/manager/cmdgen/getnext-multiple-oids-and-resolve-with-mib.py
    http://snmplabs.com/pysnmp/examples/smi/manager/browsing-mib-tree.html

    :param agent_hostname:
    :param community:
    :param base_oid:
    :return:
    """

    snmp_logger = logging.getLogger(SNMP_LOGGER_NAME)

    from pysnmp.smi import builder, view, compiler, rfc1902
    mibBuilder = builder.MibBuilder()
    mibViewController = view.MibViewController(mibBuilder)
    compiler.addMibCompiler(mibBuilder, sources=['http://mibs.snmplabs.com/asn1/@mib@'])
    # Pre-load MIB modules we expect to work with
    mibBuilder.loadModules('SNMPv2-MIB', 'SNMP-COMMUNITY-MIB', 'RFC1213-MIB')

    snmp_logger.debug("walking %s: %s" % (agent_hostname, base_oid))

    for (engineErrorIndication,
         pduErrorIndication,
         errorIndex,
         varBinds) in nextCmd(
            SnmpEngine(),
            CommunityData(community),
            UdpTransportTarget((agent_hostname, 161)),
            ContextData(),
            ObjectType(ObjectIdentity(base_oid)),
            lexicographicMode=False,
            lookupNames=True,
            lookupValues=True):
        assert not engineErrorIndication
        assert not pduErrorIndication
        assert errorIndex == 0
        # varBinds = [
        #     rfc1902.ObjectType(rfc1902.ObjectIdentity(x[0]),x[1])
        #         .resolveWithMib(mibViewController)
        #     for x in varBinds]
        for oid, val in varBinds:
            yield {"oid": "." + str(oid), "value": val.prettyPrint()}


def _validate_config(ctx, param, value):
    """
    loads, validates and returns configuration parameters

    :param ctx:
    :param param:
    :param value:
    :return:
    """
    config = json.loads(value.read())
    jsonschema.validate(config, CONFIG_SCHEMA)
    return config


def load_oids(config_file):
    """
    :param config_file: file-like object
    :return:
    """
    result = {}
    for line in config_file:
        m = re.match(r'^([^=]+)=(.*)\s*$', line)
        if m:
            result[m.group(1)] = m.group(2)
    return result


def load_routers(config_file):
    """
    :param config_file: file-like object
    :return:
    """
    for line in config_file:
        m = re.match(r'^([a-z\d]+\.[a-z\d]{3,4}\.[a-z\d]{2}\.(geant|eumedconnect)\d*\.net)\s*=([^,]+)\s*,(.*)\s*$', line)
        if not m:
            logging.warning("malformed config file line: '%s'" % line.strip())
            continue
        yield {
            "hostname": m.group(1),
            "community": m.group(3),
            "address": m.group(4)
        }


@contextlib.contextmanager
def connection(alarmsdb):
    cx = None
    try:
        cx = mysql.connector.connect(
            host=alarmsdb["hostname"],
            user=alarmsdb["username"],
            passwd=alarmsdb["password"],
            db=alarmsdb["dbname"])
        yield cx
    finally:
        if cx:
            cx.close()


@contextlib.contextmanager
def cursor(cnx):
    csr = None
    try:
        csr = cnx.cursor()
        yield csr
    finally:
        if csr:
            csr.close()


def _db_test(db, router):
    database_logger = logging.getLogger(DATABASE_LOGGER_NAME)
    with cursor(db) as crs:
        database_logger.debug("_db_test: %r" % router)
        query = "SELECT absid FROM routers WHERE hostname = %s"
        crs.execute(query, (router['hostname'],))
        for (absid,) in crs:
            database_logger.debug("absid: %r" % absid)


def _v6address_oid2str(dotted_decimal):
    hex_params = []
    for dec in re.split(r'\.', dotted_decimal):
        hex_params.append("%02x" % int(dec))
    return ":".join(hex_params)


def get_router_interfaces(router):
    with open("oid_list.conf") as f:
        oid_map = load_oids(f)

    details = {}
    for name, oid in oid_map.items():
        details[name] = walk(router["hostname"], router["community"], oid)
        details[name] = list(details[name])

    v4IfcNames = {}
    for v4IfcName in details["v4InterfaceName"]:
        m = re.match(r'.*\.(\d+)$', v4IfcName["oid"])
        assert m, "sanity failure parsing oid: " + v4IfcName["oid"]
        v4IfcNames[m.group(1)] = v4IfcName["value"]

    interfaces = []
    for v4Address, v4Mask, v4InterfaceOID in zip(
            details["v4Address"],
            details["v4Mask"],
            details["v4InterfaceOID"]):
        yield {
            "v4Address": v4Address["value"],
            "v4Mask": v4Mask["value"],
            "v4InterfaceName": v4IfcNames[v4InterfaceOID["value"]]
        }

    v6IfcNames = {}
    for v6InterfaceName in details["v6InterfaceName"]:
        m = re.match(r'.*\.(\d+)$', v6InterfaceName["oid"])
        assert m, "sanity failure parsing oid: " + v6InterfaceName["oid"]
        v6IfcNames[m.group(1)] = v6InterfaceName["value"]

    for v6AddressAndMask in details["v6AddressAndMask"]:
        pattern = (
            r'^'
            + oid_map["v6AddressAndMask"].replace(r'.', r'\.')
            + r'\.(\d+)\.(.+)$'
        )
        m = re.match(pattern, v6AddressAndMask["oid"])
        assert m, "sanity failure parsing oid: " + v6InterfaceName["oid"]
        yield {
            "v6Address": _v6address_oid2str(m.group(2)),
            "v6Mask": v6AddressAndMask["value"],
            "v6InterfaceName": v6IfcNames[m.group(1)]
        }


@contextlib.contextmanager
def ssh_connection(router, ssh_params):
    import os
    key_filename = os.path.join(os.path.dirname(__file__), ssh_params["private-key"])
    known_hosts = os.path.join(os.path.dirname(__file__), ssh_params["known-hosts"])
    k = paramiko.DSSKey.from_private_key_file(key_filename)

    router["hostname"] = "mx1.ams.nl.geant.net"
    with paramiko.SSHClient() as ssh:
        ssh.load_host_keys(known_hosts)
        ssh.connect(
            hostname=router["hostname"],
            username="Monit0r",
            pkey=k)
        yield ssh


def exec_router_commands_json(router, ssh_params, commands):
    juniper_logger = logging.getLogger(JUNIPER_LOGGER_NAME)
    with ssh_connection(router, ssh_params) as ssh:

        _, stdout, _ = ssh.exec_command("set cli screen-length 0")
        assert stdout.channel.recv_exit_status() == 0

        def _dups_to_list(pairs):
            counter_map = {}
            for k, v in pairs:
                counter_map.setdefault(k, []).append(v)
            result = {}
            for k, v in counter_map.items():
                if len(v) == 1:
                    result[k] = v[0]
                else:
                    result[k] = v
            return result

        for c in commands:
            juniper_logger.debug("command: '%s'" % (c + " | display json"))
            _, stdout, _ = ssh.exec_command(c + " | display json")
            assert stdout.channel.recv_exit_status() == 0
            # TODO: error handling
            output = stdout.read()
            if output:
                juniper_logger.debug("%r output: [%d] %r" % (router, len(output), output[:20]))
                yield json.loads(output, object_pairs_hook=_dups_to_list)
            else:
                juniper_logger.debug("%r output empty" % router)
                yield {}


def get_router_interfaces_q(router, q):
    threading_logger = logging.getLogger(THREADING_LOGGER_NAME)
    threading_logger.debug("[ENTER>>] get_router_interfaces_q: %r" % router)
    q.put(list(get_router_interfaces(router)))
    threading_logger.debug("[<<EXIT]  get_router_interfaces_q: %r" % router)


def exec_router_commands_json_q(router, ssh_params, commands, q):
    threading_logger = logging.getLogger(THREADING_LOGGER_NAME)
    threading_logger.debug("[ENTER>>] exec_router_commands_q: %r" % router)
    q.put(list(exec_router_commands_json(router, ssh_params, commands)))
    threading_logger.debug("[<<EXIT] exec_router_commands_q: %r" % router)


def get_router_details(router, params, q):

    threading_logger = logging.getLogger(THREADING_LOGGER_NAME)

    threading_logger.debug("[ENTER>>]get_router_details: %r" % router)

    commands = [
        'show configuration routing-instances IAS protocols bgp',
        # 'show configuration routing-instances IAS protocols bgp | display set | match neighbor | match description | match "GEANT-IX | GEANT-IX-"',
        # 'show configuration routing-instances IAS protocols bgp | display set | match neighbor | match description | match "GEANT-IXv6 | GEANT-IXv6-"',
        'show configuration logical-systems VRR protocols bgp',
        # 'show configuration logical-systems VRR protocols bgp group VPN-RR-INTERNAL | match "neigh|desc"',
        # 'show configuration logical-systems VRR protocols bgp group VPN-RR | match "neigh|desc"'
        'show interfaces descriptions'
    ]

    snmpifc_proc_queue = Queue()
    snmpifc_proc = Process(target=get_router_interfaces_q, args=(router,snmpifc_proc_queue))
    snmpifc_proc.start()

    commands_proc_queue = Queue()
    commands_proc = Process(target=exec_router_commands_json_q, args=(router, params["ssh"], commands, commands_proc_queue))
    commands_proc.start()

    threading_logger.debug("waiting for commands result: %r" % router)
    command_output = commands_proc_queue.get()
    assert len(command_output) == len(commands)
    result = dict(zip(["bgp", "vrr", "interfaces"], command_output))
    commands_proc.join()
    threading_logger.debug("... got commands result & joined: %r" % router)

    threading_logger.debug("waiting for snmp ifc results: %r" % router)
    result["snmp-interfaces"] = snmpifc_proc_queue.get()
    snmpifc_proc.join()
    threading_logger.debug("... got snmp ifc result & joined: %r" % router)

    q.put(result)

    threading_logger.debug("[<<EXIT]get_router_details: %r" % router)


def load_network_details(config):

    threading_logger = logging.getLogger(THREADING_LOGGER_NAME)

    with open("routers_community.conf") as f:
        routers = list(load_routers(f))

    processes = []
    for r in routers:
        q = Queue()
        p = Process(target=get_router_details, args=(r, config, q))
        p.start()
        processes.append({"router": r, "process": p, "queue": q})

    result = {}
    for p in processes:
        threading_logger.debug("waiting for get_router_details result: %r" % p["router"])
        result[p["router"]["hostname"]] = p["queue"].get()
        p["process"].join()
        threading_logger.debug("got result and joined get_router_details proc: %r" % p["router"])

    return result


@click.command()
@click.option(
    "--config",
#    required=True,
    type=click.File(),
    help="Configuration filename",

    default=open("config.json"),
    callback=_validate_config)
def cli(config):
    network_details = load_network_details(config)
    filename = "/tmp/router-info.json"
    logging.debug("writing output to: " + filename)
    with open(filename, "w") as f:
        f.write(json.dumps(network_details))


if __name__ == "__main__":
    logging.basicConfig(level=logging.WARNING)
    logging.getLogger(SNMP_LOGGER_NAME).setLevel(logging.DEBUG)
    logging.getLogger(THREADING_LOGGER_NAME).setLevel(logging.INFO)
    logging.getLogger(JUNIPER_LOGGER_NAME).setLevel(logging.DEBUG)
    logging.getLogger(DATABASE_LOGGER_NAME).setLevel(logging.DEBUG)
    cli()