import contextlib
import logging
import re
import ipaddress

from jnpr.junos import Device
from jnpr.junos import exception as EzErrors
from lxml import etree
import netifaces
import requests

CONFIG_SCHEMA = """<?xml version="1.1" encoding="UTF-8" ?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">

  <xs:complexType name="generic-sequence">
    <xs:sequence>
        <xs:any processContents="lax" minOccurs="0" maxOccurs="unbounded"/>
    </xs:sequence>
    <xs:anyAttribute processContents="skip" />
  </xs:complexType>


  <!-- NOTE: 'unit' content isn't validated -->
  <xs:complexType name="juniper-interface">
    <xs:sequence>
        <xs:choice minOccurs="1" maxOccurs="unbounded">
          <xs:element name="name" minOccurs="1" maxOccurs="1" type="xs:string" />
          <xs:element name="description" minOccurs="0" maxOccurs="1">
              <xs:complexType>
                <xs:simpleContent>
                  <xs:extension base="xs:string">
                    <xs:attribute name="inactive" type="xs:string" />
                  </xs:extension>
                </xs:simpleContent>
              </xs:complexType>
          </xs:element>
          <xs:any processContents="lax" minOccurs="0" maxOccurs="unbounded" />
        </xs:choice>
    </xs:sequence>
    <xs:attribute name="inactive" type="xs:string" />
  </xs:complexType>

  <xs:element name="configuration">
    <xs:complexType>
      <xs:sequence>
        <xs:choice minOccurs="1" maxOccurs="unbounded">
          <xs:element name="transfer-on-commit" minOccurs="0" type="xs:string" />
          <xs:element name="archive-sites" minOccurs="0" type="generic-sequence" />
          <xs:element name="version" minOccurs="0" type="xs:string" />
          <xs:element name="groups" minOccurs="0" type="generic-sequence" />
          <xs:element name="apply-groups" minOccurs="0" type="xs:string" />
          <xs:element name="system" minOccurs="0" type="generic-sequence" />
          <xs:element name="logical-systems" minOccurs="0" type="generic-sequence" />
          <xs:element name="chassis" minOccurs="0" type="generic-sequence" />
          <xs:element name="services" minOccurs="0" type="generic-sequence" />
          <xs:element name="interfaces" minOccurs="0">
            <xs:complexType>
              <xs:sequence>
                <xs:choice minOccurs="1" maxOccurs="unbounded">
                    <xs:element name="apply-groups" minOccurs="0" type="xs:string" />
                    <xs:element name="interface-range" minOccurs="0" type="generic-sequence" />
                    <xs:element name="interface" minOccurs="1" maxOccurs="unbounded" type="juniper-interface" />
                </xs:choice>
              </xs:sequence>
            </xs:complexType>
          </xs:element>
          <xs:element name="snmp" minOccurs="0" type="generic-sequence" />
          <xs:element name="forwarding-options" minOccurs="0" type="generic-sequence" />
          <xs:element name="routing-options" minOccurs="0" type="generic-sequence" />
          <xs:element name="protocols" minOccurs="0" type="generic-sequence" />
          <xs:element name="policy-options" minOccurs="0" type="generic-sequence" />
          <xs:element name="class-of-service" minOccurs="0" type="generic-sequence" />
          <xs:element name="firewall" minOccurs="0" type="generic-sequence" />
          <xs:element name="routing-instances" minOccurs="0" type="generic-sequence" />
          <xs:element name="bridge-domains" minOccurs="0" type="generic-sequence" />
          <xs:element name="virtual-chassis" minOccurs="0" type="generic-sequence" />
          <xs:element name="vlans" minOccurs="0" type="generic-sequence" />
          <xs:element name="comment" minOccurs="0" type="xs:string" />
        </xs:choice>
      </xs:sequence>
      <xs:attribute name="changed-seconds" type="xs:string" />
      <xs:attribute name="changed-localtime" type="xs:string" />
      <xs:attribute name="inactive" type="xs:string" />
    </xs:complexType>
  </xs:element>

</xs:schema>
"""  # noqa: E501

UNIT_SCHEMA = """<?xml version="1.1" encoding="UTF-8" ?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">

  <xs:complexType name="generic-sequence">
    <xs:sequence>
        <xs:any processContents="lax" minOccurs="0" maxOccurs="unbounded"/>
    </xs:sequence>
    <xs:anyAttribute processContents="skip" />
  </xs:complexType>

  <xs:element name="unit">
    <xs:complexType>
      <xs:sequence>
        <xs:element name="name" minOccurs="1" maxOccurs="1" type="xs:int" />
        <xs:element name="description" minOccurs="0" maxOccurs="1" type="xs:string" />
        <xs:any processContents="lax" minOccurs="0" maxOccurs="unbounded"/>
      </xs:sequence>
      <xs:attribute name="inactive" type="xs:string" />
    </xs:complexType>
  </xs:element>

</xs:schema>
"""  # noqa: E501


class NetconfHandlingError(Exception):
    pass


@contextlib.contextmanager
def _rpc(hostname, ssh):
    dev = Device(
        host=hostname,
        user=ssh['username'],
        ssh_private_key_file=ssh['private-key'])
    try:
        dev.open()
        yield dev.rpc
    finally:
        dev.close()


def validate_netconf_config(config_doc):
    """
    :param config_doc:
    :return:
    :raises: NetconfHandlingError in case of validation errors
    """
    logger = logging.getLogger(__name__)

    def _validate(schema, doc):
        if schema.validate(doc):
            return
        messages = []
        for e in schema.error_log:
            msg = f'{e.line}.{e.column}: {e.message}'
            messages.append(msg)
            logger.error(msg)
        raise NetconfHandlingError('\n'.join(messages))

    schema_doc = etree.XML(CONFIG_SCHEMA.encode('utf-8'))
    config_schema = etree.XMLSchema(schema_doc)

    _validate(config_schema, config_doc)

    # validate interfaces/interface/unit elements ...
    schema_doc = etree.XML(UNIT_SCHEMA.encode('utf-8'))
    unit_schema = etree.XMLSchema(schema_doc)
    for i in config_doc.xpath('//configuration/interfaces/interface'):
        for u in i.xpath('./unit'):
            _validate(unit_schema, u)


def load_config(hostname, ssh_params, validate=True):
    """
    loads netconf data from the router, validates (by default) and
    returns as an lxml etree doc

    :param hostname: router hostname
    :param ssh_params: 'ssh' config element(cf. config.py:CONFIG_SCHEMA)
    :param validate: whether or not to validate netconf data (default True)
    :return:
    :raises: NetconfHandlingError or ConnectionError
    """
    logger = logging.getLogger(__name__)
    logger.info("capturing netconf data for '%s'" % hostname)
    try:
        with _rpc(hostname, ssh_params) as router:
            config = router.get_config()
        if validate:
            validate_netconf_config(config)
        return config
    except (EzErrors.ConnectError, EzErrors.RpcError) as e:
        raise ConnectionError(str(e))


def list_interfaces(netconf_config):
    """
    generator that parses netconf output and
    yields a list of interfaces

    :param netconf_config: xml doc that was generated by load_config
    :return:
    """

    def _ifc_info(e):
        # warning: this structure should match the default
        #          returned from routes.classifier.juniper_link_info
        name = e.find('name')
        assert name is not None, "expected interface 'name' child element"
        ifc = {
            'name': name.text,
            'description': '',
            'bundle': []
        }
        description = e.find('description')
        if description is not None:
            ifc['description'] = description.text

        for b in i.iterfind(".//bundle"):
            ifc['bundle'].append(b.text)

        ifc['ipv4'] = e.xpath('./family/inet/address/name/text()')
        ifc['ipv6'] = e.xpath('./family/inet6/address/name/text()')

        return ifc

    def _inactive(interface_node):
        return interface_node.get('inactive', None) == 'inactive'

    def _units(base_name, interface_node):
        for u in interface_node.xpath('./unit'):
            if _inactive(u):
                continue
            unit_info = _ifc_info(u)
            unit_info['name'] = f'{base_name}.{unit_info["name"]}'
            yield unit_info

    for i in netconf_config.xpath('//configuration/interfaces/interface'):
        if _inactive(i):
            continue
        info = _ifc_info(i)
        yield info
        yield from _units(info['name'], i)

    for ls_node in netconf_config.xpath('//configuration/logical-systems'):
        logical_system = ls_node.xpath('./name/text()')
        assert logical_system, 'no logical-system name found'
        for i in ls_node.xpath('.//interfaces/interface'):
            name = i.xpath('./name/text()')
            assert name, "expected interface 'name' child element"
            for u in _units(name[0], i):
                u['logical-system'] = logical_system[0]
                yield u


def asn_to_int(asn_string: str) -> int:
    """
    Convert a possibly dotted ASN to an integer.

    Args:
    asn_string (str): ASN to be converted, can be in dot notation or not.

    Returns:
    int: ASN in integer format.

    Raises:
    ValueError: If the ASN string is not in the expected format or exceeds valid range.
    """

    dotted_asn_pattern = re.compile(r'^(\d+)\.(\d+)$')
    match = dotted_asn_pattern.match(asn_string)

    if match:
        high_order, low_order = map(int, match.groups())

        if high_order > 0xffff or low_order > 0xffff:
            raise ValueError(f'Invalid ASN format: {asn_string}. Both components must be <= 0xffff.')

        return (high_order << 16) | low_order
    elif asn_string.isdigit():
        return int(asn_string)
    else:
        raise ValueError(f'Unable to parse ASN string: {asn_string}. Expected either a pure integer or a dot notation.')


def _system_bgp_peers(system_node):

    def _peering_params(neighbor_node):
        address = neighbor_node.find('name').text
        info = {'address': ipaddress.ip_address(address).exploded}
        peer_as = neighbor_node.find('peer-as')
        if peer_as is not None:
            # lxml usage warning: can't just test `if peer_as:`
            info['remote-asn'] = asn_to_int(peer_as.text)
        local_as = neighbor_node.find('local-as')
        if local_as is not None:
            asn_value_node = local_as.find('as-number')
            info['local-asn'] = asn_to_int(asn_value_node.text)
        description = neighbor_node.find('description')
        if description is not None:
            # lxml usage warning: can't just test `if description:`
            info['description'] = description.text
        return info

    def _neighbors(group_node):
        for neighbor in group_node.xpath('./neighbor'):
            inactive = neighbor.get('inactive')
            if inactive == 'inactive':
                continue
            yield _peering_params(neighbor)

    for group in system_node.xpath('./protocols/bgp/group'):
        group_name = group.find('name').text
        for peer in _neighbors(group):
            peer['group'] = group_name
            yield peer

    for instance in system_node.xpath(
            './routing-instances/instance'):
        instance_name = instance.find('name').text
        for peer in _system_bgp_peers(instance):
            peer['instance'] = instance_name
            yield peer


def all_bgp_peers(netconf_config):
    """
    Return all active bgp peering sessions defined for this router.

    The response will be a generator, which renders a list
    formatted according to the following schema:

    .. asjson::
       inventory_provider.routes.msr.PEERING_LIST_SCHEMA

    EXCEPT: the 'hostname' parameter is not present

    :param netconf_config:
    :return: yields active peering sessions
    """

    for base_system in netconf_config.xpath('//configuration'):
        # there should only be one
        yield from _system_bgp_peers(base_system)

    for logical_system in netconf_config.xpath(
            '//configuration/logical-systems'):
        logical_system_name = logical_system.find('name').text
        for peer in _system_bgp_peers(logical_system):
            peer['logical-system'] = logical_system_name
            yield peer


def interface_addresses(netconf_config):
    """
    Yields a list of all distinct interface addresses.

    :param netconf_config:
    :return:
    """
    for ifc in list_interfaces(netconf_config):
        for address in ifc['ipv4'] + ifc['ipv6']:
            yield {
                "name": ipaddress.ip_interface(address).ip.exploded,
                "interface address": address,
                "interface name": ifc['name']
            }


def load_routers_from_netdash(url):
    """
    Query url for a linefeed-delimitted list of managed router hostnames.

    :param url: url of alldevices.txt file
    :return: list of router hostnames
    """
    r = requests.get(url=url)
    r.raise_for_status()
    return [
        ln.strip() for ln in r.text.splitlines() if ln.strip()
    ]


def local_interfaces(
        type=netifaces.AF_INET,
        omit_link_local=True,
        omit_loopback=True):
    """
    Generator yielding IPv4Interface or IPv6Interface objects for
    the interfaces present on the local system,
    depending on the value of type.

    :param type: hopefully AF_INET or AF_INET6
    :param omit_link_local: skip v6 fe80* addresses if true
    :param omit_loopback: skip lo* interfaces if true
    :return:
    """
    for n in netifaces.interfaces():
        if omit_loopback and re.match(r'^lo\d+', n):
            continue
        am = netifaces.ifaddresses(n)
        for a in am.get(type, []):
            if omit_link_local and a['addr'].startswith('fe80:'):
                continue
            m = re.match(r'^(.+?)(%.*)?$', a['addr'])
            assert m
            addr = m.group(1)
            m = re.match(r'.*/(\d+)$', a['netmask'])
            if m:
                mask = m.group(1)
            else:
                mask = a['netmask']
            yield ipaddress.ip_interface('%s/%s' % (addr, mask))


def snmp_community_string(netconf_config):
    my_addressess = list([i.ip for i in local_interfaces()])
    for community in netconf_config.xpath('//configuration/snmp/community'):
        for subnet in community.xpath('./clients/name/text()'):
            allowed_network = ipaddress.ip_network(subnet, strict=False)
            for me in my_addressess:
                if me in allowed_network:
                    return community.xpath('./name/text()')[0]
    return None


def netconf_changed_timestamp(netconf_config):
    """
    Return the last change timestamp published by the config document.

    :param netconf_config: netconf lxml etree document
    :return: an epoch timestamp (integer number of seconds) or None
    """
    for ts in netconf_config.xpath('/configuration/@changed-seconds'):
        if re.match(r'^\d+$', ts):
            return int(ts)
    logger = logging.getLogger(__name__)
    logger.warning('no valid timestamp found in netconf configuration')
    return None


def logical_systems(netconf_config):
    """
    Return a list of logical system names for the router.

    It's not an error if a router has no defined logical systems.

    :param netconf_config: netconf lxml etree document
    :return: a list of strings
    """
    return netconf_config.xpath('//configuration/logical-systems/name/text()')