import logging
import re
import ipaddress

from jnpr.junos import Device
from jnpr.junos import exception as EzErrors
from lxml import etree
import netifaces
import requests

CONFIG_SCHEMA = """<?xml version="1.1" encoding="UTF-8" ?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">

  <xs:complexType name="generic-sequence">
    <xs:sequence>
        <xs:any processContents="lax" minOccurs="0" maxOccurs="unbounded"/>
    </xs:sequence>
    <xs:anyAttribute processContents="skip" />
  </xs:complexType>


  <!-- NOTE: 'unit' content isn't validated -->
  <xs:complexType name="juniper-interface">
    <xs:sequence>
        <xs:choice minOccurs="1" maxOccurs="unbounded">
          <xs:element name="name" minOccurs="1" maxOccurs="1" type="xs:string" />
          <xs:element name="description" minOccurs="0" maxOccurs="1">
              <xs:complexType>
                <xs:simpleContent>
                  <xs:extension base="xs:string">
                    <xs:attribute name="inactive" type="xs:string" />
                  </xs:extension>
                </xs:simpleContent>
              </xs:complexType>
          </xs:element>
          <xs:any processContents="lax" minOccurs="0" maxOccurs="unbounded" />
        </xs:choice>
    </xs:sequence>
    <xs:attribute name="inactive" type="xs:string" />
  </xs:complexType>

  <xs:element name="configuration">
    <xs:complexType>
      <xs:sequence>
        <xs:choice minOccurs="1" maxOccurs="unbounded">
          <xs:element name="transfer-on-commit" minOccurs="0" type="xs:string" />
          <xs:element name="archive-sites" minOccurs="0" type="generic-sequence" />
          <xs:element name="version" minOccurs="0" type="xs:string" />
          <xs:element name="groups" minOccurs="0" type="generic-sequence" />
          <xs:element name="apply-groups" minOccurs="0" type="xs:string" />
          <xs:element name="system" minOccurs="0" type="generic-sequence" />
          <xs:element name="logical-systems" minOccurs="0" type="generic-sequence" />
          <xs:element name="chassis" minOccurs="0" type="generic-sequence" />
          <xs:element name="services" minOccurs="0" type="generic-sequence" />
          <xs:element name="interfaces" minOccurs="0">
            <xs:complexType>
              <xs:sequence>
                <xs:choice minOccurs="1" maxOccurs="unbounded">
                    <xs:element name="apply-groups" minOccurs="0" type="xs:string" />
                    <xs:element name="interface-range" minOccurs="0" type="generic-sequence" />
                    <xs:element name="interface" minOccurs="1" maxOccurs="unbounded" type="juniper-interface" />
                </xs:choice>
              </xs:sequence>
            </xs:complexType>
          </xs:element>
          <xs:element name="snmp" minOccurs="0" type="generic-sequence" />
          <xs:element name="forwarding-options" minOccurs="0" type="generic-sequence" />
          <xs:element name="routing-options" minOccurs="0" type="generic-sequence" />
          <xs:element name="protocols" minOccurs="0" type="generic-sequence" />
          <xs:element name="policy-options" minOccurs="0" type="generic-sequence" />
          <xs:element name="class-of-service" minOccurs="0" type="generic-sequence" />
          <xs:element name="firewall" minOccurs="0" type="generic-sequence" />
          <xs:element name="routing-instances" minOccurs="0" type="generic-sequence" />
          <xs:element name="bridge-domains" minOccurs="0" type="generic-sequence" />
          <xs:element name="virtual-chassis" minOccurs="0" type="generic-sequence" />
          <xs:element name="vlans" minOccurs="0" type="generic-sequence" />
          <xs:element name="comment" minOccurs="0" type="xs:string" />
        </xs:choice>
      </xs:sequence>
      <xs:attribute name="changed-seconds" type="xs:string" />
      <xs:attribute name="changed-localtime" type="xs:string" />
    </xs:complexType>
  </xs:element>

</xs:schema>
"""  # noqa: E501

UNIT_SCHEMA = """<?xml version="1.1" encoding="UTF-8" ?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">

  <xs:complexType name="generic-sequence">
    <xs:sequence>
        <xs:any processContents="lax" minOccurs="0" maxOccurs="unbounded"/>
    </xs:sequence>
    <xs:anyAttribute processContents="skip" />
  </xs:complexType>

  <xs:element name="unit">
    <xs:complexType>
      <xs:sequence>
        <xs:element name="name" minOccurs="1" maxOccurs="1" type="xs:int" />
        <xs:element name="description" minOccurs="0" maxOccurs="1" type="xs:string" />
        <xs:any processContents="lax" minOccurs="0" maxOccurs="unbounded"/>
      </xs:sequence>
      <xs:attribute name="inactive" type="xs:string" />
    </xs:complexType>
  </xs:element>

</xs:schema>
"""  # noqa: E501

PEERING_LIST_SCHEMA = {
    "$schema": "http://json-schema.org/draft-07/schema#",
    "definitions": {
        "peering": {
            "type": "object",
            "properties": {
                "group": {"type": "string"},
                "description": {"type": "string"},
                "address": {"type": "string"},
                "remote-asn": {"type": "integer"},
                "local-asn": {"type": "integer"},
                "instance": {"type": "string"},
                "logical-system": {"type": "string"},
            },
            # lots of internal peerings - so maybe no explicit asn's
            "required": ["group", "address"],
            "additionalProperties": False
        }
    },
    "type": "array",
    "items": {"$ref": "#/definitions/peering"}
}


class NetconfHandlingError(Exception):
    pass


def _rpc(hostname, ssh):
    dev = Device(
        host=hostname,
        user=ssh['username'],
        ssh_private_key_file=ssh['private-key'])
    try:
        dev.open()
    except (EzErrors.ConnectError, EzErrors.RpcError) as e:
        raise ConnectionError(str(e))
    return dev.rpc


def validate_netconf_config(config_doc):
    """
    :param config_doc:
    :return:
    :raises: NetconfHandlingError in case of validation errors
    """
    logger = logging.getLogger(__name__)

    def _validate(schema, doc):
        if schema.validate(doc):
            return
        messages = []
        for e in schema.error_log:
            msg = f'{e.line}.{e.column}: {e.message}'
            messages.append(msg)
            logger.error(msg)
        raise NetconfHandlingError('\n'.join(messages))

    schema_doc = etree.XML(CONFIG_SCHEMA.encode('utf-8'))
    config_schema = etree.XMLSchema(schema_doc)

    _validate(config_schema, config_doc)

    # validate interfaces/interface/unit elements ...
    schema_doc = etree.XML(UNIT_SCHEMA.encode('utf-8'))
    unit_schema = etree.XMLSchema(schema_doc)
    for i in config_doc.xpath('//configuration/interfaces/interface'):
        for u in i.xpath('./unit'):
            _validate(unit_schema, u)


def load_config(hostname, ssh_params, validate=True):
    """
    loads netconf data from the router, validates (by default) and
    returns as an lxml etree doc

    :param hostname: router hostname
    :param ssh_params: 'ssh' config element(cf. config.py:CONFIG_SCHEMA)
    :param validate: whether or not to validate netconf data (default True)
    :return:
    :raises: NetconfHandlingError from validate_netconf_config
    """
    logger = logging.getLogger(__name__)
    logger.info("capturing netconf data for '%s'" % hostname)
    config = _rpc(hostname, ssh_params).get_config()
    if validate:
        validate_netconf_config(config)
    return config


def list_interfaces(netconf_config):
    """
    generator that parses netconf output and
    yields a list of interfaces

    :param netconf_config: xml doc that was generated by load_config
    :return:
    """

    def _ifc_info(e):
        # warning: this structure should match the default
        #          returned from routes.classifier.juniper_link_info
        name = e.find('name')
        assert name is not None, "expected interface 'name' child element"
        ifc = {
            'name': name.text,
            'description': '',
            'bundle': []
        }
        description = e.find('description')
        if description is not None:
            ifc['description'] = description.text

        for b in i.iterfind(".//bundle"):
            ifc['bundle'].append(b.text)

        ifc['ipv4'] = e.xpath('./family/inet/address/name/text()')
        ifc['ipv6'] = e.xpath('./family/inet6/address/name/text()')

        return ifc

    def _units(base_name, node):
        for u in node.xpath('./unit'):
            if u.get('inactive', None) == 'inactive':
                continue
            unit_info = _ifc_info(u)
            unit_info['name'] = "%s.%s" % (base_name, unit_info['name'])
            yield unit_info

    for i in netconf_config.xpath('//configuration/interfaces/interface'):
        info = _ifc_info(i)
        yield info
        for u in _units(info['name'], i):
            yield u

    for i in netconf_config.xpath(
            '//configuration/logical-systems/interfaces/interface'):
        name = i.find('name')
        assert name is not None, 'expected interface ''name'' child element'
        for u in _units(name.text, i):
            yield u


def _system_bgp_peers(system_node):

    def _peering_params(neighbor_node):
        address = neighbor_node.find('name').text
        info = {'address': ipaddress.ip_address(address).exploded}
        peer_as = neighbor_node.find('peer-as')
        if peer_as is not None:
            # lxml usage warning: can't just test `if peer_as:`
            info['remote-asn'] = int(peer_as.text)
        local_as = neighbor_node.find('local-as')
        if local_as is not None:
            asn_value_node = local_as.find('as-number')
            info['local-asn'] = int(asn_value_node.text)
        description = neighbor_node.find('description')
        if description is not None:
            # lxml usage warning: can't just test `if description:`
            info['description'] = description.text
        return info

    def _neighbors(group_node):
        for neighbor in group_node.xpath('./neighbor'):
            inactive = neighbor.get('inactive')
            if inactive == 'inactive':
                continue
            yield _peering_params(neighbor)

    for group in system_node.xpath('./protocols/bgp/group'):
        group_name = group.find('name').text
        for peer in _neighbors(group):
            peer['group'] = group_name
            yield peer

    for instance in system_node.xpath(
            './routing-instances/instance'):
        instance_name = instance.find('name').text
        for peer in _system_bgp_peers(instance):
            peer['instance'] = instance_name
            yield peer


def all_bgp_peers(netconf_config):

    for base_system in netconf_config.xpath('//configuration'):
        # there should only be one
        yield from _system_bgp_peers(base_system)

    for logical_system in netconf_config.xpath(
            '//configuration/logical-systems'):
        logical_system_name = logical_system.find('name').text
        for peer in _system_bgp_peers(logical_system):
            peer['logical-system'] = logical_system_name
            yield peer


def interface_addresses(netconf_config):
    """
    yields a list of all distinct interface addresses
    :param netconf_config:
    :return:
    """
    for ifc in list_interfaces(netconf_config):
        for address in ifc['ipv4'] + ifc['ipv6']:
            yield {
                "name": ipaddress.ip_interface(address).ip.exploded,
                "interface address": address,
                "interface name": ifc['name']
            }


def load_routers_from_netdash(url):
    """
    query url for a linefeed-delmitted list of managed router hostnames

    :param url: url of alldevices.txt file
    :return: list of router hostnames
    """
    r = requests.get(url=url)
    r.raise_for_status()
    return [
        ln.strip() for ln in r.text.splitlines() if ln.strip()
    ]


def local_interfaces(
        type=netifaces.AF_INET,
        omit_link_local=True,
        omit_loopback=True):
    """
    generator yielding IPv4Interface or IPv6Interface objects,
    depending on the value of type
    :param type: hopefully AF_INET or AF_INET6
    :param omit_link_local: skip v6 fe80* addresses if true
    :param omit_loopback: skip lo* interfaces if true
    :return:
    """
    for n in netifaces.interfaces():
        if omit_loopback and re.match(r'^lo\d+', n):
            continue
        am = netifaces.ifaddresses(n)
        for a in am.get(type, []):
            if omit_link_local and a['addr'].startswith('fe80:'):
                continue
            m = re.match(r'^(.+?)(%.*)?$', a['addr'])
            assert m
            addr = m.group(1)
            m = re.match(r'.*/(\d+)$', a['netmask'])
            if m:
                mask = m.group(1)
            else:
                mask = a['netmask']
            yield ipaddress.ip_interface('%s/%s' % (addr, mask))


def snmp_community_string(netconf_config):
    my_addressess = list([i.ip for i in local_interfaces()])
    for community in netconf_config.xpath('//configuration/snmp/community'):
        for subnet in community.xpath('./clients/name/text()'):
            allowed_network = ipaddress.ip_network(subnet, strict=False)
            for me in my_addressess:
                if me in allowed_network:
                    return community.xpath('./name/text()')[0]
    return None


def netconf_changed_timestamp(netconf_config):
    '''
    return the last change timestamp published by the config document
    :param netconf_config: netconf lxml etree document
    :return: an epoch timestamp (integer number of seconds) or None
    '''
    for ts in netconf_config.xpath('/configuration/@changed-seconds'):
        if re.match(r'^\d+$', ts):
            return int(ts)
    logger = logging.getLogger(__name__)
    logger.warning('no valid timestamp found in netconf configuration')
    return None


def logical_systems(netconf_config):
    """
    Return a list of logical system names for the router.

    It's not an error if a router has no defined logical systems.

    :param netconf_config: netconf lxml etree document
    :return: a list of strings
    """
    return netconf_config.xpath('//configuration/logical-systems/name/text()')