import logging
import re
import ipaddress

from jnpr.junos import Device
from jnpr.junos import exception as EzErrors
from lxml import etree
import netifaces
import requests

CONFIG_SCHEMA = """<?xml version="1.1" encoding="UTF-8" ?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">

  <xs:complexType name="generic-sequence">
    <xs:sequence>
        <xs:any processContents="lax" minOccurs="0" maxOccurs="unbounded"/>
    </xs:sequence>
    <xs:anyAttribute processContents="skip" />
  </xs:complexType>


  <!-- NOTE: 'unit' content isn't validated -->
  <xs:complexType name="juniper-interface">
    <xs:sequence>
        <xs:choice minOccurs="1" maxOccurs="unbounded">
          <xs:element name="name" minOccurs="1" maxOccurs="1" type="xs:string" />
          <xs:element name="description" minOccurs="0" maxOccurs="1">
              <xs:complexType>
                <xs:simpleContent>
                  <xs:extension base="xs:string">
                    <xs:attribute name="inactive" type="xs:string" />
                  </xs:extension>
                </xs:simpleContent>
              </xs:complexType>
          </xs:element>
          <xs:any processContents="lax" minOccurs="0" maxOccurs="unbounded" />
        </xs:choice>
    </xs:sequence>
    <xs:attribute name="inactive" type="xs:string" />
  </xs:complexType>

  <xs:element name="configuration">
    <xs:complexType>
      <xs:sequence>
        <xs:choice minOccurs="1" maxOccurs="unbounded">
          <xs:element name="version" minOccurs="0" type="xs:string" />
          <xs:element name="groups" minOccurs="0" type="generic-sequence" />
          <xs:element name="apply-groups" minOccurs="0" type="xs:string" />
          <xs:element name="system" minOccurs="0" type="generic-sequence" />
          <xs:element name="logical-systems" minOccurs="0" type="generic-sequence" />
          <xs:element name="chassis" minOccurs="0" type="generic-sequence" />
          <xs:element name="services" minOccurs="0" type="generic-sequence" />
          <xs:element name="interfaces" minOccurs="0">
            <xs:complexType>
              <xs:sequence>
                <xs:choice minOccurs="1" maxOccurs="unbounded">
                    <xs:element name="apply-groups" minOccurs="0" type="xs:string" />
                    <xs:element name="interface-range" minOccurs="0" type="generic-sequence" />
                    <xs:element name="interface" minOccurs="1" maxOccurs="unbounded" type="juniper-interface" />
                </xs:choice>
              </xs:sequence>
            </xs:complexType>
          </xs:element>
          <xs:element name="snmp" minOccurs="0" type="generic-sequence" />
          <xs:element name="forwarding-options" minOccurs="0" type="generic-sequence" />
          <xs:element name="routing-options" minOccurs="0" type="generic-sequence" />
          <xs:element name="protocols" minOccurs="0" type="generic-sequence" />
          <xs:element name="policy-options" minOccurs="0" type="generic-sequence" />
          <xs:element name="class-of-service" minOccurs="0" type="generic-sequence" />
          <xs:element name="firewall" minOccurs="0" type="generic-sequence" />
          <xs:element name="routing-instances" minOccurs="0" type="generic-sequence" />
          <xs:element name="bridge-domains" minOccurs="0" type="generic-sequence" />
          <xs:element name="virtual-chassis" minOccurs="0" type="generic-sequence" />
          <xs:element name="vlans" minOccurs="0" type="generic-sequence" />
          <xs:element name="comment" minOccurs="0" type="xs:string" />
        </xs:choice>
      </xs:sequence>
      <xs:attribute name="changed-seconds" type="xs:string" />
      <xs:attribute name="changed-localtime" type="xs:string" />
    </xs:complexType>
  </xs:element>

</xs:schema>
"""  # noqa: E501

UNIT_SCHEMA = """<?xml version="1.1" encoding="UTF-8" ?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">

  <xs:complexType name="generic-sequence">
    <xs:sequence>
        <xs:any processContents="lax" minOccurs="0" maxOccurs="unbounded"/>
    </xs:sequence>
    <xs:anyAttribute processContents="skip" />
  </xs:complexType>

  <xs:element name="unit">
    <xs:complexType>
      <xs:sequence>
        <xs:element name="name" minOccurs="1" maxOccurs="1" type="xs:int" />
        <xs:element name="description" minOccurs="0" maxOccurs="1" type="xs:string" />
        <xs:any processContents="lax" minOccurs="0" maxOccurs="unbounded"/>
      </xs:sequence>
      <xs:attribute name="inactive" type="xs:string" />
    </xs:complexType>
  </xs:element>

</xs:schema>
"""  # noqa: E501


def _rpc(hostname, ssh):
    dev = Device(
        host=hostname,
        user=ssh['username'],
        ssh_private_key_file=ssh['private-key'])
    try:
        dev.open()
    except (EzErrors.ConnectError, EzErrors.RpcError) as e:
        raise ConnectionError(str(e))
    return dev.rpc


def validate_netconf_config(config_doc):
    logger = logging.getLogger(__name__)

    def _validate(schema, doc):
        if schema.validate(doc):
            return
        for e in schema.error_log:
            logger.error("%d.%d: %s" % (e.line, e.column, e.message))
        assert False

    schema_doc = etree.XML(CONFIG_SCHEMA.encode('utf-8'))
    config_schema = etree.XMLSchema(schema_doc)

    _validate(config_schema, config_doc)

    # validate interfaces/interface/unit elements ...
    schema_doc = etree.XML(UNIT_SCHEMA.encode('utf-8'))
    unit_schema = etree.XMLSchema(schema_doc)
    for i in config_doc.xpath('//configuration/interfaces/interface'):
        for u in i.xpath('./unit'):
            _validate(unit_schema, u)


def load_config(hostname, ssh_params):
    """
    loads netconf data from the router, validates and
    returns as an lxml etree doc

    :param hostname: router hostname
    :param ssh_params: 'ssh' config element(cf. config.py:CONFIG_SCHEMA)
    :return:
    """
    logger = logging.getLogger(__name__)
    logger.info("capturing netconf data for '%s'" % hostname)
    config = _rpc(hostname, ssh_params).get_config()
    validate_netconf_config(config)
    return config


def list_interfaces(netconf_config):
    """
    generator that parses netconf output and
    yields a list of interfaces

    :param netconf_config: xml doc that was generated by load_config
    :return:
    """

    def _ifc_info(e):
        # warning: this structure should match the default
        #          returned from routes.classifier.juniper_link_info
        name = e.find('name')
        assert name is not None, "expected interface 'name' child element"
        ifc = {
            'name': name.text,
            'description': '',
            'bundle': []
        }
        description = e.find('description')
        if description is not None:
            ifc['description'] = description.text

        for b in i.iterfind(".//bundle"):
            ifc['bundle'].append(b.text)

        ifc['ipv4'] = e.xpath('./family/inet/address/name/text()')
        ifc['ipv6'] = e.xpath('./family/inet6/address/name/text()')

        return ifc

    def _units(base_name, node):
        for u in node.xpath('./unit'):
            if u.get('inactive', None) == 'inactive':
                continue
            unit_info = _ifc_info(u)
            unit_info['name'] = "%s.%s" % (base_name, unit_info['name'])
            yield unit_info

    for i in netconf_config.xpath('//configuration/interfaces/interface'):
        info = _ifc_info(i)
        yield info
        for u in _units(info['name'], i):
            yield u

    for i in netconf_config.xpath(
            '//configuration/logical-systems/interfaces/interface'):
        name = i.find('name')
        assert name is not None, 'expected interface ''name'' child element'
        for u in _units(name.text, i):
            yield u


def list_bgp_routes(netconf_config):
    for r in netconf_config.xpath(
            '//configuration/routing-instances/'
            'instance[name/text()="IAS"]/protocols/bgp/'
            'group[starts-with(name/text(), "GEANT-IX")]/'
            'neighbor'):
        name = r.find('name')
        description = r.find('description')
        local_as = r.find('local-as')
        if local_as is not None:
            local_as = local_as.find('as-number')
        peer_as = r.find('peer-as')
        yield {
            'name': name.text,
            'description': description.text,
            'as': {
                'local': int(local_as.text),
                'peer': int(peer_as.text)
            }
        }


def ix_public_peers(netconf_config):
    for r in netconf_config.xpath(
            '//configuration/routing-instances/'
            'instance[name/text()="IAS"]/protocols/bgp/'
            'group[starts-with(name/text(), "GEANT-IX")]/'
            'neighbor'):
        name = r.find('name')
        description = r.find('description')
        local_as = r.find('local-as')
        if local_as is not None:
            local_as = local_as.find('as-number')
        peer_as = r.find('peer-as')
        yield {
            'name': ipaddress.ip_address(name.text).exploded,
            'description': description.text,
            'as': {
                'local': int(local_as.text),
                'peer': int(peer_as.text)
            }
        }


def vpn_rr_peers(netconf_config):
    for r in netconf_config.xpath(
            '//configuration/logical-systems[name/text()="VRR"]/'
            '/protocols/bgp/'
            'group[name/text()="VPN-RR" or name/text()="VPN-RR-INTERNAL"]/'
            'neighbor'):
        neighbor = {
            'name': ipaddress.ip_address(r.find('name').text).exploded,
            'description': r.find('description').text,
        }
        peer_as = r.find('peer-as')
        if peer_as is not None:
            neighbor['peer-as'] = int(r.find('peer-as').text)
        yield neighbor


def interface_addresses(netconf_config):
    """
    yields a list of all distinct interface addresses
    :param netconf_config:
    :return:
    """
    for ifc in list_interfaces(netconf_config):
        for address in ifc['ipv4'] + ifc['ipv6']:
            yield {
                "name": ipaddress.ip_interface(address).ip.exploded,
                "interface address": address,
                "interface name": ifc['name']
            }


# note for enabling vrr data parsing ...
# def fetch_vrr_config(hostname, ssh_params):
#
#     commands = [
#         _DISABLE_PAGING_COMMAND,
#         ('show configuration logical-systems '
#          'VRR protocols bgp | display json')
#     ]
#
#     output = list(ssh_exec_commands(hostname, ssh_params, commands))
#     assert len(output) == len(commands)
#
#     return _loads(output[1]) if output[1] else {}
#


def load_routers_from_netdash(url):
    """
    query url for a linefeed-delmitted list of managed router hostnames

    :param url: url of alldevices.txt file
    :return: list of router hostnames
    """
    r = requests.get(url=url)
    r.raise_for_status()
    return [
        ln.strip() for ln in r.text.splitlines() if ln.strip()
    ]


def local_interfaces(
        type=netifaces.AF_INET,
        omit_link_local=True,
        omit_loopback=True):
    """
    generator yielding IPv4Interface or IPv6Interface objects,
    depending on the value of type
    :param type: hopefully AF_INET or AF_INET6
    :param omit_link_local: skip v6 fe80* addresses if true
    :param omit_loopback: skip lo* interfaces if true
    :return:
    """
    for n in netifaces.interfaces():
        if omit_loopback and re.match(r'^lo\d+', n):
            continue
        am = netifaces.ifaddresses(n)
        for a in am.get(type, []):
            if omit_link_local and a['addr'].startswith('fe80:'):
                continue
            m = re.match(r'^(.+?)(%.*)?$', a['addr'])
            assert m
            addr = m.group(1)
            m = re.match(r'.*/(\d+)$', a['netmask'])
            if m:
                mask = m.group(1)
            else:
                mask = a['netmask']
            yield ipaddress.ip_interface('%s/%s' % (addr, mask))


def snmp_community_string(netconf_config):
    my_addressess = list([i.ip for i in local_interfaces()])
    for community in netconf_config.xpath('//configuration/snmp/community'):
        for subnet in community.xpath('./clients/name/text()'):
            allowed_network = ipaddress.ip_network(subnet, strict=False)
            for me in my_addressess:
                if me in allowed_network:
                    return community.xpath('./name/text()')[0]
    return None


def netconf_changed_timestamp(netconf_config):
    '''
    return the last change timestamp published by the config document
    :param netconf_config: netconf lxml etree document
    :return: an epoch timestamp (integer number of seconds) or None
    '''
    for ts in netconf_config.xpath('/configuration/@changed-seconds'):
        if re.match(r'^\d+$', ts):
            return int(ts)
    logger = logging.getLogger(__name__)
    logger.warning('no valid timestamp found in netconf configuration')
    return None