import json
import logging
import os
import time

from celery import Task, states
from celery.result import AsyncResult

from collections import defaultdict
from lxml import etree
import jsonschema

from inventory_provider.tasks.app import app
from inventory_provider.tasks.common \
    import get_next_redis, latch_db, get_latch, set_latch, update_latch_status
from inventory_provider.tasks import data
from inventory_provider import config
from inventory_provider import environment
from inventory_provider.db import db, opsdb
from inventory_provider import snmp
from inventory_provider import juniper

FINALIZER_POLLING_FREQUENCY_S = 2.5
FINALIZER_TIMEOUT_S = 300

# TODO: error callback (cf. http://docs.celeryproject.org/en/latest/userguide/calling.html#linking-callbacks-errbacks)  # noqa: E501

environment.setup_logging()

logger = logging.getLogger(__name__)


def log_task_entry_and_exit(f):
    # cf. https://stackoverflow.com/a/47663642
    def _w(*args, **kwargs):
        logger.debug(f'>>> {f.__name__}{args}')
        try:
            return f(*args, *kwargs)
        finally:
            logger.debug(f'<<< {f.__name__}{args}')
    return _w


class InventoryTaskError(Exception):
    pass


class InventoryTask(Task):

    config = None

    def __init__(self):

        if InventoryTask.config:
            return

        assert os.path.isfile(
            os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']), (
                'config file %r not found' %
                os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'])

        with open(os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']) as f:
            logging.info(
                    "Initializing worker with config from: %r" %
                    os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'])
            InventoryTask.config = config.load(f)
            logging.debug("loaded config: %r" % InventoryTask.config)

    def update_state(self, **kwargs):
        logger.debug(json.dumps(
            {'state': kwargs['state'], 'meta': str(kwargs['meta'])}
        ))
        super().update_state(**kwargs)

    def on_failure(self, exc, task_id, args, kwargs, einfo):
        logger.exception(exc)
        super().on_failure(exc, task_id, args, kwargs, einfo)


@app.task(base=InventoryTask, bind=True, name='snmp_refresh_interfaces')
@log_task_entry_and_exit
def snmp_refresh_interfaces(self, hostname, community):
    value = list(snmp.get_router_snmp_indexes(hostname, community))
    r = get_next_redis(InventoryTask.config)
    r.set('snmp-interfaces:' + hostname, json.dumps(value))


@app.task(base=InventoryTask, bind=True, name='netconf_refresh_config')
@log_task_entry_and_exit
def netconf_refresh_config(self, hostname):
    netconf_doc = juniper.load_config(hostname, InventoryTask.config["ssh"])
    netconf_str = etree.tostring(netconf_doc, encoding='unicode')
    r = get_next_redis(InventoryTask.config)
    r.set('netconf:' + hostname, netconf_str)


@app.task(base=InventoryTask, bind=True, name='update_interfaces_to_services')
@log_task_entry_and_exit
def update_interfaces_to_services(self):
    interface_services = defaultdict(list)
    with db.connection(InventoryTask.config["ops-db"]) as cx:
        for service in opsdb.get_circuits(cx):
            equipment_interface = '%s:%s' % (
                service['equipment'], service['interface_name'])
            interface_services[equipment_interface].append(service)

    r = get_next_redis(InventoryTask.config)
    rp = r.pipeline()
    for key in r.scan_iter('opsdb:interface_services:*'):
        rp.delete(key)
    rp.execute()

    rp = r.pipeline()
    for equipment_interface, services in interface_services.items():
        rp.set(
            f'opsdb:interface_services:{equipment_interface}',
            json.dumps(services))
    rp.execute()


@app.task(base=InventoryTask, bind=True, name='import_unmanaged_interfaces')
@log_task_entry_and_exit
def import_unmanaged_interfaces(self):

    def _convert(d):
        # the config file keys are more readable than
        # the keys used in redis
        return {
            'name': d['address'],
            'interface address': d['network'],
            'interface name': d['interface'].lower(),
            'router': d['router'].lower()
        }

    interfaces = [
        _convert(ifc) for ifc
        in InventoryTask.config.get('unmanaged-interfaces', [])
    ]

    if interfaces:
        r = get_next_redis(InventoryTask.config)
        rp = r.pipeline()
        for ifc in interfaces:
            rp.set(
                f'reverse_interface_addresses:{ifc["name"]}',
                json.dumps(ifc))
            rp.set(
                f'subnets:{ifc["interface address"]}',
                json.dumps([ifc]))
        rp.execute()


@app.task(base=InventoryTask, bind=True, name='update_access_services')
@log_task_entry_and_exit
def update_access_services(self):

    access_services = {}
    with db.connection(InventoryTask.config["ops-db"]) as cx:
        for service in opsdb.get_access_services(cx):

            if service['name'] in access_services:
                logger.warning(
                    'got multiple access services '
                    f'with name "{service["name"]}"')

            access_services[service['name']] = service

    r = get_next_redis(InventoryTask.config)
    rp = r.pipeline()
    for key in r.scan_iter('opsdb:access_services:*'):
        rp.delete(key)
    rp.execute()

    rp = r.pipeline()
    for name, service in access_services.items():
        rp.set(
            f'opsdb:access_services:{name}',
            json.dumps(service))
    rp.execute()


@app.task(base=InventoryTask, bind=True, name='update_lg_routers')
@log_task_entry_and_exit
def update_lg_routers(self):

    r = get_next_redis(InventoryTask.config)
    rp = r.pipeline()
    for k in r.scan_iter('opsdb:lg:*'):
        rp.delete(k)
    rp.execute()

    with db.connection(InventoryTask.config["ops-db"]) as cx:
        rp = r.pipeline()
        for router in opsdb.lookup_lg_routers(cx):
            rp.set(f'opsdb:lg:{router["equipment name"]}', json.dumps(router))
        rp.execute()


@app.task(base=InventoryTask, bind=True, name='update_equipment_locations')
@log_task_entry_and_exit
def update_equipment_locations(self):
    r = get_next_redis(InventoryTask.config)
    rp = r.pipeline()
    for k in r.scan_iter('opsdb:location:*'):
        rp.delete(k)
    rp.execute()

    with db.connection(InventoryTask.config["ops-db"]) as cx:
        rp = r.pipeline()
        for h in data.derive_router_hostnames(InventoryTask.config):
            # lookup_pop_info returns a list of locations
            # (there can sometimes be more than one match)
            locations = list(opsdb.lookup_pop_info(cx, h))
            rp.set('opsdb:location:%s' % h, json.dumps(locations))
        rp.execute()


@app.task(base=InventoryTask, bind=True, name='update_circuit_hierarchy')
@log_task_entry_and_exit
def update_circuit_hierarchy(self):

    # TODO: integers are not JSON keys
    with db.connection(InventoryTask.config["ops-db"]) as cx:
        child_to_parents = defaultdict(list)
        parent_to_children = defaultdict(list)
        for relation in opsdb.get_circuit_hierarchy(cx):
            parent_id = relation["parent_circuit_id"]
            child_id = relation["child_circuit_id"]
            parent_to_children[parent_id].append(relation)
            child_to_parents[child_id].append(relation)

        r = get_next_redis(InventoryTask.config)
        rp = r.pipeline()
        for key in r.scan_iter('opsdb:services:parents:*'):
            rp.delete(key)
        for key in r.scan_iter('opsdb:services:children:*'):
            rp.delete(key)
        rp.execute()

        rp = r.pipeline()
        for cid, parents in parent_to_children.items():
            rp.set('opsdb:services:parents:%d' % cid, json.dumps(parents))
        for cid, children in child_to_parents.items():
            rp.set('opsdb:services:children:%d' % cid, json.dumps(children))
        rp.execute()


@app.task(base=InventoryTask, bind=True, name='update_geant_lambdas')
@log_task_entry_and_exit
def update_geant_lambdas(self):

    r = get_next_redis(InventoryTask.config)
    rp = r.pipeline()
    for key in r.scan_iter('opsdb:geant_lambdas:*'):
        rp.delete(key)
    rp.execute()

    with db.connection(InventoryTask.config["ops-db"]) as cx:
        rp = r.pipeline()
        for ld in opsdb.get_geant_lambdas(cx):
            rp.set(
                'opsdb:geant_lambdas:%s' % ld['name'].lower(),
                json.dumps(ld))
        rp.execute()


@app.task(base=InventoryTask, bind=True,
          name='update_neteng_managed_device_list')
@log_task_entry_and_exit
def update_neteng_managed_device_list(self):
    self.update_state(
        state=states.STARTED,
        meta={
            'task': 'update_neteng_managed_device_list',
            'message': 'querying netdash for managed routers'
        })

    routers = list(juniper.load_routers_from_netdash(
        InventoryTask.config['managed-routers']))

    self.update_state(
        state=states.STARTED,
        meta={
            'task': 'update_neteng_managed_device_list',
            'message': f'found {len(routers)} routers, saving details'
        })

    r = get_next_redis(InventoryTask.config)
    r.set('netdash', json.dumps(routers).encode('utf-8'))

    return {
        'task': 'update_neteng_managed_device_list',
        'message': 'saved %d managed routers' % len(routers)
    }


def load_netconf_data(hostname):
    """
    this method should only be called from a task

    :param hostname:
    :return:
    """
    r = get_next_redis(InventoryTask.config)
    netconf = r.get('netconf:' + hostname)
    if not netconf:
        raise InventoryTaskError('no netconf data found for %r' % hostname)
    return etree.fromstring(netconf.decode('utf-8'))


def clear_cached_classifier_responses(hostname=None):
    if hostname:
        logger.debug(
            'removing cached classifier responses for %r' % hostname)
    else:
        logger.debug('removing all cached classifier responses')

    r = get_next_redis(InventoryTask.config)

    def _hostname_keys():
        for k in r.keys('classifier-cache:juniper:%s:*' % hostname):
            yield k

        # TODO: very inefficient ... but logically simplest at this point
        for k in r.keys('classifier-cache:peer:*'):
            value = r.get(k.decode('utf-8'))
            if not value:
                # deleted in another thread
                continue
            value = json.loads(value.decode('utf-8'))
            interfaces = value.get('interfaces', [])
            if hostname in [i['interface']['router'] for i in interfaces]:
                yield k

    def _all_keys():
        return r.keys('classifier-cache:*')

    keys_to_delete = _hostname_keys() if hostname else _all_keys()
    rp = r.pipeline()
    for k in keys_to_delete:
        rp.delete(k)
    rp.execute()


def _refresh_peers(hostname, key_base, peers):
    logger.debug(
        'removing cached %s for %r' % (key_base, hostname))
    r = get_next_redis(InventoryTask.config)
    # WARNING (optimization): this is an expensive query if
    #       the redis connection is slow, and we currently only
    #       call this method during a full refresh
    # for k in r.scan_iter(key_base + ':*'):
    #     # potential race condition: another proc could have
    #     # delete this element between the time we read the
    #     # keys and the next statement ... check for None below
    #     value = r.get(k.decode('utf-8'))
    #     if value:
    #         value = json.loads(value.decode('utf-8'))
    #         if value['router'] == hostname:
    #             r.delete(k)

    rp = r.pipeline()
    for peer in peers:
        peer['router'] = hostname
        rp.set(
            '%s:%s' % (key_base, peer['name']),
            json.dumps(peer))
    rp.execute()


def refresh_ix_public_peers(hostname, netconf):
    _refresh_peers(
        hostname,
        'ix_public_peer',
        juniper.ix_public_peers(netconf))


def refresh_vpn_rr_peers(hostname, netconf):
    _refresh_peers(
        hostname,
        'vpn_rr_peer',
        juniper.vpn_rr_peers(netconf))


def refresh_interface_address_lookups(hostname, netconf):
    _refresh_peers(
        hostname,
        'reverse_interface_addresses',
        juniper.interface_addresses(netconf))


def refresh_juniper_interface_list(hostname, netconf):
    logger.debug(
        'removing cached netconf-interfaces for %r' % hostname)

    r = get_next_redis(InventoryTask.config)
    rp = r.pipeline()
    for k in r.scan_iter('netconf-interfaces:%s:*' % hostname):
        rp.delete(k)
    for k in r.keys('netconf-interface-bundles:%s:*' % hostname):
        rp.delete(k)
    rp.execute()

    all_bundles = defaultdict(list)

    rp = r.pipeline()
    for ifc in juniper.list_interfaces(netconf):
        bundles = ifc.get('bundle', None)
        for bundle in bundles:
            if bundle:
                all_bundles[bundle].append(ifc['name'])
        rp.set(
            'netconf-interfaces:%s:%s' % (hostname, ifc['name']),
            json.dumps(ifc))
    for k, v in all_bundles.items():
        rp.set(
            'netconf-interface-bundles:%s:%s' % (hostname, k),
            json.dumps(v))
    rp.execute()


@app.task(base=InventoryTask, bind=True, name='reload_router_config')
@log_task_entry_and_exit
def reload_router_config(self, hostname):
    self.update_state(
        state=states.STARTED,
        meta={
            'task': 'reload_router_config',
            'hostname': hostname,
            'message': 'loading router netconf data'
        })

    # get the timestamp for the current netconf data
    current_netconf_timestamp = None
    try:
        netconf_doc = load_netconf_data(hostname)
        current_netconf_timestamp \
            = juniper.netconf_changed_timestamp(netconf_doc)
        logger.debug(
            'current netconf timestamp: %r' % current_netconf_timestamp)
    except InventoryTaskError:
        pass  # ok at this point if not found

    # load new netconf data
    netconf_refresh_config.apply(args=[hostname])

    netconf_doc = load_netconf_data(hostname)

    # return if new timestamp is the same as the original timestamp
    new_netconf_timestamp = juniper.netconf_changed_timestamp(netconf_doc)
    assert new_netconf_timestamp, \
        'no timestamp available for new netconf data'
    if new_netconf_timestamp == current_netconf_timestamp:
        logger.debug('no netconf change timestamp change, aborting')
        return {
            'task': 'reload_router_config',
            'hostname': hostname,
            'message': 'OK (no change)'
        }

    # clear cached classifier responses for this router, and
    # refresh peering data
    self.update_state(
        state=states.STARTED,
        meta={
            'task': 'reload_router_config',
            'hostname': hostname,
            'message': 'refreshing peers & clearing cache'
        })
    refresh_ix_public_peers(hostname, netconf_doc)
    refresh_vpn_rr_peers(hostname, netconf_doc)
    refresh_interface_address_lookups(hostname, netconf_doc)
    refresh_juniper_interface_list(hostname, netconf_doc)
    # clear_cached_classifier_responses(hostname)

    # load snmp indexes
    community = juniper.snmp_community_string(netconf_doc)
    if not community:
        raise InventoryTaskError(
            'error extracting community string for %r' % hostname)
    else:
        self.update_state(
            state=states.STARTED,
            meta={
                'task': 'reload_router_config',
                'hostname': hostname,
                'message': 'refreshing snmp interface indexes'
            })
        snmp_refresh_interfaces.apply(args=[hostname, community])

    clear_cached_classifier_responses(None)

    return {
        'task': 'reload_router_config',
        'hostname': hostname,
        'message': 'OK'
    }


def _erase_next_db(config):
    """
    flush next db, but first save latch and then restore afterwards

    TODO: handle the no latch scenario nicely
    :param config:
    :return:
    """
    r = get_next_redis(config)
    saved_latch = get_latch(r)
    r.flushdb()
    if saved_latch:
        set_latch(
            config,
            new_current=saved_latch['current'],
            new_next=saved_latch['next'])


def launch_refresh_cache_all(config):
    """
    utility function intended to be called outside of the worker process
    :param config: config structure as defined in config.py
    :return:
    """
    _erase_next_db(config)

    update_latch_status(config, pending=True)

    # first batch of subtasks: refresh cached opsdb data
    subtasks = [
        update_neteng_managed_device_list.apply_async(),
        update_interfaces_to_services.apply_async(),
        update_geant_lambdas.apply_async(),
        update_circuit_hierarchy.apply_async()
    ]
    [x.get() for x in subtasks]

    # second batch of subtasks:
    #   alarms db status cache
    #   juniper netconf & snmp data
    subtasks = [
        update_equipment_locations.apply_async(),
        update_lg_routers.apply_async(),
        update_access_services.apply_async(),
        import_unmanaged_interfaces.apply_async()
    ]
    for hostname in data.derive_router_hostnames(config):
        logger.debug('queueing router refresh jobs for %r' % hostname)
        subtasks.append(reload_router_config.apply_async(args=[hostname]))

    pending_task_ids = [x.id for x in subtasks]

    t = refresh_finalizer.apply_async(args=[json.dumps(pending_task_ids)])
    pending_task_ids.append(t.id)
    return pending_task_ids


def _wait_for_tasks(task_ids, update_callback=lambda s: None):

    all_successful = True

    start_time = time.time()
    while task_ids and time.time() - start_time < FINALIZER_TIMEOUT_S:
        update_callback('waiting for tasks to complete: %r' % task_ids)
        time.sleep(FINALIZER_POLLING_FREQUENCY_S)

        def _is_error(id):
            status = check_task_status(id)
            return status['ready'] and not status['success']

        if any([_is_error(id) for id in task_ids]):
            all_successful = False

        task_ids = [
            id for id in task_ids
            if not check_task_status(id)['ready']
        ]

    if task_ids:
        raise InventoryTaskError(
            'timeout waiting for pending tasks to complete')
    if not all_successful:
        raise InventoryTaskError(
            'some tasks finished with an error')

    update_callback('pending taskscompleted in {} seconds'.format(
            time.time() - start_time))


@app.task(base=InventoryTask, bind=True, name='refresh_finalizer')
@log_task_entry_and_exit
def refresh_finalizer(self, pending_task_ids_json):

    input_schema = {
        "$schema": "http://json-schema.org/draft-07/schema#",
        "type": "array",
        "items": {"type": "string"}
    }

    def _update(s):
        logger.debug(s)
        self.update_state(
            state=states.STARTED,
            meta={
                'task': 'refresh_finalizer',
                'message': s
            })

    try:
        task_ids = json.loads(pending_task_ids_json)
        logger.debug('task_ids: %r' % task_ids)
        jsonschema.validate(task_ids, input_schema)

        _wait_for_tasks(task_ids, update_callback=_update)
        _build_subnet_db(update_callback=_update)
        _build_service_category_interface_list(update_callback=_update)

    except (jsonschema.ValidationError,
            json.JSONDecodeError,
            InventoryTaskError) as e:
        update_latch_status(InventoryTask.config, failure=True)
        raise e

    latch_db(InventoryTask.config)
    _update('latched current/next dbs')


def _build_service_category_interface_list(update_callback=lambda s: None):

    def _classify(ifc):
        if ifc['description'].startswith('SRV_MDVPN'):
            return 'mdvpn'
        if 'LHCONE' in ifc['description']:
            return 'lhcone'
        return None

    update_callback('loading all known interfaces')
    interfaces = data.build_service_interface_user_list(InventoryTask.config)
    interfaces = list(interfaces)
    update_callback(f'loaded {len(interfaces)} interfaces, '
                    'saving by service category')

    r = get_next_redis(InventoryTask.config)
    rp = r.pipeline()

    for ifc in interfaces:
        service_type = _classify(ifc)
        if not service_type:
            continue
        rp.set(
            f'interface-services:{service_type}'
            f':{ifc["router"]}:{ifc["interface"]}',
            json.dumps(ifc))

    rp.execute()


def _build_subnet_db(update_callback=lambda s: None):

    r = get_next_redis(InventoryTask.config)

    update_callback('loading all network addresses')
    subnets = {}
    for k in r.scan_iter('reverse_interface_addresses:*'):
        info = r.get(k.decode('utf-8')).decode('utf-8')
        info = json.loads(info)
        entry = subnets.setdefault(info['interface address'], [])
        entry.append(info)

    update_callback('saving {} subnets'.format(len(subnets)))

    rp = r.pipeline()
    for k, v in subnets.items():
        rp.set('subnets:' + k, json.dumps(v))
    rp.execute()


def check_task_status(task_id):
    r = AsyncResult(task_id, app=app)
    result = {
        'id': r.id,
        'status': r.status,
        'exception': r.status in states.EXCEPTION_STATES,
        'ready': r.status in states.READY_STATES,
        'success': r.status == states.SUCCESS,
    }
    if r.result:
        # TODO: only discovered this case by testing, is this the only one?
        #       ... otherwise need to pre-test json serialization
        if isinstance(r.result, Exception):
            result['result'] = {
                'error type': type(r.result).__name__,
                'message': str(r.result)
            }
        else:
            result['result'] = r.result
    return result