import json
import logging
import os
import re

from celery import Task, states
from celery.result import AsyncResult

from collections import defaultdict
from lxml import etree

from inventory_provider.tasks.app import app
from inventory_provider.tasks.common import get_redis
from inventory_provider import config
from inventory_provider import environment
from inventory_provider.db import db, opsdb
from inventory_provider import snmp
from inventory_provider import juniper

# TODO: error callback (cf. http://docs.celeryproject.org/en/latest/userguide/calling.html#linking-callbacks-errbacks)  # noqa: E501

environment.setup_logging()


class InventoryTaskError(Exception):
    pass


class InventoryTask(Task):

    config = None

    def __init__(self):

        if InventoryTask.config:
            return

        assert os.path.isfile(
            os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']), (
                'config file %r not found' %
                os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'])

        with open(os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']) as f:
            logging.info(
                    "Initializing worker with config from: %r" %
                    os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'])
            InventoryTask.config = config.load(f)
            logging.debug("loaded config: %r" % InventoryTask.config)

    def update_state(self, **kwargs):
        logger = logging.getLogger(__name__)
        logger.debug(json.dumps(
            {'state': kwargs['state'], 'meta': kwargs['meta']}
        ))
        super().update_state(**kwargs)


def _save_value(key, value):
    assert isinstance(value, str), \
        "sanity failure: expected string data as value"
    r = get_redis(InventoryTask.config)
    r.set(name=key, value=value)
    # InventoryTask.logger.debug("saved %s" % key)
    return "OK"


def _save_value_json(key, data_obj):
    _save_value(
        key,
        json.dumps(data_obj))


def _save_value_etree(key, xml_doc):
    _save_value(
        key,
        etree.tostring(xml_doc, encoding='unicode'))


@app.task
def snmp_refresh_interfaces(hostname, community):
    logger = logging.getLogger(__name__)
    logger.debug(
        '>>> snmp_refresh_interfaces(%r, %r)' % (hostname, community))

    _save_value_json(
        'snmp-interfaces:' + hostname,
        list(snmp.get_router_snmp_indexes(
            hostname,
            community)))

    logger.debug(
        '<<< snmp_refresh_interfaces(%r, %r)' % (hostname, community))


@app.task
def netconf_refresh_config(hostname):
    logger = logging.getLogger(__name__)
    logger.debug('>>> netconf_refresh_config(%r)' % hostname)

    _save_value_etree(
        'netconf:' + hostname,
        juniper.load_config(hostname, InventoryTask.config["ssh"]))

    logger.debug('<<< netconf_refresh_config(%r)' % hostname)


@app.task
def update_interfaces_to_services():
    logger = logging.getLogger(__name__)
    logger.debug('>>> update_interfaces_to_services')

    interface_services = defaultdict(list)
    with db.connection(InventoryTask.config["ops-db"]) as cx:
        for service in opsdb.get_circuits(cx):
            equipment_interface = '%s:%s' % (
                service['equipment'], service['interface_name'])
            interface_services[equipment_interface].append(service)

    r = get_redis(InventoryTask.config)
    for key in r.scan_iter('opsdb:interface_services:*'):
        r.delete(key)
    for equipment_interface, services in interface_services.items():
        r.set(
            'opsdb:interface_services:' + equipment_interface,
            json.dumps(services))

    logger.debug('<<< update_interfaces_to_services')


@app.task
def update_equipment_locations():
    logger = logging.getLogger(__name__)
    logger.debug('>>> update_equipment_locations')

    r = get_redis(InventoryTask.config)
    for k in r.scan_iter('opsdb:location:*'):
        r.delete(k)

    with db.connection(InventoryTask.config["ops-db"]) as cx:
        for h in _derive_router_hostnames(InventoryTask.config):
            # lookup_pop_info returns a list of locations
            # (there can sometimes be more than one match)
            locations = list(opsdb.lookup_pop_info(cx, h))
            r.set('opsdb:location:%s' % h, json.dumps(locations))

    logger.debug('<<< update_equipment_locations')


@app.task
def update_circuit_hierarchy():
    logger = logging.getLogger(__name__)
    logger.debug('>>> update_circuit_hierarchy')

    # TODO: integers are not JSON keys
    with db.connection(InventoryTask.config["ops-db"]) as cx:
        child_to_parents = defaultdict(list)
        parent_to_children = defaultdict(list)
        for relation in opsdb.get_circuit_hierarchy(cx):
            parent_id = relation["parent_circuit_id"]
            child_id = relation["child_circuit_id"]
            parent_to_children[parent_id].append(relation)
            child_to_parents[child_id].append(relation)

        r = get_redis(InventoryTask.config)
        for key in r.scan_iter('opsdb:services:parents:*'):
            r.delete(key)
        for cid, parents in child_to_parents.items():
            r.set('opsdb:services:parents:%d' % cid, json.dumps(parents))

        for key in r.scan_iter('opsdb:services:children:*'):
            r.delete(key)
        for cid, children in child_to_parents.items():
            r.set('opsdb:services:children:%d' % cid, json.dumps(children))

    logger.debug('<<< update_circuit_hierarchy')


@app.task
def update_geant_lambdas():
    logger = logging.getLogger(__name__)
    logger.debug('>>> update_geant_lambdas')

    r = get_redis(InventoryTask.config)
    for key in r.scan_iter('opsdb:geant_lambdas:*'):
        r.delete(key)
    with db.connection(InventoryTask.config["ops-db"]) as cx:
        for ld in opsdb.get_geant_lambdas(cx):
            r.set('opsdb:geant_lambdas:%s' % ld['name'].lower(),
                  json.dumps(ld))

    logger.debug('<<< geant_lambdas')


@app.task(base=InventoryTask, bind=True)
def update_junosspace_device_list(self):
    logger = logging.getLogger(__name__)
    logger.debug('>>> update_junosspace_device_list')

    self.update_state(
        state=states.STARTED,
        meta={
            'task': 'update_junosspace_device_list',
            'message': 'querying junosspace for managed routers'
        })

    r = get_redis(InventoryTask.config)

    routers = {}
    for d in juniper.load_routers_from_junosspace(
            InventoryTask.config['junosspace']):
        routers['junosspace:' + d['hostname']] = json.dumps(d).encode('utf-8')

    self.update_state(
        state=states.STARTED,
        meta={
            'task': 'update_junosspace_device_list',
            'message': 'found %d routers, saving details' % len(routers)
        })

    for k in r.keys('junosspace:*'):
        r.delete(k)
    for k, v in routers.items():
        r.set(k, v)

    logger.debug('<<< update_junosspace_device_list')

    return {
        'task': 'update_junosspace_device_list',
        'message': 'saved %d managed routers' % len(routers)
    }


def load_netconf_data(hostname):
    """
    this method should only be called from a task

    :param hostname:
    :return:
    """
    r = get_redis(InventoryTask.config)
    netconf = r.get('netconf:' + hostname)
    if not netconf:
        raise InventoryTaskError('no netconf data found for %r' % hostname)
    return etree.fromstring(netconf.decode('utf-8'))


def clear_cached_classifier_responses(hostname=None):
    logger = logging.getLogger(__name__)
    if hostname:
        logger.debug(
            'removing cached classifier responses for %r' % hostname)
    else:
        logger.debug('removing all cached classifier responses')

    r = get_redis(InventoryTask.config)

    def _hostname_keys():
        for k in r.keys('classifier-cache:juniper:%s:*' % hostname):
            yield k

        # TODO: very inefficient ... but logically simplest at this point
        for k in r.keys('classifier-cache:peer:*'):
            value = r.get(k.decode('utf-8'))
            if not value:
                # deleted in another thread
                continue
            value = json.loads(value.decode('utf-8'))
            interfaces = value.get('interfaces', [])
            if hostname in [i['interface']['router'] for i in interfaces]:
                yield k

    def _all_keys():
        return r.keys('classifier-cache:*')

    keys_to_delete = _hostname_keys() if hostname else _all_keys()
    for k in keys_to_delete:
        r.delete(k)


def _refresh_peers(hostname, key_base, peers):
    logger = logging.getLogger(__name__)
    logger.debug(
        'removing cached %s for %r' % (key_base, hostname))
    r = get_redis(InventoryTask.config)
    for k in r.keys(key_base + ':*'):
        # potential race condition: another proc could have
        # delete this element between the time we read the
        # keys and the next statement ... check for None below
        value = r.get(k.decode('utf-8'))
        if value:
            value = json.loads(value.decode('utf-8'))
            if value['router'] == hostname:
                r.delete(k)

    for peer in peers:
        peer['router'] = hostname
        r.set(
            '%s:%s' % (key_base, peer['name']),
            json.dumps(peer))


def refresh_ix_public_peers(hostname, netconf):
    _refresh_peers(
        hostname,
        'ix_public_peer',
        juniper.ix_public_peers(netconf))


def refresh_vpn_rr_peers(hostname, netconf):
    _refresh_peers(
        hostname,
        'vpn_rr_peer',
        juniper.vpn_rr_peers(netconf))


def refresh_interface_address_lookups(hostname, netconf):
    _refresh_peers(
        hostname,
        'reverse_interface_addresses',
        juniper.interface_addresses(netconf))


def refresh_juniper_interface_list(hostname, netconf):
    logger = logging.getLogger(__name__)
    logger.debug(
        'removing cached netconf-interfaces for %r' % hostname)

    r = get_redis(InventoryTask.config)
    for k in r.keys('netconf-interfaces:%s:*' % hostname):
        r.delete(k)

    for k in r.keys('netconf-interface-bundles:%s:*' % hostname):
        r.delete(k)

    all_bundles = defaultdict(list)
    for ifc in juniper.list_interfaces(netconf):
        bundles = ifc.get('bundle', None)
        for bundle in bundles:
            if bundle:
                all_bundles[bundle].append(ifc['name'])

        r.set(
            'netconf-interfaces:%s:%s' % (hostname, ifc['name']),
            json.dumps(ifc))
    for k, v in all_bundles.items():
        r.set(
            'netconf-interface-bundles:%s:%s' % (hostname, k),
            json.dumps(v))


@app.task(base=InventoryTask, bind=True)
def reload_router_config(self, hostname):
    logger = logging.getLogger(__name__)
    logger.debug('>>> reload_router_config')

    self.update_state(
        state=states.STARTED,
        meta={
            'task': 'reload_router_config',
            'hostname': hostname,
            'message': 'loading router netconf data'
        })

    # get the timestamp for the current netconf data
    current_netconf_timestamp = None
    try:
        netconf_doc = load_netconf_data(hostname)
        current_netconf_timestamp \
            = juniper.netconf_changed_timestamp(netconf_doc)
        logger.debug(
            'current netconf timestamp: %r' % current_netconf_timestamp)
    except InventoryTaskError:
        pass  # ok at this point if not found

    # load new netconf data
    netconf_refresh_config.apply(args=[hostname])

    netconf_doc = load_netconf_data(hostname)

    # return if new timestamp is the same as the original timestamp
    new_netconf_timestamp = juniper.netconf_changed_timestamp(netconf_doc)
    assert new_netconf_timestamp, \
        'no timestamp available for new netconf data'
    if new_netconf_timestamp == current_netconf_timestamp:
        logger.debug('no netconf change timestamp change, aborting')
        logger.debug('<<< reload_router_config')
        return {
            'task': 'reload_router_config',
            'hostname': hostname,
            'message': 'OK (no change)'
        }

    # clear cached classifier responses for this router, and
    # refresh peering data
    self.update_state(
        state=states.STARTED,
        meta={
            'task': 'reload_router_config',
            'hostname': hostname,
            'message': 'refreshing peers & clearing cache'
        })
    refresh_ix_public_peers(hostname, netconf_doc)
    refresh_vpn_rr_peers(hostname, netconf_doc)
    refresh_interface_address_lookups(hostname, netconf_doc)
    refresh_juniper_interface_list(hostname, netconf_doc)
    # clear_cached_classifier_responses(hostname)

    # load snmp indexes
    community = juniper.snmp_community_string(netconf_doc)
    if not community:
        raise InventoryTaskError(
            'error extracting community string for %r' % hostname)
    else:
        self.update_state(
            state=states.STARTED,
            meta={
                'task': 'reload_router_config',
                'hostname': hostname,
                'message': 'refreshing snmp interface indexes'
            })
        snmp_refresh_interfaces.apply(args=[hostname, community])

    clear_cached_classifier_responses(None)

    logger.debug('<<< reload_router_config')

    return {
        'task': 'reload_router_config',
        'hostname': hostname,
        'message': 'OK'
    }


def _derive_router_hostnames(config):
    logger = logging.getLogger(__name__)
    r = get_redis(config)
    junosspace_equipment = set()
    for k in r.keys('junosspace:*'):
        m = re.match('^junosspace:(.*)$', k.decode('utf-8'))
        assert m
        junosspace_equipment.add(m.group(1))

    opsdb_equipment = set()
    for k in r.keys('opsdb:interface_services:*'):
        m = re.match(
            'opsdb:interface_services:([^:]+):.*$',
            k.decode('utf-8'))
        if m:
            opsdb_equipment.add(m.group(1))
        else:
            logger.info("Unable to derive router name from %s" %
                        k.decode('utf-8'))
    return junosspace_equipment & opsdb_equipment


def launch_refresh_cache_all(config):
    """
    utility function intended to be called outside of the worker process
    :param config: config structure as defined in config.py
    :return:
    """
    logger = logging.getLogger(__name__)

    # first batch of subtasks: refresh cached opsdb data
    subtasks = [
        update_junosspace_device_list.apply_async(),
        update_interfaces_to_services.apply_async(),
        update_geant_lambdas.apply_async(),
        update_circuit_hierarchy.apply_async()
    ]
    [x.get() for x in subtasks]

    # second batch of subtasks:
    #   alarms db status cache
    #   juniper netconf & snmp data
    subtasks = [
        update_equipment_locations.apply_async(),
    ]
    for hostname in _derive_router_hostnames(config):
        logger.debug(
            'queueing router refresh jobs for %r' % hostname)
        subtasks.append(reload_router_config.apply_async(args=[hostname]))

    [x.get() for x in subtasks]

    return [x.id for x in subtasks]


def check_task_status(task_id):
    r = AsyncResult(task_id, app=app)
    result = {
        'id': r.id,
        'status': r.status,
        'exception': r.status in states.EXCEPTION_STATES,
        'ready': r.status in states.READY_STATES,
        'success': r.status == states.SUCCESS,
    }
    if r.result:
        # TODO: only discovered this case by testing, is this the only one?
        #       ... otherwise need to pre-test json serialization
        if isinstance(r.result, Exception):
            result['result'] = {
                'error type': type(r.result).__name__,
                'message': str(r.result)
            }
        else:
            result['result'] = r.result
    return result