import concurrent.futures
import functools
import ipaddress
import itertools
import json
import logging
import os
import re

import ncclient.operations
import ncclient.transport.errors
from celery import Task, states, chord
from celery.result import AsyncResult
from celery import signals

from collections import defaultdict, namedtuple

from kombu.exceptions import KombuError
from lxml import etree
from ncclient.transport import TransportError

from inventory_provider.db import ims_data
from inventory_provider.db.ims import IMS
from inventory_provider.routes.poller import load_error_report_interfaces, load_interfaces_to_poll
from inventory_provider.tasks.app import app
from inventory_provider.tasks.common \
    import get_next_redis, get_current_redis, \
    latch_db, get_latch, set_latch, update_latch_status, \
    ims_sorted_service_type_key, set_single_latch
from inventory_provider import config, nokia, gap
from inventory_provider import environment
from inventory_provider import snmp
from inventory_provider import juniper
from redis import RedisError
from requests import HTTPError

FINALIZER_POLLING_FREQUENCY_S = 2.5
FINALIZER_TIMEOUT_S = 300

# TODO: error callback (cf. http://docs.celeryproject.org/en/latest/userguide/calling.html#linking-callbacks-errbacks)  # noqa: E501

logger = logging.getLogger(__name__)
log_task_entry_and_exit = functools.partial(
    environment.log_entry_and_exit, logger=logger)


@signals.after_setup_logger.connect
def setup_logging(conf=None, **kwargs):
    environment.setup_logging()


class InventoryTaskError(Exception):
    pass


class InventoryTask(Task):
    config = None

    def __init__(self):
        self.args = []

        if InventoryTask.config:
            return

        assert os.path.isfile(
            os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']), (
                'config file %r not found' %
                os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'])

        with open(os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']) as f:
            logging.info(
                "Initializing worker with config from: %r" %
                os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'])
            InventoryTask.config = config.load(f)
            logging.debug("loaded config: %r" % InventoryTask.config)

    def log_info(self, message):
        logger.debug(message)
        self.send_event('task-info', message=message)

    def log_warning(self, message):
        logger.warning(message)
        self.send_event('task-warning', message=message)

    def log_error(self, message):
        logger.error(message)
        self.send_event('task-error', message=message)


def _unmanaged_interfaces():
    def _convert(d):
        # the config file keys are more readable than
        # the keys used in redis
        return {
            'name': d['address'],
            'interface address': d['network'],
            'interface name': d['interface'].lower(),
            'router': d['router'].lower()
        }

    yield from map(
        _convert,
        InventoryTask.config.get('unmanaged-interfaces', []))


def _nokia_community_strings(config_):
    return {
        'inventory-provider': config_['nokia-community-inventory-provider'],
        'dashboard': config_['nokia-community-dashboard'],
        'brian': config_['nokia-community-brian']
    }


def _general_community_strings(community):
    return {
        'inventory-provider': community,
        'dashboard': community,
        'brian': community
    }


@log_task_entry_and_exit
def refresh_juniper_bgp_peers(hostname, netconf):
    host_peerings = list(juniper.all_bgp_peers(netconf))
    r = get_next_redis(InventoryTask.config)
    r.set(f'juniper-peerings:hosts:{hostname}', json.dumps(host_peerings))


@log_task_entry_and_exit
def refresh_juniper_interface_list(hostname, netconf, interface_info, lab=False):
    """
    load all interfaces from the netconf doc

    save under 'lab:...' if lab is true

    :param hostname:
    :param netconf:
    :param interface_info:
    :param lab:
    :return:
    """
    logger.debug(
        'removing cached netconf-interfaces for %r' % hostname)

    r = get_next_redis(InventoryTask.config)

    interfaces_keybase = f'netconf-interfaces:{hostname}'
    bundles_keybase = f'netconf-interface-bundles:{hostname}'
    interfaces_all_key = f'netconf-interfaces-hosts:{hostname}'
    if lab:
        interfaces_keybase = f'lab:{interfaces_keybase}'
        interfaces_all_key = f'lab:{interfaces_all_key}'
        bundles_keybase = f'lab:{bundles_keybase}'

    rp = r.pipeline()
    rp.delete(interfaces_all_key)
    # scan with bigger batches, to mitigate network latency effects
    for k in r.scan_iter(f'{interfaces_keybase}:*', count=1000):
        rp.delete(k)
    for k in r.scan_iter(f'{bundles_keybase}:*', count=1000):
        rp.delete(k)
    rp.execute()

    all_bundles = defaultdict(list)

    rp = r.pipeline()

    rp.set(
        interfaces_all_key,
        json.dumps(list(juniper.interface_addresses(netconf))))

    interface_speeds = {}
    if interface_info:
        interface_speeds = juniper.get_interface_speeds(interface_info)

    for ifc in juniper.list_interfaces(netconf):

        ifc['speed'] = interface_speeds.get(ifc['name'], '')

        bundles = ifc.get('bundle', None)
        for bundle in bundles:
            if bundle:
                all_bundles[bundle].append(ifc['name'])
        rp.set(
            f'{interfaces_keybase}:{ifc["name"]}',
            json.dumps(ifc))

    for k, v in all_bundles.items():
        rp.set(
            f'{bundles_keybase}:{k}',
            json.dumps(v))

    rp.execute()


def _erase_next_db(config):
    """
    flush next db, but first save latch and then restore afterwards

    TODO: handle the no latch scenario nicely
    :param config:
    :return:
    """
    r = get_next_redis(config)
    saved_latch = get_latch(r)

    if saved_latch:
        # execute as transaction to ensure that latch is always available in
        # db that is being flushed
        rp = r.pipeline()
        rp.multi()
        rp.flushdb()
        set_single_latch(
            rp,
            saved_latch['this'],
            saved_latch['current'],
            saved_latch['next'],
            saved_latch.get('timestamp', 0),
            saved_latch.get('update-started', 0)
        )
        rp.execute()

        # ensure latch is consistent in all dbs
        set_latch(
            config,
            new_current=saved_latch['current'],
            new_next=saved_latch['next'],
            timestamp=saved_latch.get('timestamp', 0),
            update_timestamp=saved_latch.get('update-started', 0))


def populate_poller_cache(interface_services, r):
    host_services = defaultdict(dict)
    for v in interface_services.values():
        # logger.debug(v)
        if v:
            h = v[0]['equipment']
            i = v[0]['port']
            host_services[h][i] = [{
                'id': s['id'],
                'name': s['name'],
                'type': s['service_type'],
                'status': s['status']
            } for s in v]
    # todo - delete from redis
    rp = r.pipeline()
    for k in r.scan_iter('poller_cache:*', count=1000):
        rp.delete(k)
    rp.execute()
    rp = r.pipeline()
    for host, interface_services in host_services.items():
        rp.set(f'poller_cache:{host}', json.dumps(interface_services))
    rp.execute()


def _build_subnet_db(update_callback=lambda s: None):
    r = get_next_redis(InventoryTask.config)

    update_callback('loading all network addresses')
    subnets = {}
    patterns = ('netconf-interfaces-hosts:', 'lab:netconf-interfaces-hosts:')
    for p in patterns:
        # scan with bigger batches, to mitigate network latency effects
        for k in r.scan_iter(f"{p}*", count=1000):
            k = k.decode('utf-8')
            hostname = k[len(p):]
            host_interfaces = r.get(k).decode('utf-8')
            host_interfaces = json.loads(host_interfaces)
            for ifc in host_interfaces:
                ifc['router'] = hostname
                entry = subnets.setdefault(ifc['interface address'], [])
                entry.append(ifc)

    for ifc in _unmanaged_interfaces():
        entry = subnets.setdefault(ifc['interface address'], [])
        entry.append(ifc)

    update_callback('saving {} subnets'.format(len(subnets)))

    rp = r.pipeline()
    for k, v in subnets.items():
        rp.set(f'subnets:{k}', json.dumps(v))
    rp.execute()


def _build_router_peering_db(update_callback=lambda s: None):
    def _is_ix(peering_info):
        if peering_info.get('instance', '') != 'IAS':
            return False
        if not peering_info.get('group', '').startswith('GEANT-IX'):
            return False

        expected_keys = ('description', 'local-asn', 'remote-asn')
        if any(peering_info.get(x, None) is None for x in expected_keys):
            logger.error('internal data error, looks like ix peering but'
                         f'some expected keys are missing: {peering_info}')
            return False
        return True

    r = get_next_redis(InventoryTask.config)

    update_callback('loading all juniper network peerings')
    peerings_per_address = {}
    ix_peerings = []
    peerings_per_asn = {}
    peerings_per_logical_system = {}
    peerings_per_group = {}
    peerings_per_routing_instance = {}
    all_peerings = []

    def _build_peering_details(key_prefix):
        # scan with bigger batches, to mitigate network latency effects
        for _k in r.scan_iter(f'{key_prefix}*', count=1000):
            key_name = _k.decode('utf-8')
            hostname = key_name[len(key_prefix):]
            host_peerings = r.get(key_name).decode('utf-8')
            host_peerings = json.loads(host_peerings)
            for _p in host_peerings:
                _p['hostname'] = hostname
                peerings_per_address.setdefault(_p['address'], []).append(_p)
                if _is_ix(_p):
                    ix_peerings.append(_p)
                asn = _p.get('remote-asn', None)
                if asn:
                    peerings_per_asn.setdefault(asn, []).append(_p)
                logical_system = _p.get('logical-system', None)
                if logical_system:
                    peerings_per_logical_system.setdefault(
                        logical_system, []).append(_p)
                group = _p.get('group', None)
                if group:
                    peerings_per_group.setdefault(group, []).append(_p)
                routing_instance = _p.get('instance', None)
                if routing_instance:
                    peerings_per_routing_instance.setdefault(
                        routing_instance, []).append(_p)
                all_peerings.append(_p)

    key_prefixes = ('juniper-peerings:hosts:', 'nokia-peerings:hosts:')
    for k in key_prefixes:
        _build_peering_details(k)

    # sort ix peerings by group
    ix_groups = {}
    for p in ix_peerings:
        description = p['description']
        keyword = description.split(' ')[0]  # regex needed??? (e.g. tabs???)
        ix_groups.setdefault(keyword, set()).add(p['address'])

    rp = r.pipeline()

    # for use with /msr/bgp
    rp.set('router-peerings:all', json.dumps(all_peerings))

    # create peering entries, keyed by remote addresses
    update_callback(f'saving {len(peerings_per_address)} remote peers')
    for k, v in peerings_per_address.items():
        rp.set(f'router-peerings:remote:{k}', json.dumps(v))

    # create pivoted ix group name lists
    update_callback(f'saving {len(ix_groups)} remote ix peering groups')
    for k, v in ix_groups.items():
        group_addresses = list(v)
        rp.set(f'router-peerings:ix-groups:{k}', json.dumps(group_addresses))

    # create pivoted asn peering lists
    update_callback(f'saving {len(peerings_per_asn)} asn peering lists')
    for k, v in peerings_per_asn.items():
        rp.set(f'router-peerings:peer-asn:{k}', json.dumps(v))

    # create pivoted logical-systems peering lists
    update_callback(
        f'saving {len(peerings_per_logical_system)}'
        ' logical-system peering lists')
    for k, v in peerings_per_logical_system.items():
        rp.set(f'router-peerings:logical-system:{k}', json.dumps(v))

    # create pivoted group peering lists
    update_callback(
        f'saving {len(peerings_per_group)} group peering lists')
    for k, v in peerings_per_group.items():
        rp.set(f'router-peerings:group:{k}', json.dumps(v))

    # create pivoted routing instance peering lists
    update_callback(
        f'saving {len(peerings_per_routing_instance)} group peering lists')
    for k, v in peerings_per_routing_instance.items():
        rp.set(f'router-peerings:routing-instance:{k}', json.dumps(v))

    rp.execute()


def _build_snmp_peering_db(update_callback=lambda s: None):
    r = get_next_redis(InventoryTask.config)

    update_callback('loading all snmp network peerings')
    peerings = {}

    # scan with bigger batches, to mitigate network latency effects
    key_prefix = 'snmp-peerings:hosts:'
    for k in r.scan_iter(f'{key_prefix}*', count=1000):
        key_name = k.decode('utf-8')
        hostname = key_name.split(':')[-1]
        host_peerings = r.get(key_name).decode('utf-8')
        host_peerings = json.loads(host_peerings)
        for p in host_peerings:
            p['hostname'] = hostname
            peerings.setdefault(p['remote'], []).append(p)

    update_callback(f'saving {len(peerings)} remote peers')

    rp = r.pipeline()
    for k, v in peerings.items():
        rp.set(f'snmp-peerings:remote:{k}', json.dumps(v))
    rp.execute()


def check_task_status(task_id, parent=None, forget=False):
    r = AsyncResult(task_id, app=app)
    assert r.id == task_id  # sanity

    result = {
        'id': task_id,
        'status': r.status,
        'exception': r.status in states.EXCEPTION_STATES,
        'ready': r.status in states.READY_STATES,
        'success': r.status == states.SUCCESS,
        'parent': parent
    }

    # TODO: only discovered this case by testing, is this the only one?
    #       ... otherwise need to pre-test json serialization
    if isinstance(r.result, Exception):
        result['result'] = {
            'error type': type(r.result).__name__,
            'message': str(r.result)
        }
    else:
        result['result'] = r.result

    def child_taskids(children):
        # reverse-engineered, can't find documentation on this
        for child in children:
            if not child:
                continue
            if isinstance(child, list):
                logger.debug(f'list: {child}')
                yield from child_taskids(child)
                continue
            if isinstance(child, str):
                yield child
                continue
            assert isinstance(child, AsyncResult)
            yield child.id

    for child_id in child_taskids(getattr(r, 'children', []) or []):
        yield from check_task_status(child_id, parent=task_id)

    if forget and result['ready']:
        r.forget()

    yield result


@app.task(base=InventoryTask, bind=True, name='update_entry_point')
@log_task_entry_and_exit
def update_entry_point(self):
    try:
        _erase_next_db(InventoryTask.config)
        update_latch_status(InventoryTask.config, pending=True)
        self.log_info("Starting update")

        routers = retrieve_and_persist_neteng_managed_device_list(
            info_callback=self.log_info,
            warning_callback=self.log_warning
        )
        lab_routers = InventoryTask.config.get('lab-routers', {})
        chord(
            (
                ims_task.s().on_error(task_error_handler.s()),
                chord(
                    (reload_router_config_juniper.s(r) for r, v in routers.items()
                     if v == 'juniper'),
                    empty_task.si('juniper router tasks complete')
                ),
                chord(
                    (reload_lab_router_config_juniper.s(r)
                     for r, v in lab_routers.items() if v == 'juniper'),
                    empty_task.si('juniper lab router tasks complete')
                ),
                chord(
                    (reload_router_config_nokia.s(r) for r, v in routers.items()
                     if v == 'nokia'),
                    empty_task.si('nokia router tasks complete')
                ),
                chord(
                    (reload_router_config_nokia.s(r, True)
                     for r, v in lab_routers.items() if v == 'nokia'),
                    empty_task.si('nokia lab router tasks complete')
                ),
                chord(
                    (reload_router_config_try_all.s(r) for r, v in routers.items()
                     if v == 'unknown'),
                    empty_task.si('unknown router tasks complete')
                ),
                chord(
                    (reload_router_config_try_all.s(r, True)
                     for r, v in lab_routers.items() if v == 'unknown'),
                    empty_task.si('unknown lab router tasks complete')
                )
            ),
            final_task.si().on_error(task_error_handler.s())
        )()
        return self.request.id
    except (RedisError, KombuError):
        update_latch_status(InventoryTask.config, pending=False, failure=True)
        logger.exception('error launching refresh subtasks')
        raise


@app.task
def task_error_handler(request, exc, traceback):
    update_latch_status(InventoryTask.config, pending=False, failure=True)
    logger.warning('Task {0!r} raised error: {1!r}'.format(request.id, exc))


@app.task(base=InventoryTask, bind=True, name='empty_task')
def empty_task(self, message):
    logger.warning(f'message from empty task: {message}')


def retrieve_and_persist_neteng_managed_device_list(
        info_callback=lambda s: None,
        warning_callback=lambda s: None):
    netdash_equipment = None
    try:
        info_callback('querying netdash for managed routers')
        netdash_equipment = gap.load_routers_from_orchestrator(InventoryTask.config)
    except Exception as e:
        warning_callback(f'Error retrieving device list: {e}')

    if netdash_equipment:
        info_callback(f'found {len(netdash_equipment)} routers')
    else:
        warning_callback('No devices retrieved, using previous list')
        try:
            current_r = get_current_redis(InventoryTask.config)
            netdash_equipment = current_r.get('netdash')
            netdash_equipment = json.loads(netdash_equipment.decode('utf-8'))
            if not netdash_equipment:
                raise InventoryTaskError(
                    'No equipment retrieved from previous list')
        except Exception as e:
            warning_callback(str(e))
            update_latch_status(
                InventoryTask.config, pending=False, failure=True)
            raise e

    try:
        next_r = get_next_redis(InventoryTask.config)
        next_r.set('netdash', json.dumps(netdash_equipment))
        info_callback(f'saved {len(netdash_equipment)} managed routers')
    except Exception as e:
        warning_callback(str(e))
        update_latch_status(InventoryTask.config, pending=False, failure=True)
        raise e
    return netdash_equipment


@app.task(base=InventoryTask, bind=True, name='reload_router_config_try_all')
@log_task_entry_and_exit
def reload_router_config_try_all(self, hostname, lab=False):
    try:
        _reload_router_config_nokia(
            hostname, lab, self.log_info, self.log_warning)
    except Exception as e1:
        self.log_warning(
            f'error loading {hostname} info: {e1} - trying juniper')
        try:
            if lab:
                _reload_lab_router_config_juniper(
                    hostname, self.log_info, self.log_warning)
            else:
                _reload_router_config_juniper(
                    hostname, self.log_info, self.log_warning)
        except Exception as e2:
            errmsg = f'unhandled exception loading {hostname} info: {e2}'
            logger.exception(errmsg)
            update_latch_status(
                InventoryTask.config, pending=True, failure=True)
            self.log_error(errmsg)


@app.task(base=InventoryTask, bind=True, name='reload_router_config_nokia')
@log_task_entry_and_exit
def reload_router_config_nokia(self, hostname, lab=False):
    try:
        _reload_router_config_nokia(
            hostname, lab, self.log_info, self.log_warning)
    except Exception as e:
        errmsg = f'unhandled exception loading {hostname} info: {e}'
        logger.exception(errmsg)
        update_latch_status(InventoryTask.config, pending=True, failure=True)
        self.log_error(errmsg)
        # TODO: re-raise and handle in some common way for all tasks


@log_task_entry_and_exit
def _reload_router_config_nokia(
        hostname, lab=False,
        info_callback=lambda s: None,
        warning_callback=lambda s: None):
    info_callback(
        f'loading netconf data for {"lab " if lab else ""} {hostname}')
    netconf_doc, state_doc = retrieve_and_persist_config_nokia(
        hostname, lab, warning_callback)
    if netconf_doc is None or state_doc is None:
        return
    r = get_next_redis(InventoryTask.config)
    refresh_nokia_interface_list(hostname, netconf_doc, r, lab)
    communities = _nokia_community_strings(InventoryTask.config)
    snmp_refresh_interfaces_nokia(hostname, state_doc, communities, r, info_callback)
    refresh_nokia_bgp_peers(hostname, netconf_doc)
    snmp_refresh_peerings_nokia(hostname, communities)


def snmp_refresh_peerings_nokia(hostname, communities, update_callback=lambda S: None):
    get_peerings_func = functools.partial(
        snmp.get_peer_state_info_nokia,
        walk_community=communities['inventory-provider'],
        poll_community=communities['dashboard'],
    )
    snmp_refresh_peerings(
        get_peerings_func, hostname, 'nokia', update_callback=update_callback
    )


def snmp_refresh_peerings(
        get_peerings_func,
        hostname,
        redis_key_group,
        update_callback=lambda s: None):
    try:
        peerings = list(get_peerings_func(hostname))
    except ConnectionError:
        msg = f'error loading snmp peering data from {hostname}'
        logger.exception(msg)
        update_callback(msg)
        r = get_current_redis(InventoryTask.config)
        peerings = r.get(f'snmp-peerings:hosts:{redis_key_group}:{hostname}')
        if peerings is None:
            raise InventoryTaskError(
                f'snmp error with {hostname} and no cached peering data found')
        # unnecessary json encode/decode here ... could be optimized
        peerings = json.loads(peerings.decode('utf-8'))
        update_callback(f'using cached snmp peering data for {hostname}')

    r = get_next_redis(InventoryTask.config)
    r.set(f'snmp-peerings:hosts:{redis_key_group}:{hostname}', json.dumps(peerings))

    update_callback(f'snmp peering info loaded from {hostname}')


def retrieve_and_persist_config_nokia(
        hostname, lab=False, update_callback=lambda s: None
):
    redis_netconf_key = f'netconf-nokia:{hostname}'
    redis_state_key = f'nokia-state:{hostname}'
    if lab:
        redis_netconf_key = f'lab:{redis_netconf_key}'
        redis_state_key = f'lab:{redis_state_key}'

    try:
        netconf_config, state = nokia.load_docs(
            hostname, InventoryTask.config['nokia-ssh']
        )
    except (ConnectionError, TransportError) as e:
        msg = f'error loading nokia data from {hostname}'
        logger.exception(e)
        logger.exception(msg)
        update_callback(msg)
        r = get_current_redis(InventoryTask.config)
        netconf_str = r.get(redis_netconf_key)
        state_str = r.get(redis_state_key)
        failed_docs = []
        failed_keys = []
        if not netconf_str:
            failed_docs.append('netconf')
            failed_keys.append(redis_netconf_key)
        if not state_str:
            failed_docs.append('port state')
            failed_keys.append(redis_state_key)
        if failed_docs:
            update_callback(f'no cached info for {failed_keys}. Ignoring this host')
            logger.warning(
                f'Nokia doc error with {hostname}'
                f' and no cached data found for {failed_docs}. Ignoring this host'
            )
            return None, None
        netconf_config = nokia.remove_xml_namespaces(etree.fromstring(netconf_str))
        state = nokia.remove_xml_namespaces(etree.fromstring(state_str))
        update_callback(f'Returning cached nokia data for {hostname}')
    else:
        netconf_str = etree.tostring(netconf_config)
        state_str = etree.tostring(state)

    rw = get_next_redis(InventoryTask.config)
    rwp = rw.pipeline()
    rwp.set(redis_netconf_key, netconf_str)
    rwp.set(redis_state_key, state_str)
    rwp.execute()
    logger.info(f'Nokia docs info loaded from {hostname}')
    return netconf_config, state


@log_task_entry_and_exit
def snmp_refresh_interfaces_nokia(
        hostname, state_doc, communities, redis, update_callback=lambda s: None):

    def _interface_info(interface_, name_field):
        return {
            'name': interface_[name_field],
            'index': int(interface_['if-index']),
            'communities': communities
        }

    interfaces = (
        _interface_info(ifc, "interface-name")
        for ifc in nokia.get_interfaces_state(state_doc)
    )

    ports = (
        _interface_info(port, "port-id") for port in nokia.get_ports_state(state_doc)
    )
    lags = (_interface_info(lag, "name") for lag in nokia.get_lags_state(state_doc))
    all_interfaces = list(itertools.chain(interfaces, ports, lags))

    rp = redis.pipeline()
    rp.set(f'snmp-interfaces:{hostname}', json.dumps(all_interfaces))

    for ifc in all_interfaces:
        ifc['hostname'] = hostname
        rp.set(
            f'snmp-interfaces-single:{hostname}:{ifc["name"]}',
            json.dumps(ifc))

    rp.execute()

    update_callback(f'snmp interface info loaded from {hostname}')


def refresh_nokia_interface_list(hostname, netconf_config, redis, lab=False):
    bundles_keybase = f'netconf-interface-bundles:{hostname}'
    interfaces_all_key = f'netconf-interfaces-hosts:{hostname}'
    interfaces_key_base = f'netconf-interfaces:{hostname}'
    if lab:
        bundles_keybase = f'lab:{bundles_keybase}'
        interfaces_all_key = f'lab:{interfaces_all_key}'
        interfaces_key_base = f'lab:{interfaces_key_base}'

    logger.debug(f'removing cached netconf-interfaces for {hostname}')
    rp = redis.pipeline()
    rp.delete(interfaces_all_key)
    for k in redis.scan_iter(f'{bundles_keybase}:*', count=1000):
        rp.delete(k)
    for k in redis.scan_iter(f'{interfaces_key_base}:*', count=1000):
        rp.delete(k)
    rp.execute()

    def is_enabled(e):
        return e.get('admin-state', 'enable') == 'enable'

    ports_by_port_id = {p['port-id']: p for p in nokia.get_ports_config(netconf_config) if is_enabled(p)}
    lags_by_name = {lag['name']: lag for lag in nokia.get_lags_config(netconf_config) if is_enabled(lag)}
    interfaces_by_name = \
        {ifc['interface-name']: ifc for ifc in nokia.get_interfaces_config(netconf_config) if is_enabled(ifc)}

    ports_to_lag = {}
    for lag_name, lag in lags_by_name.items():
        for port in lag['ports']:
            ports_to_lag[port] = [lag_name]

    def _save_interfaces_details(_details, _rp):
        _rp.set(
            f'{interfaces_key_base}:{_details["name"]}',
            json.dumps(_details))

    rp = redis.pipeline()
    for port in ports_by_port_id.values():
        details = {
            'name': port['port-id'],
            'description': port['description'],
            'bundle': ports_to_lag.get(port['port-id'], []),
            'ipv4': [],
            'ipv6': [],
        }
        if 'speed' in port and 'speed-unit' in port:
            details['speed'] = f'{port["speed"]}{port["speed-unit"]}'
        else:
            details['speed'] = ''
        _save_interfaces_details(details, rp)

    for interface in interfaces_by_name.values():
        details = {
            'name': interface['interface-name'],
            'description': interface['description'],
            'bundle': ports_to_lag.get(interface['interface-name'], []),
            'speed': '',
            'ipv4': interface['ipv4'],
            'ipv6': interface['ipv6'],
        }
        _save_interfaces_details(details, rp)

    def _get_lag_speed(_lag):
        _ports = [ports_by_port_id[p] for p in _lag['ports'] if p in ports_by_port_id]
        assert len({p['speed-unit'] for p in _ports}) == 1
        return f'{sum(p["speed"] for p in _ports)}{_ports[0]["speed-unit"]}'

    for lag in lags_by_name.values():
        details = {
            'name': lag['name'],
            'description': lag['description'],
            'bundle': [],
            'speed': _get_lag_speed(lag),
            'ipv4': [],
            'ipv6': [],
        }
        _save_interfaces_details(details, rp)

    for lag in lags_by_name.values():
        rp.set(
            f'{bundles_keybase}:{lag["name"]}',
            json.dumps(lag['ports']))

    interfaces = []
    for interface in interfaces_by_name.values():
        for addr in itertools.chain(interface['ipv4'], interface['ipv6']):
            interfaces.append({
                'name': ipaddress.ip_interface(addr).ip.exploded,
                'interface address': addr,
                'interface name': interface['interface-name'],
            })
    rp.set(
        interfaces_all_key,
        json.dumps(interfaces))
    rp.execute()


@log_task_entry_and_exit
def refresh_nokia_bgp_peers(hostname, netconf):
    host_peerings = list(nokia.get_all_bgp_peers(netconf))
    r = get_next_redis(InventoryTask.config)
    r.set(f'nokia-peerings:hosts:{hostname}', json.dumps(host_peerings))


@app.task(base=InventoryTask, bind=True, name='reload_lab_router_juniper')
@log_task_entry_and_exit
def reload_lab_router_config_juniper(self, hostname):
    try:
        _reload_lab_router_config_juniper(
            hostname, self.log_info, self.log_warning)
    except Exception as e:
        errmsg = f'unhandled exception loading {hostname} info: {e}'
        logger.exception(errmsg)
        update_latch_status(InventoryTask.config, pending=True, failure=True)
        self.log_error(errmsg)
        # TODO: re-raise and handle in some common way for all tasks


@log_task_entry_and_exit
def _reload_lab_router_config_juniper(
        hostname,
        info_callback=lambda s: None,
        warning_callback=lambda s: None
):
    info_callback(f'loading netconf data for lab {hostname}')

    # load new netconf data, in this thread
    netconf_str = retrieve_and_persist_netconf_config_juniper(
        hostname, lab=True, update_callback=warning_callback)
    netconf_doc = etree.fromstring(netconf_str)
    interface_info_str = retrieve_and_persist_interface_info_juniper(
        hostname, update_callback=warning_callback)
    if interface_info_str:
        interface_info = etree.fromstring(interface_info_str)
    else:
        interface_info = None

    refresh_juniper_interface_list(
        hostname, netconf_doc, interface_info, lab=True)

    # load snmp indexes
    community = juniper.snmp_community_string(netconf_doc)
    if not community:
        raise InventoryTaskError(
            f'error extracting community string for {hostname}')
    else:
        info_callback(f'refreshing snmp interface indexes for {hostname}')
        logical_systems = juniper.logical_systems(netconf_doc)

        # load snmp data, in this thread
        snmp_refresh_interfaces_juniper(
            hostname, community, logical_systems, info_callback)

    info_callback(f'updated configuration for lab {hostname}')


@app.task(base=InventoryTask, bind=True, name='reload_router_config_juniper')
@log_task_entry_and_exit
def reload_router_config_juniper(self, hostname):
    try:
        _reload_router_config_juniper(
            hostname, self.log_info, self.log_warning)
    except Exception as e:
        errmsg = f'unhandled exception loading {hostname} info: {e}'
        logger.exception(errmsg)
        update_latch_status(InventoryTask.config, pending=True, failure=True)
        self.log_error(errmsg)
        # TODO: re-raise and handle in some common way for all tasks
        # raise


@log_task_entry_and_exit
def _reload_router_config_juniper(
        hostname,
        info_callback=lambda s: None,
        warning_callback=lambda s: None
):
    info_callback(f'loading netconf data for {hostname}')
    netconf_str = retrieve_and_persist_netconf_config_juniper(
        hostname, update_callback=warning_callback)

    netconf_doc = etree.fromstring(netconf_str)
    interface_info_str = retrieve_and_persist_interface_info_juniper(
        hostname, update_callback=warning_callback)
    if interface_info_str:
        interface_info = etree.fromstring(interface_info_str)
    else:
        interface_info = None

    # clear cached classifier responses for this router, and
    # refresh peering data
    logger.info(f'refreshing peers & clearing cache for {hostname}')
    refresh_juniper_bgp_peers(hostname, netconf_doc)
    refresh_juniper_interface_list(hostname, netconf_doc, interface_info)

    # load snmp indexes
    community = juniper.snmp_community_string(netconf_doc)
    if not community:
        raise InventoryTaskError(
            f'error extracting community string for {hostname}')
    else:
        info_callback(f'refreshing snmp interface indexes for {hostname}')
        logical_systems = juniper.logical_systems(netconf_doc)

        # load snmp data, in this thread
        snmp_refresh_interfaces_juniper(
            hostname, community, logical_systems, info_callback)
        snmp_refresh_peerings_juniper(hostname, community, logical_systems)

    logger.info(f'updated configuration for {hostname}')


def retrieve_and_persist_netconf_config_juniper(
        hostname, lab=False, update_callback=lambda s: None):
    redis_key = f'netconf:{hostname}'
    if lab:
        redis_key = f'lab:{redis_key}'

    try:
        netconf_doc = juniper.load_config(
            hostname, InventoryTask.config["ssh"])
        netconf_str = etree.tostring(netconf_doc, encoding='unicode')
    except (ConnectionError, juniper.NetconfHandlingError,
            InventoryTaskError) as e:
        msg = f'error loading netconf data from {hostname}'
        logger.exception(e)
        logger.exception(msg)
        update_callback(msg)
        r = get_current_redis(InventoryTask.config)

        netconf_str = r.get(redis_key)
        if not netconf_str:
            update_callback(f'no cached netconf for {redis_key}')
            raise InventoryTaskError(
                f'netconf error with {hostname}'
                f' and no cached netconf data found')
        logger.info(f'Returning cached netconf data for {hostname}')
        update_callback(f'Returning cached netconf data for {hostname}')

    r = get_next_redis(InventoryTask.config)
    r.set(redis_key, netconf_str)
    logger.info(f'netconf info loaded from {hostname}')
    return netconf_str


def retrieve_and_persist_interface_info_juniper(
        hostname, lab=False, update_callback=lambda s: None):
    redis_key = f'intinfo:{hostname}'
    if lab:
        redis_key = f'lab:{redis_key}'

    try:
        interface_info_str = juniper.get_interface_info_for_router(hostname, InventoryTask.config["ssh"])
        logger.info(f'interface-info rpc success from {hostname}')
    except (ConnectionError, juniper.TimeoutError, InventoryTaskError,
            ncclient.transport.errors.SSHError, ncclient.operations.errors.TimeoutExpiredError):
        msg = f'error loading interface-info data from {hostname}'
        logger.exception(msg)
        update_callback(msg)
        r = get_current_redis(InventoryTask.config)

        interface_info_str = r.get(redis_key)
        if interface_info_str:
            logger.info(f'Returning cached interface info data for {hostname}')
            update_callback(f'Returning cached interface info data for {hostname}')
        else:
            update_callback(f'no cached interface info for {redis_key}')
            logger.warning(f'interface-info could not be retrieved from {hostname}, ignoring this host')
            return None

    r = get_next_redis(InventoryTask.config)
    r.set(redis_key, interface_info_str)
    logger.info(f'interface info loaded from {hostname}')
    return interface_info_str


@log_task_entry_and_exit
def snmp_refresh_interfaces_juniper(
        hostname, community, logical_systems, update_callback=lambda s: None):
    try:
        interfaces = list(
            snmp.get_router_snmp_indexes(hostname, community, logical_systems))
    except ConnectionError:
        msg = f'error loading snmp interface data from {hostname}'
        logger.exception(msg)
        update_callback(msg)
        r = get_current_redis(InventoryTask.config)
        interfaces = r.get(f'snmp-interfaces:{hostname}')
        if not interfaces:
            raise InventoryTaskError(
                f'snmp error with {hostname}'
                f' and no cached snmp interface data found')
        # unnecessary json encode/decode here ... could be optimized
        interfaces = json.loads(interfaces.decode('utf-8'))
        update_callback(f'using cached snmp interface data for {hostname}')

    for ifc in interfaces:
        ifc['communities'] = _general_community_strings(community)
    r = get_next_redis(InventoryTask.config)

    rp = r.pipeline()
    rp.set(f'snmp-interfaces:{hostname}', json.dumps(interfaces))

    # optimization for DBOARD3-372
    # interfaces is a list of dicts like: {'name': str, 'index': int}
    for ifc in interfaces:
        ifc['hostname'] = hostname
        rp.set(
            f'snmp-interfaces-single:{hostname}:{ifc["name"]}',
            json.dumps(ifc))

    rp.execute()

    update_callback(f'snmp interface info loaded from {hostname}')


@log_task_entry_and_exit
def snmp_refresh_peerings_juniper(
        hostname, community, logical_systems, update_callback=lambda S: None):

    get_peerings_func = functools.partial(
        snmp.get_peer_state_info_juniper,
        community=community,
        logical_systems=logical_systems
    )
    snmp_refresh_peerings(
        get_peerings_func, hostname, 'juniper', update_callback=update_callback
    )


def cache_extracted_ims_data(extracted_data, use_current=False):
    if use_current:
        r = get_current_redis(InventoryTask.config)
    else:
        r = get_next_redis(InventoryTask.config)

    for k, v in extracted_data.items():
        r.set(f'ims:cache:{k}', json.dumps(v))


@app.task(base=InventoryTask, bind=True, name='ims_task')
@log_task_entry_and_exit
def ims_task(self, use_current=False):
    try:
        extracted_data = extract_ims_data(log_warning=self.log_warning)
        cache_extracted_ims_data(extracted_data)
        transformed_data = transform_ims_data(extracted_data)

        persist_ims_data(transformed_data, use_current)
    except Exception as e:
        errmsg = f'Error in IMS task {e}'
        logger.exception('Error in IMS task:')
        self.log_error(errmsg)
        update_latch_status(InventoryTask.config, pending=True, failure=True)


def extract_ims_data(log_warning=lambda s: None):
    c = InventoryTask.config["ims"]
    return _extract_ims_data(
        ims_api_url=c['api'],
        ims_username=c['username'],
        ims_password=c['password'],
        verify_ssl=c.get('verify-ssl', False),
        log_warning=log_warning
    )


def _extract_ims_data(
    ims_api_url, ims_username, ims_password, verify_ssl, log_warning
):
    """
    convenient entry point for testing ...

    :param ims_api_url:
    :param ims_username:
    :param ims_password:
    :return:
    """

    def _ds() -> IMS:
        return IMS(ims_api_url, ims_username, ims_password, verify_ssl)

    _ds().clear_dynamic_context_cache()

    locations = {}
    site_locations = {}
    lg_routers = []
    geant_nodes = []
    customer_contacts = {}
    planned_work_contacts = {}
    circuit_ids_to_monitor = []
    circuit_ids_and_sids = {}
    circuit_ids_and_third_party_ids = {}
    additional_circuit_customers = {}
    flexils_data = {}
    customers = {}
    customer_regions = {}
    equipment_details = []

    hierarchy = {}
    port_id_details = defaultdict(list)
    port_id_services = defaultdict(list)

    @log_task_entry_and_exit
    def _populate_locations():
        nonlocal locations
        locations = {k: v for k, v in ims_data.get_node_locations(ds=_ds())}

    @log_task_entry_and_exit
    def _populate_site_locations():
        nonlocal site_locations
        site_locations = {k: v for k, v in ims_data.get_site_locations(ds=_ds())}

    @log_task_entry_and_exit
    def _populate_lg_routers():
        nonlocal lg_routers
        lg_routers = list(ims_data.lookup_lg_routers(ds=_ds()))

    @log_task_entry_and_exit
    def _populate_geant_nodes():
        nonlocal geant_nodes
        geant_nodes = list(ims_data.lookup_geant_nodes(ds=_ds()))

    @log_task_entry_and_exit
    def _populate_customer_contacts():
        nonlocal customer_contacts
        customer_contacts = \
            {k: v for k, v in ims_data.get_customer_tts_contacts(ds=_ds())}

    @log_task_entry_and_exit
    def _populate_customer_planned_work_contacts():
        nonlocal planned_work_contacts
        planned_work_contacts = \
            {k: v for k, v in
             ims_data.get_customer_planned_work_contacts(ds=_ds())}

    @log_task_entry_and_exit
    def _populate_circuit_ids_to_monitor():
        nonlocal circuit_ids_to_monitor
        circuit_ids_to_monitor = \
            list(ims_data.get_monitored_circuit_ids(ds=_ds()))

    @log_task_entry_and_exit
    def _populate_sids():
        nonlocal circuit_ids_and_sids
        circuit_ids_and_sids = \
            {cid: sid for cid, sid in ims_data.get_ids_and_sids(ds=_ds())}

    @log_task_entry_and_exit
    def _populate_third_party_ids():
        nonlocal circuit_ids_and_third_party_ids
        circuit_ids_and_third_party_ids = \
            {cid: tpid for cid, tpid in ims_data.get_ids_and_third_party_ids(ds=_ds())}

    @log_task_entry_and_exit
    def _populate_additional_circuit_customers():
        nonlocal additional_circuit_customers
        additional_circuit_customers = \
            ims_data.get_circuit_related_customers(ds=_ds())

    exceptions = {}
    with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = {
            executor.submit(_populate_locations): 'locations',
            executor.submit(_populate_site_locations): 'site_locations',
            executor.submit(_populate_geant_nodes): 'geant_nodes',
            executor.submit(_populate_lg_routers): 'lg_routers',
            executor.submit(_populate_customer_contacts): 'customer_contacts',
            executor.submit(_populate_customer_planned_work_contacts):
                'planned_work_contacts',
            executor.submit(_populate_circuit_ids_to_monitor):
                'circuit_ids_to_monitor',
            executor.submit(_populate_sids): 'sids',
            executor.submit(_populate_third_party_ids): 'third_party_ids',
            executor.submit(_populate_additional_circuit_customers):
                'additional_circuit_customers'
        }

        for future in concurrent.futures.as_completed(futures):
            if future.exception():
                exceptions[futures[future]] = str(future.exception())

    if exceptions:
        raise InventoryTaskError(json.dumps(exceptions, indent=2))

    @log_task_entry_and_exit
    def _populate_customers():
        nonlocal customers
        customers = {c['id']: c for c in _ds().get_all_entities('customer')}

    @log_task_entry_and_exit
    def _populate_customer_regions():
        nonlocal customer_regions
        customer_regions = {c['id']: c for c in ims_data.get_customer_regions(ds=_ds())}

    @log_task_entry_and_exit
    def _populate_flexils_data():
        nonlocal flexils_data

        def _transform_key(key):
            if key == "null":
                return None
            try:
                return int(key)
            except ValueError:
                return key

        try:
            flexils_data = ims_data.get_flexils_by_circuitid(ds=_ds())
        except HTTPError:
            log_warning("Failure reading FlexILS data from IMS. Using cache instead")
            redis = get_current_redis(InventoryTask.config)
            raw_data = json.loads(redis.get("ims:cache:flexils_data"))
            flexils_data = {_transform_key(k): v for k, v in raw_data.items()}

    @log_task_entry_and_exit
    def _populate_hierarchy():
        nonlocal hierarchy
        hierarchy = {
            d['id']: d for d in ims_data.get_circuit_hierarchy(ds=_ds())}
        logger.debug("hierarchy complete")

    @log_task_entry_and_exit
    def _populate_equipment_details():
        nonlocal equipment_details
        equipment_details = list(ims_data.get_equipment_details(ds=_ds()))

    @log_task_entry_and_exit
    def _populate_port_id_details():
        nonlocal port_id_details
        for x in ims_data.get_port_details(ds=_ds()):
            pd = port_id_details[x['port_id']]
            pd.append(x)
        logger.debug("Port details complete")

    @log_task_entry_and_exit
    def _populate_circuit_info():
        for x in ims_data.get_port_id_services(ds=_ds()):
            port_id_services[x['port_a_id']].append(x)
        logger.debug("port circuits complete")

    with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = {
            executor.submit(_populate_hierarchy): 'hierarchy',
            executor.submit(_populate_port_id_details): 'port_id_details',
            executor.submit(_populate_circuit_info): 'circuit_info',
            executor.submit(_populate_flexils_data): 'flexils_data',
            executor.submit(_populate_customers): 'customers',
            executor.submit(_populate_customer_regions): 'customer_regions',
            executor.submit(_populate_equipment_details): 'equipment_details'
        }

        for future in concurrent.futures.as_completed(futures):
            if future.exception():
                exceptions[futures[future]] = str(future.exception())

    if exceptions:
        raise InventoryTaskError(json.dumps(exceptions, indent=2))

    return {
        'locations': locations,
        'site_locations': site_locations,
        'lg_routers': lg_routers,
        'customer_contacts': customer_contacts,
        'planned_work_contacts': planned_work_contacts,
        'circuit_ids_to_monitor': circuit_ids_to_monitor,
        'circuit_ids_sids': circuit_ids_and_sids,
        'circuit_ids_third_party_ids': circuit_ids_and_third_party_ids,
        'additional_circuit_customers': additional_circuit_customers,
        'hierarchy': hierarchy,
        'port_id_details': port_id_details,
        'port_id_services': port_id_services,
        'geant_nodes': geant_nodes,
        'flexils_data': flexils_data,
        'customers': customers,
        'customer_regions': customer_regions,
        'equipment_details': equipment_details
    }


def _convert_to_bits(value, unit):
    unit = unit.lower()
    conversions = {
        'm': 1 << 20,
        'mb': 1 << 20,
        'g': 1 << 30,
        'gbe': 1 << 30,
    }
    return int(value) * conversions[unit]


def _get_speed(circuit_id, hierarchy):
    c = hierarchy.get(circuit_id)
    if c is None:
        return 0
    if c['status'] != 'operational':
        return 0
    pattern = re.compile(r'^(\d+)([a-zA-z]+)$')
    m = pattern.match(c['speed'])
    if m:
        try:
            return _convert_to_bits(m[1], m[2])
        except KeyError as e:
            logger.debug(f'Could not find key: {e} '
                         f'for circuit: {circuit_id}')
            return 0
    else:
        if c['circuit-type'] == 'service' \
                or c['product'].lower() == 'ethernet':
            return sum(
                (_get_speed(x, hierarchy) for x in c['carrier-circuits'])
            )
        else:
            return 0


def transform_ims_data(data):
    locations = data['locations']
    customer_contacts = data['customer_contacts']
    planned_work_contacts = data['planned_work_contacts']
    circuit_ids_to_monitor = data['circuit_ids_to_monitor']
    additional_circuit_customers = data['additional_circuit_customers']

    # This is only used within this function and the populate_mic_with_third_party_data
    # The following data gets added to this
    #  contacts
    #  planned_work_contacts
    #  sid
    #  third_party_id - only if it exists for the circuit
    hierarchy = data['hierarchy']  # data in this gets modified

    # These two just gets the flex ILS data added to them.
    # They are also used for building interface_services
    port_id_details = data['port_id_details']  # data in this gets modified
    port_id_services = data['port_id_services']  # data in this gets modified

    circuit_ids_and_sids = data['circuit_ids_sids']
    circuit_ids_and_third_party_ids = data['circuit_ids_third_party_ids']
    geant_nodes = data['geant_nodes']
    flexils_data = data['flexils_data']
    customers = data['customers']

    # this is the data that most classification data will be built from
    # the keys are in the format hosts:interfaces
    # e.g.
    # MX1.FRA.DE:ET-1/0/2
    # MX1.LON.UK:AE12
    # MX1.LON.UK:AE12.123
    interface_services = {}

    services_by_type = {}
    node_pair_services = {}
    sid_services = {}
    pop_nodes = {}

    # populate pop_nodes
    for location in locations.values():
        pop_nodes.setdefault(location['pop']['name'], []).append(location['equipment-name'])

    def _get_circuit_contacts(_circuit):
        _customer_ids = {_circuit['customerid']}
        _customer_ids.update(ac['id'] for ac in additional_circuit_customers.get(_circuit['id'], ()))

        _trouble_ticket_contacts = set()
        _planned_work_contacts = set()
        for cid in _customer_ids:
            _trouble_ticket_contacts.update(customer_contacts.get(cid, ()))
            _planned_work_contacts.update(planned_work_contacts.get(cid, ()))

        return _trouble_ticket_contacts, _planned_work_contacts

    def _add_details_to_hierarchy():
        # this updates the values in the hierarchy dict of the outer scope
        # it is an internal function just to clearly group the logic
        for circuit_id, hierarchy_entry in hierarchy.items():
            # add contacts
            tt_contacts, pw_contacts = _get_circuit_contacts(hierarchy_entry)
            hierarchy_entry['contacts'] = sorted(tt_contacts)
            hierarchy_entry['planned_work_contacts'] = sorted(pw_contacts)
            # add SIDs
            hierarchy_entry['sid'] = circuit_ids_and_sids.get(circuit_id, '')
            # add third party ids to hierarchy - iff one exists
            if circuit_id in circuit_ids_and_third_party_ids:
                hierarchy_entry['third_party_id'] = circuit_ids_and_third_party_ids[circuit_id]

    _add_details_to_hierarchy()

    def _add_flex_ils_details_to_port_id_details(_flex_ils_details):
        port_id_details[flex_ils_details['key']] = [{
            'port_id': flex_ils_details['key'],
            'equipment_name': flex_ils_details['node_name'],
            'interface_name': flex_ils_details['full_port_name']
        }]

    def _add_flex_ils_details_to_port_id_services(_flex_ils_details, _circuit_id):
        circuit = hierarchy[circuit_id]
        port_id_services[flex_ils_details['key']] = [{
            'id': circuit['id'],
            'name': circuit['name'],
            'project': circuit['project'],
            'port_a_id': flex_ils_details['key'],
            'circuit_type': circuit['circuit-type'],
            'status': circuit['status'],
            'service_type': circuit['product'],
            'customerid': circuit['customerid'],
            'customer': customers.get(circuit['customerid'], ''),
            'contacts': circuit['contacts'],
            'planned_work_contacts': circuit['planned_work_contacts']
        }]

    # add flex ils details to port_id_details and port_id_services
    # flex_ils_data is a dict of circuit_ids to list of flex_ils_details
    for circuit_id, flex_ils_details_per_circuit_id in flexils_data.items():
        if circuit_id:  # add details iff there is a circuit (there can be a None key)
            # an example of flex_ils_details:
            # {'node_name': 'PAR01-MTC6-1', 'full_port_name': '5-A-1-L1-1', 'key': 'PAR01-MTC6-1:5-A-1-L1-1'}
            for flex_ils_details in flex_ils_details_per_circuit_id:
                _add_flex_ils_details_to_port_id_details(flex_ils_details)
                _add_flex_ils_details_to_port_id_services(flex_ils_details, circuit_id)

    def _get_related_services(_circuit_id):
        # this may want to go outside the transform function so that it can be used
        # by other functions, in which case we would need to also pass
        # - hierarchy
        # - circuit_ids_to_monitor
        # - circuit_ids_and_sids

        # using a dict as an easy way to ensure unique services
        related_services = {}

        if _circuit_id and _circuit_id in hierarchy:
            circuit = hierarchy.get(_circuit_id, None)

            if circuit['circuit-type'] == 'service':
                related_services[circuit['id']] = {
                    'id': circuit['id'],
                    'name': circuit['name'],
                    'circuit_type': circuit['circuit-type'],
                    'service_type': circuit['product'],
                    'project': circuit['project'],
                    'contacts': circuit['contacts'],
                    'planned_work_contacts': circuit['planned_work_contacts']
                }
                if circuit['id'] in circuit_ids_to_monitor:
                    related_services[circuit['id']]['status'] = circuit['status']
                else:
                    related_services[circuit['id']]['status'] = 'non-monitored'

                if circuit['id'] in circuit_ids_and_sids:
                    related_services[circuit['id']]['sid'] = circuit_ids_and_sids[circuit['id']]

            if circuit['sub-circuits']:
                for sub_circuit_id in circuit['sub-circuits']:
                    temp_parents = _get_related_services(sub_circuit_id)
                    related_services.update({t['id']: t for t in temp_parents})
        return related_services.values()

    def _format_circuit(circuit):
        circuit['additional_customers'] = additional_circuit_customers.get(circuit['id'], [])
        circuit['original_status'] = circuit['status']
        circuit['monitored'] = True
        if circuit['circuit_type'] == 'service' and circuit['id'] not in circuit_ids_to_monitor:
            circuit['monitored'] = False
            circuit['status'] = 'non-monitored'

        # there is only ever 1 item in the list, next refactor should remove the list
        port_details_a = port_id_details[circuit['port_a_id']][0]
        location_a = locations.get(port_details_a['equipment_name'], None)
        if location_a:
            loc_a = location_a['pop']
        else:
            loc_a = locations['UNKNOWN_LOC']['pop']
            logger.warning(
                f'Unable to find location for {port_details_a["equipment_name"]} - '
                f'Service ID {circuit["id"]}')
        circuit['pop_name'] = loc_a['name']
        circuit['pop_abbreviation'] = loc_a['abbreviation']
        circuit['equipment'] = port_details_a['equipment_name']
        circuit['card_id'] = ''  # this is redundant I believe
        circuit['port'] = port_details_a['interface_name']
        circuit['logical_unit'] = ''  # this is redundant I believe
        if 'port_b_id' in circuit:

            # there is only ever 1 item in the list, next refactor should remove the list
            pd_b = port_id_details[circuit['port_b_id']][0]
            location_b = locations.get(pd_b['equipment_name'], None)
            if location_b:
                loc_b = location_b['pop']
            else:
                loc_b = locations['UNKNOWN_LOC']['pop']
                logger.warning(
                    f'Unable to find location for {pd_b["equipment_name"]} - '
                    f'Service ID {circuit["id"]}')

            circuit['other_end_pop_name'] = loc_b['name']
            circuit['other_end_pop_abbreviation'] = loc_b['abbreviation']
            circuit['other_end_equipment'] = pd_b['equipment_name']
            circuit['other_end_port'] = pd_b['interface_name']
        else:
            circuit['other_end_pop_name'] = ''
            circuit['other_end_pop_abbreviation'] = ''
            circuit['other_end_equipment'] = ''
            circuit['other_end_port'] = ''

        circuit.pop('port_a_id', None)
        circuit.pop('port_b_id', None)

    # recursive function which iterates the circuit tree to find all the fibre routes
    # the given circuit is carried by
    def _get_fibre_routes(_circuit_id):
        FibreRoute = namedtuple('FibreRoute', 'id name status')
        _circuit = hierarchy.get(_circuit_id, None)
        if _circuit is None:
            return
        if _circuit['speed'].lower() == 'fibre_route':
            yield FibreRoute(_circuit['id'], _circuit['name'], _circuit['status'])
        else:
            for carrier_circuit_id in _circuit['carrier-circuits']:
                yield from _get_fibre_routes(carrier_circuit_id)

    def _build_interface_services():
        for port_id, port_details in port_id_details.items():
            _contacts = set()
            _planned_work_contacts = set()
            details = port_details[0]  # there is only ever 1 item in the list, next refactor should remove the list
            circuits = port_id_services.get(details['port_id'], [])
            for _circuit in circuits:
                # add fibre routes
                _circuit['fibre-routes'] = [fr._asdict() for fr in set(_get_fibre_routes(_circuit['id']))]

                # add related services
                _circuit['related-services'] = list(_get_related_services(_circuit['id']))

                # update contact list bases on related services
                if _circuit['status'] == 'operational':
                    for related_service in _circuit['related-services']:
                        if related_service['status'] == 'operational' \
                                and related_service['id'] in circuit_ids_to_monitor:
                            _contacts.update(related_service.get('contacts', []))
                            _planned_work_contacts.update(related_service.get('planned_work_contacts', []))
                _circuit['contacts'] = sorted(_contacts)
                _circuit['planned_work_contacts'] = sorted(_planned_work_contacts)

                # add speed
                _circuit['calculated-speed'] = _get_speed(_circuit['id'], hierarchy)

                # add third party ids
                if _circuit['id'] in circuit_ids_and_third_party_ids:
                    _circuit['third_party_id'] = circuit_ids_and_third_party_ids[_circuit['id']]

                # need to do this before the sid processing as that can skip records
                _format_circuit(_circuit)

                # add sid info
                try:
                    # get the physical port circuit, if it exists
                    # https://jira.software.geant.org/browse/POL1-687
                    port_circuit = next(
                        c for c in circuits if c.get('port_type') == 'ports')
                except StopIteration:
                    port_circuit = None

                sid = None
                if _circuit['id'] in circuit_ids_and_sids:
                    sid = circuit_ids_and_sids[_circuit['id']]
                elif 'sid' in details:
                    if len(circuits) > 1:
                        if port_circuit != _circuit:
                            # if this is not the physical port circuit
                            # related to this port, then we don't want to
                            # assign the SID to this circuit, so skip.
                            continue

                    # assign the SID from the port to this circuit
                    sid = details['sid']

                if sid is None:
                    continue

                _circuit['sid'] = sid

            key = f"{details['equipment_name']}:{details['interface_name']}"
            interface_services.setdefault(key, []).extend(circuits)
        # end of build interface_services

    _build_interface_services()

    def _add_to_sid_services(_circuit):
        if 'sid' in _circuit:
            sid = _circuit['sid']
            sid_info = {
                'circuit_id': _circuit['id'],
                'sid': sid,
                'status': _circuit['original_status'],
                'monitored': _circuit['monitored'],
                'name': _circuit['name'],
                'speed': _circuit['calculated-speed'],
                'service_type': _circuit['service_type'],
                'project': _circuit['project'],
                'customer': _circuit['customer'],
                'equipment': _circuit['equipment'],
                'port': _circuit['port'],
                'geant_equipment': _circuit['equipment'] in geant_nodes
            }
            if sid_info not in sid_services.setdefault(sid, []):
                sid_services[sid].append(sid_info)

    def _add_to_services_by_type(_circuit):
        type_key = ims_sorted_service_type_key(_circuit['service_type'])
        services_by_type.setdefault(type_key, {})[_circuit['id']] = _circuit

    def _add_to_node_pair_services(_circuit):
        if _circuit['other_end_equipment']:
            node_pair_key = f"{_circuit['equipment']}/{_circuit['other_end_equipment']}"
            node_pair_services.setdefault(node_pair_key, {})[_circuit['id']] = _circuit

    for circuits in interface_services.values():
        for circuit in circuits:
            _add_to_sid_services(circuit)
            _add_to_services_by_type(circuit)
            _add_to_node_pair_services(circuit)

    return {
        'hierarchy': hierarchy,
        'interface_services': interface_services,
        'services_by_type': services_by_type,
        'node_pair_services': node_pair_services,
        'sid_services': sid_services,
        'pop_nodes': pop_nodes,
        'locations': data['locations'],  # unchanged
        'site_locations': data['site_locations'],  # unchanged and unused
        'lg_routers': data['lg_routers'],  # unchanged and unused
        'circuit_ids_to_monitor': data['circuit_ids_to_monitor'],  # unchanged
    }


def persist_ims_data(data, use_current=False):
    hierarchy = data['hierarchy']
    locations = data['locations']
    site_locations = data['site_locations']
    lg_routers = data['lg_routers']
    interface_services = data['interface_services']
    services_by_type = data['services_by_type']
    node_pair_services = data['node_pair_services']
    sid_services = data['sid_services']
    pop_nodes = data['pop_nodes']
    circuit_ids_to_monitor = data['circuit_ids_to_monitor']

    def _get_sites():
        # de-dupe the sites (by abbreviation)
        sites = {
            site_location['abbreviation']: site_location
            for site_location in site_locations.values()
        }

        return sites.values()

    if use_current:
        r = get_current_redis(InventoryTask.config)

        r.delete('ims:sid_services')

        # only need to delete the individual keys if it's just an IMS update
        # rather than a complete update (the db will have been flushed)
        for key_pattern in [
            'ims:location:*',
            'ims:lg:*',
            'ims:circuit_hierarchy:*',
            'ims:interface_services:*',
            'ims:access_services:*',
            'ims:gws_indirect:*',
            'ims:node_pair_services:*',
            'ims:pop_nodes:*'
        ]:
            rp = r.pipeline()
            for k in r.scan_iter(key_pattern, count=1000):
                rp.delete(k)
    else:
        r = get_next_redis(InventoryTask.config)

    r.set('ims:sid_services', json.dumps(sid_services))
    rp = r.pipeline()
    for h, d in locations.items():
        rp.set(f'ims:location:{h}', json.dumps([d]))
    for site in _get_sites():
        rp.set(f'ims:site:{site["abbreviation"]}', json.dumps(site))

    rp.execute()
    rp = r.pipeline()
    for router in lg_routers:
        rp.set(f'ims:lg:{router["equipment name"]}', json.dumps(router))
    rp.execute()
    rp = r.pipeline()
    for circ in hierarchy.values():
        rp.set(f'ims:circuit_hierarchy:{circ["id"]}', json.dumps([circ]))
    rp.execute()
    rp = r.pipeline()
    for k, v in interface_services.items():
        rp.set(
            f'ims:interface_services:{k}',
            json.dumps(v))
    rp.execute()
    rp = r.pipeline()
    for k, v in node_pair_services.items():
        rp.set(
            f'ims:node_pair_services:{k}',
            json.dumps(list(v.values())))
    rp.execute()
    rp = r.pipeline()
    for k, v in pop_nodes.items():
        rp.set(
            f'ims:pop_nodes:{k}',
            json.dumps(sorted(v)))
    rp.execute()

    rp = r.pipeline()

    populate_poller_cache(interface_services, r)
    populate_mic_cache(interface_services, r)
    populate_mic_with_third_party_data(interface_services, hierarchy, circuit_ids_to_monitor, r)

    for service_type, services in services_by_type.items():
        for v in services.values():
            rp.set(
                f'ims:services:{service_type}:{v["name"]}',
                json.dumps({
                    'id': v['id'],
                    'name': v['name'],
                    'project': v['project'],
                    'here': {
                        'pop': {
                            'name': v['pop_name'],
                            'abbreviation': v['pop_abbreviation']
                        },
                        'equipment': v['equipment'],
                        'port': v['port'],
                    },
                    'there': {
                        'pop': {
                            'name': v['other_end_pop_name'],
                            'abbreviation': v['other_end_pop_abbreviation']
                        },
                        'equipment': v['other_end_equipment'],
                        'port': v['other_end_port'],
                    },
                    'speed_value': v['calculated-speed'],
                    'speed_unit': 'n/a',
                    'status': v['status'],
                    'type': v['service_type']
                }))

    rp.execute()


def populate_mic_with_third_party_data(interface_services, hierarchy, circuit_ids_to_monitor, r):
    cache_key = "mic:impact:third-party-data"
    third_party_data = defaultdict(lambda: defaultdict(dict))
    third_party_interface_data = defaultdict(lambda: defaultdict(dict))

    def get_related_services_ids(base_id):
        s = set()
        base = hierarchy[base_id]
        if base['circuit-type'] == 'service':
            s.add(base['id'])
        if base['sub-circuits']:
            for sub_circuit in base['sub-circuits']:
                s |= get_related_services_ids(sub_circuit)
        return s

    def get_formatted_third_party_rs(_circuit_data):
        if _circuit_data and _circuit_data['status'] == 'operational' and \
                _circuit_data['id'] in circuit_ids_to_monitor and _circuit_data['circuit-type'] == 'service':
            return {
                'id': _circuit_data['id'],
                'name': _circuit_data['name'] + ' (' + _circuit_data['sid'] + ')',
                'service_type': _circuit_data['product'],
                'contacts': _circuit_data['contacts'],
                'planned_work_contacts': _circuit_data['planned_work_contacts'],
            }
        else:
            return None

    def get_related_services_for_third_party_type_circuit(_circuit_data):
        related_service_ids = list(get_related_services_ids(_circuit_data['id']))
        related_services_info = {
            'related_services': [],
            'contacts': set(),
            'planned_work_contacts': set()
        }
        seen_names = set()  # Set to keep track of already added names
        for rs_id in related_service_ids:
            rs_info = hierarchy.get(rs_id)
            if rs_info is not None:
                formatted_rs = get_formatted_third_party_rs(rs_info)
                if formatted_rs is None:
                    # print("Formatted RS is None for ID:", rs_id)
                    continue
                name = formatted_rs['name']
                if name not in seen_names:  # Check if the name has not been seen before
                    related_services_info['related_services'].append({
                        'id': formatted_rs['id'],
                        'name': name,
                        'service_type': formatted_rs['service_type'],
                    })
                for contact in formatted_rs['contacts']:
                    related_services_info['contacts'].add(contact)
                for planned_work_contact in formatted_rs['planned_work_contacts']:
                    related_services_info['planned_work_contacts'].add(planned_work_contact)
                    seen_names.add(name)
        return related_services_info

    def _get_formatted_third_party_data(_circuit_data):
        if _circuit_data['status'] == 'operational' and _circuit_data.get('third_party_id'):
            related_services_info = get_related_services_for_third_party_type_circuit(_circuit_data)
            return {
                'id': _circuit_data['id'],
                'status': _circuit_data['status'],
                'name': _circuit_data['name'],
                'sid': _circuit_data.get('sid', ''),
                'service': _circuit_data['circuit-type'],
                'service_type': 'circuit_hierarchy',
                'contacts': list(related_services_info['contacts']),
                'planned_work_contacts': list(related_services_info['planned_work_contacts']),
                'third_party_id': _circuit_data['third_party_id'],
                'related_services': related_services_info['related_services']
            }

    def _get_formatted_related_service(_d):
        for rs in _d['related-services']:
            if rs['status'] == 'operational' and rs.get('third_party_id'):
                yield {
                    'id': rs['id'],
                    'sid': rs.get('sid', ''),
                    'status': rs['status'],
                    'name': rs['name'],
                    'service_type': rs['service_type'],
                    'contacts': rs['contacts'],
                    'planned_work_contacts': rs['planned_work_contacts'],
                    'third_party_id': rs.get('third_party_id', '')
                }

    def _get_formatted_interface_service(_is):
        if _is['status'] == 'operational' and _is.get('third_party_id'):
            return {
                'id': _is['id'],
                'sid': _is.get('sid', ''),
                'status': _is['status'],
                'name': _is['name'],
                'service_type': 'circuit_type',
                'contacts': _is['contacts'],
                'planned_work_contacts': _is['planned_work_contacts'],
                'third_party_id': _is.get('third_party_id', '')
            }

    # get circuit hierarchy items that have third party id
    third_party_circuit = {
        k: v for k, v in hierarchy.items() if 'third_party_id' in v
    }

    # iterate over each third_party_circuit items to format and add operational once only
    third_party_circuit_data = [_get_formatted_third_party_data(v)
                                for k, v in third_party_circuit.items()
                                if v and _get_formatted_third_party_data(v)]

    # add third_party_circuit_data to the third_party_data map and add the map to cache
    third_party_data['circuit_hierarchy'] = third_party_circuit_data

    for services in interface_services.values():
        if services:
            current_interface_services = []
            seen_ids = set()
            for d in services:
                if d.get('related-services'):
                    for rs in _get_formatted_related_service(d):
                        if rs['id'] not in seen_ids:
                            current_interface_services.append(rs)
                            seen_ids.add(rs['id'])
                    if (str(d.get('id')) + 'circuit') not in seen_ids:
                        _is = _get_formatted_interface_service(d)
                        if _is:
                            current_interface_services.append(_is)
                        seen_ids.add(str(d.get('id')) + 'circuit')

            if current_interface_services:
                site = f'{services[0]["pop_name"]} ' \
                       f'({services[0]["pop_abbreviation"]})'
                eq_name = services[0]['equipment']
                if_name = services[0]['port']
                third_party_interface_data[site][eq_name][if_name] = current_interface_services
    third_party_data['interface_services'] = third_party_interface_data

    result = json.dumps(third_party_data)
    r.set(cache_key, result.encode('utf-8'))


def populate_mic_cache(interface_services, r):
    cache_key = "mic:impact:all-data"
    all_data = defaultdict(lambda: defaultdict(dict))

    def _get_formatted_rs(_d):
        for rs in _d['related-services']:
            if rs['status'] == 'operational':
                yield {
                    'id': rs['id'],
                    'sid': rs.get('sid', ''),
                    'status': rs['status'],
                    'name': rs['name'],
                    'service_type': rs['service_type'],
                    'contacts': rs['contacts'],
                    'planned_work_contacts': rs['planned_work_contacts'],
                    'third_party_id': rs.get('third_party_id', '')
                }

    def _get_formatted_interface_service(_is):
        if _is['status'] == 'operational':
            return {
                'id': _is['id'],
                'sid': _is.get('sid', ''),
                'status': _is['status'],
                'name': _is['name'],
                'service_type': 'circuit_type',
                'contacts': _is['contacts'],
                'planned_work_contacts': _is['planned_work_contacts'],
                'third_party_id': _is.get('third_party_id', '')
            }

    for services in interface_services.values():
        if services:
            current_interface_services = []
            seen_ids = set()
            for d in services:
                if d.get('related-services'):
                    for rs in _get_formatted_rs(d):
                        if rs['id'] not in seen_ids:
                            current_interface_services.append(rs)
                            seen_ids.add(rs['id'])
                    if (str(d.get('id')) + 'circuit') not in seen_ids:
                        _is = _get_formatted_interface_service(d)
                        if _is:
                            current_interface_services.append(_is)
                        seen_ids.add(str(d.get('id')) + 'circuit')

            if current_interface_services:
                site = f'{services[0]["pop_name"]} ' \
                       f'({services[0]["pop_abbreviation"]})'
                eq_name = services[0]['equipment']
                if_name = services[0]['port']
                all_data[site][eq_name][if_name] = current_interface_services

    result = json.dumps(all_data)
    r.set(cache_key, result.encode('utf-8'))


def _populate_equipment_vendors():
    r = get_next_redis(InventoryTask.config)
    populate_equipment_vendors(r)


def populate_equipment_vendors(r):
    equipment_details = json.loads(r.get('ims:cache:equipment_details').decode('utf-8'))
    router_vendors = json.loads(r.get('netdash').decode('utf-8'))
    equipment_details.extend({
        'name': r,
        'model': 'UNKNOWN',
        'vendor': v} for r, v in router_vendors.items())
    for ed in equipment_details:
        r.set(f'state-checker:equipment-vendors:{ed["name"]}', json.dumps(ed))


@app.task(base=InventoryTask, bind=True, name='final_task')
@log_task_entry_and_exit
def final_task(self):
    r = get_current_redis(InventoryTask.config)
    latch = get_latch(r)
    if latch['failure']:
        raise InventoryTaskError('Sub-task failed - check logs for details')

    _populate_equipment_vendors()
    _build_subnet_db(update_callback=self.log_info)
    _build_snmp_peering_db(update_callback=self.log_info)
    _build_router_peering_db(update_callback=self.log_info)
    populate_poller_interfaces_cache(warning_callback=self.log_warning)
    populate_error_report_interfaces_cache(warning_callback=self.log_warning)
    collate_netconf_interfaces_all_cache(warning_callback=self.log_warning)

    latch_db(InventoryTask.config)
    self.log_info('latched current/next dbs')


@log_task_entry_and_exit
def populate_poller_interfaces_cache(warning_callback=lambda s: None):
    no_lab_cache_key = 'classifier-cache:poller-interfaces:all:no_lab'
    all_cache_key = 'classifier-cache:poller-interfaces:all'
    non_lab_populated_interfaces = None
    all_populated_interfaces = None

    r = get_next_redis(InventoryTask.config)

    try:
        lab_keys_pattern = 'lab:netconf-interfaces-hosts:*'
        lab_equipment = [h.decode('utf-8')[len(lab_keys_pattern) - 1:]
                         for h in r.keys(lab_keys_pattern)]

        all_populated_interfaces = list(
            load_interfaces_to_poll(InventoryTask.config, use_next_redis=True))
        non_lab_populated_interfaces = [x for x in all_populated_interfaces
                                        if x['router'] not in lab_equipment]

    except Exception as e:
        warning_callback(f"Failed to retrieve all required data {e}")
        logger.exception(
            "Failed to retrieve all required data, logging exception")

    if not non_lab_populated_interfaces or not all_populated_interfaces:
        previous_r = get_current_redis(InventoryTask.config)

        def _load_previous(key):
            try:
                warning_callback(f"populating {key} "
                                 "from previously cached data")
                return json.loads(previous_r.get(key))
            except Exception as e:
                warning_callback(f"Failed to load {key} "
                                 f"from previously cached data: {e}")

        if not non_lab_populated_interfaces:
            non_lab_populated_interfaces = _load_previous(no_lab_cache_key)

        if not all_populated_interfaces:
            all_populated_interfaces = _load_previous(all_cache_key)

    r.set(no_lab_cache_key, json.dumps(non_lab_populated_interfaces))
    r.set(all_cache_key, json.dumps(all_populated_interfaces))


@log_task_entry_and_exit
def populate_error_report_interfaces_cache(warning_callback=lambda s: None):
    cache_ns = 'classifier-cache:error-report-interfaces:'
    all_cache_key = cache_ns + 'all'
    all_populated_interfaces = None

    r = get_next_redis(InventoryTask.config)

    try:
        all_populated_interfaces = load_error_report_interfaces(
            InventoryTask.config, use_next_redis=True
        )

    except Exception as e:
        warning_callback(f"Failed to retrieve all required data {e}")
        logger.exception(
            "Failed to retrieve all required data, logging exception")

    if not all_populated_interfaces:
        previous_r = get_current_redis(InventoryTask.config)

        try:
            warning_callback(f"populating {all_cache_key} from previously cached data")
            previous = json.loads(previous_r.get(all_cache_key))
            all_populated_interfaces = sorted(
               previous, key=lambda i: (i["router"], i["name"])
            )
        except Exception as e:
            warning_callback(
                f"Failed to load {all_cache_key} from previously cached data: {e}"
            )
            return

    router_interfaces = {}
    for ifc in all_populated_interfaces:
        interfaces = router_interfaces.setdefault(ifc['router'], [])
        interfaces.append(ifc)

    for router, ifcs in router_interfaces.items():
        r.set(cache_ns + router, json.dumps(ifcs))

    r.set(all_cache_key, json.dumps(all_populated_interfaces))


@log_task_entry_and_exit
def collate_netconf_interfaces_all_cache(warning_callback=lambda s: None):
    """
    Fetch all existing netconf-interface redis entries and assemble them into collated documents.
    Used for fetching speed data more efficiently in /poller/speeds
    :param warning_callback:
    :return:
    """

    def _fetch_docs_for_key_pattern(r, key_pattern):
        for k in r.scan_iter(key_pattern, count=1000):
            key = k.decode('utf-8')
            doc_str = r.get(key).decode('utf-8')
            doc = json.loads(doc_str)
            doc['hostname'] = key.split(':')[1]  # get hostname part of key
            yield doc

    netconf_all_key = 'netconf-interfaces:all'
    lab_netconf_all_key = 'lab:netconf-interfaces:all'
    netconf_interface_docs = None
    lab_netconf_interface_docs = None

    r = get_next_redis(InventoryTask.config)

    try:
        netconf_interface_key_pattern = 'netconf-interfaces:*'
        lab_netconf_interface_key_pattern = 'lab:netconf-interfaces:*'
        netconf_interface_docs = list(_fetch_docs_for_key_pattern(r, netconf_interface_key_pattern))
        lab_netconf_interface_docs = list(_fetch_docs_for_key_pattern(r, lab_netconf_interface_key_pattern))
    except Exception as e:
        warning_callback(f"Failed to collate netconf-interfaces {e}")
        logger.exception(
            "Failed to collate netconf-interfaces, logging exception")

    r.set(netconf_all_key, json.dumps(netconf_interface_docs))
    r.set(lab_netconf_all_key, json.dumps(lab_netconf_interface_docs))