import ipaddress
import logging
import re
import struct

from pysnmp.hlapi import nextCmd, SnmpEngine, CommunityData, \
    UdpTransportTarget, ContextData, ObjectType, ObjectIdentity
from pysnmp.smi import builder, compiler
# from pysnmp.smi import view, rfc1902


RFC1213_MIB_IFDESC = '1.3.6.1.2.1.2.2.1.2'
# BGP4-V2-MIB-JUNIPER::jnxBgpM2PeerState
JNX_BGP_M2_PEER_STATE = '1.3.6.1.4.1.2636.5.1.1.2.1.1.1.2'
logger = logging.getLogger(__name__)


class SNMPWalkError(ConnectionError):
    pass


def _cast_snmp_value(value):
    """
    Cast things to the simplest native type.

    :param value:
    :return:
    """
    try:
        return int(value)
    except (ValueError, TypeError):
        try:
            return float(value)
        except (ValueError, TypeError):
            try:
                return str(value)
            except (ValueError, TypeError):
                pass
    return value


def _v6address_oid2str(dotted_decimal):
    hex_params = []
    for dec in re.split(r'\.', dotted_decimal):
        hex_params.append("%02x" % int(dec))
    return ":".join(hex_params)


def _canonify_oid(oid):
    """
    Our output oid's always begin with '.'.

    :param oid: a thing that stringifies to a dotted oid
    :return: a string like '.#.#.#...#'
    """
    oid = str(oid)
    return oid if oid.startswith('.') else f'.{oid}'


def walk(agent_hostname, community, base_oid):  # pragma: no cover
    """
    https://stackoverflow.com/a/45001921
    http://snmplabs.com/pysnmp/docs/hlapi/asyncore/sync/manager/cmdgen/nextcmd.html
    http://snmplabs.com/pysnmp/faq/pass-custom-mib-to-manager.html
    https://github.com/etingof/pysnmp/blob/master/examples/v3arch/asyncore/manager/cmdgen/getnext-multiple-oids-and-resolve-with-mib.py
    http://snmplabs.com/pysnmp/examples/smi/manager/browsing-mib-tree.html

    :param agent_hostname:
    :param community:
    :param base_oid:
    :return:
    """

    mibBuilder = builder.MibBuilder()
    # mibViewController = view.MibViewController(mibBuilder)
    compiler.addMibCompiler(
        mibBuilder,
        sources=['http://mibs.snmplabs.com/asn1/@mib@'])
    # Pre-load MIB modules we expect to work with
    mibBuilder.loadModules(
        'SNMPv2-MIB',
        'SNMP-COMMUNITY-MIB',
        'RFC1213-MIB')

    logger.debug("walking %s: %s" % (agent_hostname, base_oid))

    for (engineErrorIndication,
         pduErrorIndication,
         errorIndex,
         varBinds) in nextCmd(
            SnmpEngine(),
            CommunityData(community),
            UdpTransportTarget((agent_hostname, 161)),
            ContextData(),
            ObjectType(ObjectIdentity(base_oid)),
            lexicographicMode=False,
            lookupNames=True,
            lookupValues=True):

        # cf. http://snmplabs.com/
        #       pysnmp/examples/hlapi/asyncore/sync/contents.html
        if engineErrorIndication:
            raise SNMPWalkError(
                f'snmp response engine error indication: '
                f'{str(engineErrorIndication)} - {agent_hostname}')
        if pduErrorIndication:
            raise SNMPWalkError(
                'snmp response pdu error %r at %r' % (
                    pduErrorIndication,
                    errorIndex and varBinds[int(errorIndex) - 1][0] or '?'))
        if errorIndex != 0:
            raise SNMPWalkError(
                'sanity failure: errorIndex != 0, '
                'but no error indication')

        # varBinds = [
        #     rfc1902.ObjectType(rfc1902.ObjectIdentity(x[0]),x[1])
        #         .resolveWithMib(mibViewController)
        #     for x in varBinds]
        for oid, val in varBinds:
            result = {
                "oid": _canonify_oid(oid),
                "value": _cast_snmp_value(val)
            }
            logger.debug(result)
            yield result


def get_router_snmp_indexes(hostname, community):
    for ifc in walk(hostname, community, RFC1213_MIB_IFDESC):
        m = re.match(r'.*\.(\d+)$', ifc['oid'])
        assert m, f'sanity failure parsing oid: {ifc["oid"]}'
        yield {
            'name': ifc['value'],
            'index': int(m.group(1))
        }


def _v6bytes(int_str_list):
    assert len(int_str_list) == 16
    return struct.pack('!16B', *map(int, int_str_list))


def _v4str(int_str_list):
    assert len(int_str_list) == 4
    return '.'.join(int_str_list)


def get_peer_state_info(hostname, community):
    oid_prefix = f'.{JNX_BGP_M2_PEER_STATE}.'
    for ifc in walk(hostname, community, JNX_BGP_M2_PEER_STATE):

        assert ifc['oid'].startswith(oid_prefix), \
            f'{ifc["oid"]}: {JNX_BGP_M2_PEER_STATE}'

        rest = ifc['oid'][len(oid_prefix):]
        splits = rest.split('.')
        splits.pop(0)  # no idea what this integer is ...
        if splits[0] == splits[5] == '1':  # v4 should peer with v4
            # ipv4
            assert len(splits) == 10
            local = ipaddress.ip_address(_v4str(splits[1:5]))
            remote = ipaddress.ip_address(_v4str(splits[6:]))
        elif splits[0] == splits[17] == '2':  # v6 should peer with v6
            assert len(splits) == 34
            local = ipaddress.ip_address(_v6bytes(splits[1:17]))
            remote = ipaddress.ip_address(_v6bytes(splits[18:]))
        else:
            logger.error(f'expected v4 or v6 peering, got type {splits[0]}')
            assert False

        yield {
            'local': local.exploded,
            'remote': remote.exploded,
            'oid': ifc['oid']
        }


# if __name__ == '__main__':
#
#
#     # HOSTNAME = 'mx1.ams.nl.geant.net'
#     HOSTNAME = 'mx1.kau.lt.geant.net'
#     COMMUNITY = '0pBiFbD'
#     import json
#
#     # for x in get_peer_state_info('mx1.kau.lt.geant.net', '0pBiFbD'):
#     #     print(x)
#
#     peerings = get_peer_state_info(HOSTNAME, COMMUNITY)
#     print(json.dumps(list(peerings), indent=2, sort_keys=True))
#
#     # oids = [x['oid'] for x in peerings]
#     # print(oids)
#     # data = dict()
#     # for i in range(0, len(oids), 3):
#     #     data.update(get(HOSTNAME, COMMUNITY, oids[i:i+3]))
#     #
#     # assert all([v for v in data.values()])
#     # print(json.dumps(data, indent=2))
#
#     # import json
#     # z = get_router_snmp_indexes('mx1.kau.lt.geant.net', '0pBiFbD')
#     # print(json.dumps(list(z), indent=2))