Skip to content
Snippets Groups Projects
worker.py 64.49 KiB
import concurrent.futures
import functools
import json
import logging
import os
import re
import threading
import time
from typing import List

from redis.exceptions import RedisError
from kombu.exceptions import KombuError

from celery import Task, states, chord
from celery.result import AsyncResult

from collections import defaultdict
from lxml import etree
import jsonschema

from inventory_provider.db import ims_data
from inventory_provider.db.ims import IMS
from inventory_provider.routes.classifier import get_ims_interface, \
    get_ims_equipment_name
from inventory_provider.routes.common import load_json_docs, load_snmp_indexes
from inventory_provider.routes.poller import _load_interfaces, \
    _load_interface_bundles, _get_dashboard_data, _get_dashboards, \
    _load_services
from inventory_provider.tasks.app import app
from inventory_provider.tasks.common \
    import get_next_redis, get_current_redis, \
    latch_db, get_latch, set_latch, update_latch_status, \
    ims_sorted_service_type_key, set_single_latch
from inventory_provider.tasks import monitor
from inventory_provider import config
from inventory_provider import environment
from inventory_provider import snmp
from inventory_provider import juniper

FINALIZER_POLLING_FREQUENCY_S = 2.5
FINALIZER_TIMEOUT_S = 300

# TODO: error callback (cf. http://docs.celeryproject.org/en/latest/userguide/calling.html#linking-callbacks-errbacks)  # noqa: E501

environment.setup_logging()

logger = logging.getLogger(__name__)


def log_task_entry_and_exit(f):
    # cf. https://stackoverflow.com/a/47663642
    # cf. https://stackoverflow.com/a/59098957/6708581
    @functools.wraps(f)
    def _w(*args, **kwargs):
        logger.debug(f'>>> {f.__name__}{args}')
        try:
            return f(*args, **kwargs)
        finally:
            logger.debug(f'<<< {f.__name__}{args}')
    return _w


class InventoryTaskError(Exception):
    pass


class InventoryTask(Task):

    config = None

    def __init__(self):

        self.args = []

        if InventoryTask.config:
            return

        assert os.path.isfile(
            os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']), (
                'config file %r not found' %
                os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'])

        with open(os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']) as f:
            logging.info(
                    "Initializing worker with config from: %r" %
                    os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'])
            InventoryTask.config = config.load(f)
            logging.debug("loaded config: %r" % InventoryTask.config)

    def log_info(self, message):
        logger.debug(message)
        self.send_event('task-info', message=message)

    def log_warning(self, message):
        logger.warning(message)
        self.send_event('task-warning', message=message)

    def log_error(self, message):
        logger.error(message)
        self.send_event('task-error', message=message)


@app.task(base=InventoryTask, bind=True, name='snmp_refresh_peerings')
@log_task_entry_and_exit
def snmp_refresh_peerings(self, hostname, community, logical_systems):
    try:
        peerings = list(
            snmp.get_peer_state_info(hostname, community, logical_systems))
    except ConnectionError:
        msg = f'error loading snmp peering data from {hostname}'
        logger.exception(msg)
        self.log_warning(msg)
        r = get_current_redis(InventoryTask.config)
        peerings = r.get(f'snmp-peerings:hosts:{hostname}')
        if peerings is None:
            raise InventoryTaskError(
                f'snmp error with {peerings}'
                f' and no cached peering data found')
        # unnecessary json encode/decode here ... could be optimized
        peerings = json.loads(peerings.decode('utf-8'))
        self.log_warning(f'using cached snmp peering data for {hostname}')

    r = get_next_redis(InventoryTask.config)
    r.set(f'snmp-peerings:hosts:{hostname}', json.dumps(peerings))

    self.log_info(f'snmp peering info loaded from {hostname}')


@app.task(base=InventoryTask, bind=True, name='snmp_refresh_interfaces')
@log_task_entry_and_exit
def snmp_refresh_interfaces(self, hostname, community, logical_systems):
    try:
        interfaces = list(
            snmp.get_router_snmp_indexes(hostname, community, logical_systems))
    except ConnectionError:
        msg = f'error loading snmp interface data from {hostname}'
        logger.exception(msg)
        self.log_warning(msg)
        r = get_current_redis(InventoryTask.config)
        interfaces = r.get(f'snmp-interfaces:{hostname}')
        if not interfaces:
            raise InventoryTaskError(
                f'snmp error with {hostname}'
                f' and no cached snmp interface data found')
        # unnecessary json encode/decode here ... could be optimized
        interfaces = json.loads(interfaces.decode('utf-8'))
        self.log_warning(f'using cached snmp interface data for {hostname}')

    r = get_next_redis(InventoryTask.config)

    rp = r.pipeline()
    rp.set(f'snmp-interfaces:{hostname}', json.dumps(interfaces))

    # optimization for DBOARD3-372
    # interfaces is a list of dicts like: {'name': str, 'index': int}
    for ifc in interfaces:
        ifc['hostname'] = hostname
        rp.set(
            f'snmp-interfaces-single:{hostname}:{ifc["name"]}',
            json.dumps(ifc))

    rp.execute()

    self.log_info(f'snmp interface info loaded from {hostname}')


@app.task(base=InventoryTask, bind=True, name='netconf_refresh_config')
@log_task_entry_and_exit
def netconf_refresh_config(self, hostname, lab=False):
    """
    load netconf and save

    save under 'lab:...' if lab is true

    :param self:
    :param hostname:
    :param lab:
    :return:
    """

    redis_key = f'netconf:{hostname}'
    if lab:
        redis_key = f'lab:{redis_key}'

    try:
        netconf_doc = juniper.load_config(
            hostname, InventoryTask.config["ssh"])
        netconf_str = etree.tostring(netconf_doc, encoding='unicode')
    except (ConnectionError, juniper.NetconfHandlingError):
        msg = f'error loading netconf data from {hostname}'
        logger.exception(msg)
        self.log_warning(msg)
        r = get_current_redis(InventoryTask.config)
        netconf_str = r.get(redis_key)
        if not netconf_str:
            raise InventoryTaskError(
                f'netconf error with {hostname}'
                f' and no cached netconf data found')
        self.log_warning(f'using cached netconf data for {hostname}')

    r = get_next_redis(InventoryTask.config)
    r.set(redis_key, netconf_str)
    self.log_info(f'netconf info loaded from {hostname}')


def _unmanaged_interfaces():

    def _convert(d):
        # the config file keys are more readable than
        # the keys used in redis
        return {
            'name': d['address'],
            'interface address': d['network'],
            'interface name': d['interface'].lower(),
            'router': d['router'].lower()
        }

    yield from map(
        _convert,
        InventoryTask.config.get('unmanaged-interfaces', []))


@app.task(base=InventoryTask, bind=True,
          name='update_neteng_managed_device_list')
@log_task_entry_and_exit
def update_neteng_managed_device_list(self):
    self.log_info('querying netdash for managed routers')

    routers = list(juniper.load_routers_from_netdash(
        InventoryTask.config['managed-routers']))

    self.log_info(f'found {len(routers)} routers, saving details')

    r = get_next_redis(InventoryTask.config)
    r.set('netdash', json.dumps(routers).encode('utf-8'))
    self.log_info(f'saved {len(routers)} managed routers')


def load_netconf_data(hostname, lab=False):
    """
    this method should only be called from a task

    :param hostname:
    :param lab: True to look up under 'lab:...'
    :return:
    """
    redis_key = f'netconf:{hostname}'
    if lab:
        redis_key = f'lab:{redis_key}'

    r = get_next_redis(InventoryTask.config)
    netconf = r.get(redis_key)
    if not netconf:
        raise InventoryTaskError('no netconf data found for %r' % hostname)
    return etree.fromstring(netconf.decode('utf-8'))


def clear_cached_classifier_responses(hostname=None):
    if hostname:
        logger.debug(
            'removing cached classifier responses for %r' % hostname)
    else:
        logger.debug('removing all cached classifier responses')

    r = get_next_redis(InventoryTask.config)

    def _hostname_keys():
        for k in r.keys('classifier-cache:juniper:%s:*' % hostname):
            yield k

        # TODO: very inefficient ... but logically simplest at this point
        for k in r.keys('classifier-cache:peer:*'):
            value = r.get(k.decode('utf-8'))
            if not value:
                # deleted in another thread
                continue
            value = json.loads(value.decode('utf-8'))
            interfaces = value.get('interfaces', [])
            if hostname in [i['interface']['router'] for i in interfaces]:
                yield k

    def _all_keys():
        return r.keys('classifier-cache:*')

    keys_to_delete = _hostname_keys() if hostname else _all_keys()
    rp = r.pipeline()
    for k in keys_to_delete:
        rp.delete(k)
    rp.execute()


@log_task_entry_and_exit
def refresh_juniper_bgp_peers(hostname, netconf):
    host_peerings = list(juniper.all_bgp_peers(netconf))
    r = get_next_redis(InventoryTask.config)
    r.set(f'juniper-peerings:hosts:{hostname}', json.dumps(host_peerings))


@log_task_entry_and_exit
def refresh_juniper_interface_list(hostname, netconf, lab=False):
    """
    load all interfaces from the netconf doc

    save under 'lab:...' if lab is true

    :param hostname:
    :param netconf:
    :param lab:
    :return:
    """
    logger.debug(
        'removing cached netconf-interfaces for %r' % hostname)

    r = get_next_redis(InventoryTask.config)

    interfaces_keybase = f'netconf-interfaces:{hostname}'
    bundles_keybase = f'netconf-interface-bundles:{hostname}'
    interfaces_all_key = f'netconf-interfaces-hosts:{hostname}'
    if lab:
        interfaces_keybase = f'lab:{interfaces_keybase}'
        interfaces_all_key = f'lab:{interfaces_all_key}'
        bundles_keybase = f'lab:{bundles_keybase}'

    rp = r.pipeline()
    rp.delete(interfaces_all_key)
    # scan with bigger batches, to mitigate network latency effects
    for k in r.scan_iter(f'{interfaces_keybase}:*', count=1000):
        rp.delete(k)
    for k in r.scan_iter(
            f'{bundles_keybase}:*', count=1000):
        rp.delete(k)
    rp.execute()

    all_bundles = defaultdict(list)

    rp = r.pipeline()

    rp.set(
        interfaces_all_key,
        json.dumps(list(juniper.interface_addresses(netconf))))

    for ifc in juniper.list_interfaces(netconf):

        bundles = ifc.get('bundle', None)
        for bundle in bundles:
            if bundle:
                all_bundles[bundle].append(ifc['name'])
        rp.set(
            f'{interfaces_keybase}:{ifc["name"]}',
            json.dumps(ifc))

    for k, v in all_bundles.items():
        rp.set(
            f'{bundles_keybase}:{k}',
            json.dumps(v))

    rp.execute()


@app.task(base=InventoryTask, bind=True, name='reload_lab_router_config')
@log_task_entry_and_exit
def reload_lab_router_config(self, hostname):
    self.log_info(f'loading netconf data for lab {hostname}')

    # load new netconf data, in this thread
    netconf_refresh_config.apply(args=[hostname, True])

    netconf_doc = load_netconf_data(hostname, lab=True)

    refresh_juniper_interface_list(hostname, netconf_doc, lab=True)

    # load snmp indexes
    community = juniper.snmp_community_string(netconf_doc)
    if not community:
        raise InventoryTaskError(
            f'error extracting community string for {hostname}')
    else:
        self.log_info(f'refreshing snmp interface indexes for {hostname}')
        logical_systems = juniper.logical_systems(netconf_doc)

        # load snmp data, in this thread
        snmp_refresh_interfaces.apply(
            args=[hostname, community, logical_systems])

    self.log_info(f'updated configuration for lab {hostname}')


@app.task(base=InventoryTask, bind=True, name='reload_router_config')
@log_task_entry_and_exit
def reload_router_config(self, hostname):
    self.log_info(f'loading netconf data for {hostname}')

    # get the timestamp for the current netconf data
    current_netconf_timestamp = None
    try:
        netconf_doc = load_netconf_data(hostname)
        current_netconf_timestamp \
            = juniper.netconf_changed_timestamp(netconf_doc)
        logger.debug(
            'current netconf timestamp: %r' % current_netconf_timestamp)
    except InventoryTaskError:
        # NOTE: should always reach here,
        #       since we always erase everything before starting
        pass  # ok at this point if not found

    # load new netconf data, in this thread
    netconf_refresh_config.apply(args=[hostname, False])

    netconf_doc = load_netconf_data(hostname)

    # return if new timestamp is the same as the original timestamp
    new_netconf_timestamp = juniper.netconf_changed_timestamp(netconf_doc)
    assert new_netconf_timestamp, \
        'no timestamp available for new netconf data'
    if new_netconf_timestamp == current_netconf_timestamp:
        msg = f'no timestamp change for {hostname} netconf data'
        logger.debug(msg)
        self.log_info(msg)
        return

    # clear cached classifier responses for this router, and
    # refresh peering data
    self.log_info(f'refreshing peers & clearing cache for {hostname}')
    refresh_juniper_bgp_peers(hostname, netconf_doc)
    refresh_juniper_interface_list(hostname, netconf_doc)

    # load snmp indexes
    community = juniper.snmp_community_string(netconf_doc)
    if not community:
        raise InventoryTaskError(
            f'error extracting community string for {hostname}')
    else:
        self.log_info(f'refreshing snmp interface indexes for {hostname}')
        logical_systems = juniper.logical_systems(netconf_doc)

        # load snmp data, in this thread
        snmp_refresh_interfaces.apply(
            args=[hostname, community, logical_systems])
        snmp_refresh_peerings.apply(
            args=[hostname, community, logical_systems])

    clear_cached_classifier_responses(None)
    self.log_info(f'updated configuration for {hostname}')


def _erase_next_db(config):
    """
    flush next db, but first save latch and then restore afterwards

    TODO: handle the no latch scenario nicely
    :param config:
    :return:
    """
    r = get_next_redis(config)
    saved_latch = get_latch(r)
    r.flushdb()
    if saved_latch:
        set_latch(
            config,
            new_current=saved_latch['current'],
            new_next=saved_latch['next'],
            timestamp=saved_latch.get('timestamp', 0))


@log_task_entry_and_exit
def launch_refresh_cache_all(config):
    """
    utility function intended to be called outside of the worker process
    :param config: config structure as defined in config.py
    :return:
    """

    try:
        _erase_next_db(config)
        update_latch_status(config, pending=True)

        monitor.clear_joblog(get_current_redis(config))

        # first batch of subtasks: refresh cached IMS location data
        subtasks = [
            update_neteng_managed_device_list.apply_async(),
            update_equipment_locations.apply_async(),
            update_lg_routers.apply_async(),
        ]
        [x.get() for x in subtasks]

        # now launch the task whose only purpose is to
        # act as a convenient parent for all of the remaining tasks
        t = internal_refresh_phase_2.apply_async()
        return t.id

    except (RedisError, KombuError):
        update_latch_status(config, pending=False, failure=True)
        logger.exception('error launching refresh subtasks')
        raise


@app.task(base=InventoryTask, bind=True, name='internal_refresh_phase_2')
@log_task_entry_and_exit
def internal_refresh_phase_2(self):
    # second batch of subtasks:
    #   ims circuit information
    try:

        subtasks = [
            update_circuit_hierarchy_and_port_id_services.apply_async()
        ]

        r = get_next_redis(InventoryTask.config)
        routers = r.get('netdash')
        assert routers
        netdash_equipment = json.loads(routers.decode('utf-8'))
        for hostname in netdash_equipment:
            logger.debug(f'queueing router refresh jobs for {hostname}')
            subtasks.append(reload_router_config.apply_async(args=[hostname]))

        lab_routers = InventoryTask.config.get('lab-routers', [])
        for hostname in lab_routers:
            logger.debug('queueing router refresh jobs for lab %r' % hostname)
            subtasks.append(
                reload_lab_router_config.apply_async(args=[hostname]))

        pending_task_ids = [x.id for x in subtasks]

        refresh_finalizer.apply_async(args=[json.dumps(pending_task_ids)])

    except KombuError:
        # TODO: possible race condition here
        # e.g. if one of these tasks takes a long time and another
        # update is started, we could end up with strange data
        update_latch_status(config, pending=False, failure=True)
        logger.exception('error launching refresh phase 2 subtasks')
        raise


def populate_poller_cache(interface_services, r):
    host_services = defaultdict(dict)
    for v in interface_services.values():
        # logger.debug(v)
        if v:
            h = v[0]['equipment']
            i = v[0]['port']
            host_services[h][i] = [{
                'id': s['id'],
                'name': s['name'],
                'type': s['service_type'],
                'status': s['status']
            } for s in v]
    # todo - delete from redis
    rp = r.pipeline()
    for k in r.scan_iter('poller_cache:*', count=1000):
        rp.delete(k)
    rp.execute()
    rp = r.pipeline()
    for host, interface_services in host_services.items():
        rp.set(f'poller_cache:{host}', json.dumps(interface_services))
    rp.execute()


@app.task(
    base=InventoryTask,
    bind=True,
    name='update_circuit_hierarchy_and_port_id_services')
@log_task_entry_and_exit
def update_circuit_hierarchy_and_port_id_services(self, use_current=False):
    c = InventoryTask.config["ims"]
    ds1 = IMS(c['api'], c['username'], c['password'])
    ds2 = IMS(c['api'], c['username'], c['password'])
    ds3 = IMS(c['api'], c['username'], c['password'])

    locations = {k: v for k, v in ims_data.get_node_locations(ds1)}
    tls_names = list(ims_data.get_service_types(ds1))
    customer_contacts = \
        {k: v for k, v in ims_data.get_customer_service_emails(ds1)}
    circuit_ids_to_monitor = \
        list(ims_data.get_monitored_circuit_ids(ds1))
    additional_circuit_customer_ids = \
        ims_data.get_circuit_related_customer_ids(ds1)

    hierarchy = None
    port_id_details = defaultdict(list)
    port_id_services = defaultdict(list)
    interface_services = defaultdict(list)
    services_by_type = {}

    def _convert_to_bits(value, unit):
        unit = unit.lower()
        conversions = {
            'm': 1 << 20,
            'mb': 1 << 20,
            'g': 1 << 30,
            'gbe': 1 << 30,
        }
        return int(value) * conversions[unit]

    def _get_speed(circuit_id):
        c = hierarchy[circuit_id]
        if c['status'] != 'operational':
            return 0
        pattern = re.compile(r'^(\d+)([a-zA-z]+)$')
        m = pattern.match(c['speed'])
        if m:
            try:
                return _convert_to_bits(m[1], m[2])
            except KeyError as e:
                logger.debug(f'Could not find key: {e} '
                             f'for circuit: {circuit_id}')
                return 0
        else:
            if c['circuit-type'] == 'service' \
                    or c['product'].lower() == 'ethernet':
                return sum(
                    (_get_speed(x) for x in c['carrier-circuits'])
                )
            else:
                return 0

    def _get_circuit_contacts(c):
        customer_ids = {c['customerid']}
        customer_ids.update(additional_circuit_customer_ids.get(c['id'], []))
        return set().union(
            *[customer_contacts.get(cid, []) for cid in customer_ids])

    def _populate_hierarchy():
        nonlocal hierarchy
        hierarchy = {}
        for d in ims_data.get_circuit_hierarchy(ds1):
            hierarchy[d['id']] = d
            d['contacts'] = sorted(list(_get_circuit_contacts(d)))
        logger.debug("hierarchy complete")

    def _populate_port_id_details():
        nonlocal port_id_details
        for x in ims_data.get_port_details(ds2):
            pd = port_id_details[x['port_id']]
            pd.append(x)
        logger.debug("Port details complete")

    def _populate_circuit_info():
        for x in ims_data.get_port_id_services(ds3):
            port_id_services[x['port_a_id']].append(x)
        logger.debug("port circuits complete")

    hierarchy_thread = threading.Thread(target=_populate_hierarchy)
    hierarchy_thread.start()

    port_id_details_thread = threading.Thread(target=_populate_port_id_details)
    port_id_details_thread.start()

    circuit_info_thread = threading.Thread(target=_populate_circuit_info)
    circuit_info_thread.start()

    hierarchy_thread.join()
    circuit_info_thread.join()
    port_id_details_thread.join()

    def _get_fibre_routes(c_id):
        _circ = hierarchy.get(c_id, None)
        if _circ is None:
            return
        if _circ['speed'].lower() == 'fibre_route':
            yield _circ['id']
        else:
            for cc in _circ['carrier-circuits']:
                yield from _get_fibre_routes(cc)

    # whilst this is a service type the top level for reporting
    # are the BGP services on top of it
    tls_names.remove('IP PEERING - NON R&E (PUBLIC)')

    def _get_related_services(circuit_id: str) -> List[dict]:
        rs = {}
        c = hierarchy.get(circuit_id, None)
        if c:

            if c['circuit-type'] == 'service':
                rs[c['id']] = {
                    'id': c['id'],
                    'name': c['name'],
                    'circuit_type': c['circuit-type'],
                    'service_type': c['product'],
                    'project': c['project'],
                    'contacts': c['contacts']
                }
                if c['id'] in circuit_ids_to_monitor:
                    rs[c['id']]['status'] = c['status']
                else:
                    rs[c['id']]['status'] = 'non-monitored'

            if c['sub-circuits']:
                for sub in c['sub-circuits']:
                    temp_parents = \
                        _get_related_services(sub)
                    rs.update({t['id']: t for t in temp_parents})
        return list(rs.values())

    def _format_service(s):

        if s['circuit_type'] == 'service' \
                and s['id'] not in circuit_ids_to_monitor:
            s['status'] = 'non-monitored'
        pd_a = port_id_details[s['port_a_id']][0]
        location_a = locations.get(pd_a['equipment_name'], None)
        if location_a:
            loc_a = location_a['pop']
        else:
            loc_a = locations['UNKNOWN_LOC']['pop']
            logger.warning(
                f'Unable to find location for {pd_a["equipment_name"]} - '
                f'Service ID {s["id"]}')
        s['pop_name'] = loc_a['name']
        s['pop_abbreviation'] = loc_a['abbreviation']
        s['equipment'] = pd_a['equipment_name']
        s['card_id'] = ''  # this is redundant I believe
        s['port'] = pd_a['interface_name']
        s['logical_unit'] = ''  # this is redundant I believe
        if 'port_b_id' in s:
            pd_b = port_id_details[s['port_b_id']][0]
            location_b = locations.get(pd_b['equipment_name'], None)
            if location_b:
                loc_b = location_b['pop']
            else:
                loc_b = locations['UNKNOWN_LOC']['pop']
                logger.warning(
                    f'Unable to find location for {pd_b["equipment_name"]} - '
                    f'Service ID {s["id"]}')

            s['other_end_pop_name'] = loc_b['name']
            s['other_end_pop_abbreviation'] = loc_b['abbreviation']
            s['other_end_equipment'] = pd_b['equipment_name']
            s['other_end_port'] = pd_b['interface_name']
        else:
            s['other_end_pop_name'] = ''
            s['other_end_pop_abbreviation'] = ''
            s['other_end_equipment'] = ''
            s['other_end_port'] = ''

        s.pop('port_a_id', None)
        s.pop('port_b_id', None)

    for key, value in port_id_details.items():
        for details in value:
            k = f"{details['equipment_name']}:" \
                f"{details['interface_name']}"
            circuits = port_id_services.get(details['port_id'], [])

            for circ in circuits:
                contacts = _get_circuit_contacts(circ)
                circ['fibre-routes'] = []
                for x in set(_get_fibre_routes(circ['id'])):
                    c = {
                        'id': hierarchy[x]['id'],
                        'name': hierarchy[x]['name'],
                        'status': hierarchy[x]['status']
                    }
                    circ['fibre-routes'].append(c)

                circ['related-services'] = \
                    _get_related_services(circ['id'])

                for tlc in circ['related-services']:
                    contacts.update(tlc.pop('contacts'))
                circ['contacts'] = sorted(list(contacts))

                circ['calculated-speed'] = _get_speed(circ['id'])
                _format_service(circ)

                type_services = services_by_type.setdefault(
                    ims_sorted_service_type_key(circ['service_type']), dict())
                type_services[circ['id']] = circ

            interface_services[k].extend(circuits)

    if use_current:
        r = get_current_redis(InventoryTask.config)
    else:
        r = get_next_redis(InventoryTask.config)
    rp = r.pipeline()
    for k in r.scan_iter('ims:circuit_hierarchy:*', count=1000):
        rp.delete(k)
    rp.execute()
    rp = r.pipeline()
    for k in r.scan_iter('ims:interface_services:*', count=1000):
        rp.delete(k)
    rp.execute()
    rp = r.pipeline()
    for k in r.scan_iter('ims:access_services:*', count=1000):
        rp.delete(k)
    for k in r.scan_iter('ims:gws_indirect:*', count=1000):
        rp.delete(k)
    rp.execute()
    rp = r.pipeline()
    for circ in hierarchy.values():
        rp.set(f'ims:circuit_hierarchy:{circ["id"]}', json.dumps([circ]))
    rp.execute()
    rp = r.pipeline()
    for k, v in interface_services.items():
        rp.set(
            f'ims:interface_services:{k}',
            json.dumps(v))
    rp.execute()
    rp = r.pipeline()

    populate_poller_cache(interface_services, r)

    for service_type, services in services_by_type.items():
        for v in services.values():
            rp.set(
                f'ims:services:{service_type}:{v["name"]}',
                json.dumps({
                    'id': v['id'],
                    'name': v['name'],
                    'project': v['project'],
                    'here': {
                        'pop': {
                            'name': v['pop_name'],
                            'abbreviation': v['pop_abbreviation']
                        },
                        'equipment': v['equipment'],
                        'port': v['port'],
                    },
                    'there': {
                        'pop': {
                            'name': v['other_end_pop_name'],
                            'abbreviation': v['other_end_pop_abbreviation']
                        },
                        'equipment': v['other_end_equipment'],
                        'port': v['other_end_port'],
                    },
                    'speed_value': v['calculated-speed'],
                    'speed_unit': 'n/a',
                    'status': v['status'],
                    'type': v['service_type']
                }))

    rp.execute()


@app.task(base=InventoryTask, bind=True, name='update_equipment_locations')
@log_task_entry_and_exit
def update_equipment_locations(self, use_current=False):
    if use_current:
        r = get_current_redis(InventoryTask.config)
    else:
        r = get_next_redis(InventoryTask.config)
    rp = r.pipeline()
    for k in r.scan_iter('ims:location:*', count=1000):
        rp.delete(k)
    rp.execute()

    c = InventoryTask.config["ims"]
    ds = IMS(c['api'], c['username'], c['password'])

    rp = r.pipeline()
    hostnames_found = set()
    for h, d in ims_data.get_node_locations(ds):
        # put into a list to match non-IMS version
        rp.set(f'ims:location:{h}', json.dumps([d]))
        if h in hostnames_found:
            logger.debug(f'Multiple entries for {h}')
        hostnames_found.add(h)
    rp.execute()


@app.task(base=InventoryTask, bind=True, name='update_lg_routers')
@log_task_entry_and_exit
def update_lg_routers(self, use_current=False):
    if use_current:
        r = get_current_redis(InventoryTask.config)
        for k in r.scan_iter('ims:lg:*'):
            r.delete(k)
    else:
        r = get_next_redis(InventoryTask.config)

    for k in r.scan_iter('ims:lg:*'):
        r.delete(k)
    c = InventoryTask.config["ims"]
    ds = IMS(c['api'], c['username'], c['password'])

    for router in ims_data.lookup_lg_routers(ds):
        r.set(f'ims:lg:{router["equipment name"]}', json.dumps(router))


def _wait_for_tasks(task_ids, update_callback=lambda s: None):

    all_successful = True

    start_time = time.time()
    while task_ids and time.time() - start_time < FINALIZER_TIMEOUT_S:
        update_callback('waiting for tasks to complete: %r' % task_ids)
        time.sleep(FINALIZER_POLLING_FREQUENCY_S)

        def _task_and_children_result(id):
            tasks = list(check_task_status(id))
            return {
                'error': any([t['ready'] and not t['success'] for t in tasks]),
                'ready': all([t['ready'] for t in tasks])
            }

        results = dict([
            (id, _task_and_children_result(id))
            for id in task_ids])

        if any([t['error'] for t in results.values()]):
            all_successful = False

        task_ids = [
            id for id, status in results.items()
            if not status['ready']]

    if task_ids:
        raise InventoryTaskError(
            'timeout waiting for pending tasks to complete')
    if not all_successful:
        raise InventoryTaskError(
            'some tasks finished with an error')

    update_callback('pending tasks completed in {} seconds'.format(
            time.time() - start_time))


@app.task(base=InventoryTask, bind=True, name='refresh_finalizer')
@log_task_entry_and_exit
def refresh_finalizer(self, pending_task_ids_json):

    # TODO: if more types of errors appear, use a finally block

    input_schema = {
        "$schema": "http://json-schema.org/draft-07/schema#",
        "type": "array",
        "items": {"type": "string"}
    }

    try:
        task_ids = json.loads(pending_task_ids_json)
        logger.debug('task_ids: %r' % task_ids)
        jsonschema.validate(task_ids, input_schema)

        _wait_for_tasks(task_ids, update_callback=self.log_info)
        _build_subnet_db(update_callback=self.log_info)
        _build_snmp_peering_db(update_callback=self.log_info)
        _build_juniper_peering_db(update_callback=self.log_info)
        populate_poller_interfaces_cache()

    except (jsonschema.ValidationError,
            json.JSONDecodeError,
            InventoryTaskError,
            RedisError,
            KombuError) as e:
        update_latch_status(InventoryTask.config, failure=True)
        raise e

    latch_db(InventoryTask.config)
    self.log_info('latched current/next dbs')


def _build_subnet_db(update_callback=lambda s: None):

    r = get_next_redis(InventoryTask.config)

    update_callback('loading all network addresses')
    subnets = {}
    # scan with bigger batches, to mitigate network latency effects
    for k in r.scan_iter('netconf-interfaces-hosts:*', count=1000):
        k = k.decode('utf-8')
        hostname = k[len('netconf-interfaces-hosts:'):]
        host_interfaces = r.get(k).decode('utf-8')
        host_interfaces = json.loads(host_interfaces)
        for ifc in host_interfaces:
            ifc['router'] = hostname
            entry = subnets.setdefault(ifc['interface address'], [])
            entry.append(ifc)

    for ifc in _unmanaged_interfaces():
        entry = subnets.setdefault(ifc['interface address'], [])
        entry.append(ifc)

    update_callback('saving {} subnets'.format(len(subnets)))

    rp = r.pipeline()
    for k, v in subnets.items():
        rp.set(f'subnets:{k}', json.dumps(v))
    rp.execute()


def _build_juniper_peering_db(update_callback=lambda s: None):

    def _is_ix(peering_info):
        if peering_info.get('instance', '') != 'IAS':
            return False
        if not peering_info.get('group', '').startswith('GEANT-IX'):
            return False

        expected_keys = ('description', 'local-asn', 'remote-asn')
        if any(peering_info.get(x, None) is None for x in expected_keys):
            logger.error('internal data error, looks like ix peering but'
                         f'some expected keys are missing: {peering_info}')
            return False
        return True

    r = get_next_redis(InventoryTask.config)

    update_callback('loading all juniper network peerings')
    peerings_per_address = {}
    ix_peerings = []
    peerings_per_asn = {}
    peerings_per_logical_system = {}
    peerings_per_group = {}
    peerings_per_routing_instance = {}

    # scan with bigger batches, to mitigate network latency effects
    key_prefix = 'juniper-peerings:hosts:'
    for k in r.scan_iter(f'{key_prefix}*', count=1000):
        key_name = k.decode('utf-8')
        hostname = key_name[len(key_prefix):]
        host_peerings = r.get(key_name).decode('utf-8')
        host_peerings = json.loads(host_peerings)
        for p in host_peerings:
            p['hostname'] = hostname
            peerings_per_address.setdefault(p['address'], []).append(p)
            if _is_ix(p):
                ix_peerings.append(p)
            asn = p.get('remote-asn', None)
            if asn:
                peerings_per_asn.setdefault(asn, []).append(p)
            logical_system = p.get('logical-system', None)
            if logical_system:
                peerings_per_logical_system.setdefault(
                    logical_system, []).append(p)
            group = p.get('group', None)
            if group:
                peerings_per_group.setdefault(group, []).append(p)
            routing_instance = p.get('instance', None)
            if routing_instance:
                peerings_per_routing_instance.setdefault(
                    routing_instance, []).append(p)

    # sort ix peerings by group
    ix_groups = {}
    for p in ix_peerings:
        description = p['description']
        keyword = description.split(' ')[0]  # regex needed??? (e.g. tabs???)
        ix_groups.setdefault(keyword, set()).add(p['address'])

    rp = r.pipeline()

    # create peering entries, keyed by remote addresses
    update_callback(f'saving {len(peerings_per_address)} remote peers')
    for k, v in peerings_per_address.items():
        rp.set(f'juniper-peerings:remote:{k}', json.dumps(v))

    # create pivoted ix group name lists
    update_callback(f'saving {len(ix_groups)} remote ix peering groups')
    for k, v in ix_groups.items():
        group_addresses = list(v)
        rp.set(f'juniper-peerings:ix-groups:{k}', json.dumps(group_addresses))

    # create pivoted asn peering lists
    update_callback(f'saving {len(peerings_per_asn)} asn peering lists')
    for k, v in peerings_per_asn.items():
        rp.set(f'juniper-peerings:peer-asn:{k}', json.dumps(v))

    # create pivoted logical-systems peering lists
    update_callback(
        f'saving {len(peerings_per_logical_system)}'
        ' logical-system peering lists')
    for k, v in peerings_per_logical_system.items():
        rp.set(f'juniper-peerings:logical-system:{k}', json.dumps(v))

    # create pivoted group peering lists
    update_callback(
        f'saving {len(peerings_per_group)} group peering lists')
    for k, v in peerings_per_group.items():
        rp.set(f'juniper-peerings:group:{k}', json.dumps(v))

    # create pivoted routing instance peering lists
    update_callback(
        f'saving {len(peerings_per_routing_instance)} group peering lists')
    for k, v in peerings_per_routing_instance.items():
        rp.set(f'juniper-peerings:routing-instance:{k}', json.dumps(v))

    rp.execute()


def _build_snmp_peering_db(update_callback=lambda s: None):

    r = get_next_redis(InventoryTask.config)

    update_callback('loading all snmp network peerings')
    peerings = {}

    # scan with bigger batches, to mitigate network latency effects
    key_prefix = 'snmp-peerings:hosts:'
    for k in r.scan_iter(f'{key_prefix}*', count=1000):
        key_name = k.decode('utf-8')
        hostname = key_name[len(key_prefix):]
        host_peerings = r.get(key_name).decode('utf-8')
        host_peerings = json.loads(host_peerings)
        for p in host_peerings:
            p['hostname'] = hostname
            peerings.setdefault(p['remote'], []).append(p)

    update_callback(f'saving {len(peerings)} remote peers')

    rp = r.pipeline()
    for k, v in peerings.items():
        rp.set(f'snmp-peerings:remote:{k}', json.dumps(v))
    rp.execute()


def check_task_status(task_id, parent=None, forget=False):
    r = AsyncResult(task_id, app=app)
    assert r.id == task_id  # sanity

    result = {
        'id': task_id,
        'status': r.status,
        'exception': r.status in states.EXCEPTION_STATES,
        'ready': r.status in states.READY_STATES,
        'success': r.status == states.SUCCESS,
        'parent': parent
    }

    # TODO: only discovered this case by testing, is this the only one?
    #       ... otherwise need to pre-test json serialization
    if isinstance(r.result, Exception):
        result['result'] = {
            'error type': type(r.result).__name__,
            'message': str(r.result)
        }
    else:
        result['result'] = r.result

    def child_taskids(children):
        # reverse-engineered, can't find documentation on this
        for child in children:
            if not child:
                continue
            if isinstance(child, list):
                logger.debug(f'list: {child}')
                yield from child_taskids(child)
                continue
            if isinstance(child, str):
                yield child
                continue
            assert isinstance(child, AsyncResult)
            yield child.id

    for child_id in child_taskids(getattr(r, 'children', []) or []):
        yield from check_task_status(child_id, parent=task_id)

    if forget and result['ready']:
        r.forget()

    yield result

# =================================== chorded - currently only here for testing


# new
@app.task(base=InventoryTask, bind=True, name='update_entry_point')
@log_task_entry_and_exit
def update_entry_point(self):
    routers = retrieve_and_persist_neteng_managed_device_list(
        info_callback=self.log_info,
        warning_callback=self.log_warning
    )
    lab_routers = InventoryTask.config.get('lab-routers', [])

    _erase_next_db_chorded(InventoryTask.config)
    update_latch_status(InventoryTask.config, pending=True)

    tasks = chord(
        (
            ims_task.s().on_error(task_error_handler.s()),
            chord(
                (reload_router_config_chorded.s(r) for r in routers),
                empty_task.si('router tasks complete')
            ),
            chord(
                (reload_lab_router_config_chorded.s(r) for r in lab_routers),
                empty_task.si('lab router tasks complete')
            )
        ),
        final_task.si().on_error(task_error_handler.s())
    )()
    return tasks


# new
@app.task
def task_error_handler(request, exc, traceback):
    update_latch_status(InventoryTask.config, pending=False, failure=True)
    logger.warning('Task {0!r} raised error: {1!r}'.format(request.id, exc))


# new
@app.task(base=InventoryTask, bind=True, name='empty_task')
def empty_task(self, message):
    logger.warning(f'message from empty task: {message}')


def retrieve_and_persist_neteng_managed_device_list(
        info_callback=lambda s: None,
        warning_callback=lambda s: None):
    netdash_equipment = None
    try:
        info_callback('querying netdash for managed routers')
        netdash_equipment = list(juniper.load_routers_from_netdash(
            InventoryTask.config['managed-routers']))
    except Exception as e:
        warning_callback(f'Error retrieving device list: {e}')

    if netdash_equipment:
        info_callback(f'found {len(netdash_equipment)} routers')
    else:
        warning_callback('No devices retrieved, using previous list')
        try:
            current_r = get_current_redis(InventoryTask.config)
            netdash_equipment = current_r.get('netdash')
            netdash_equipment = json.loads(netdash_equipment.decode('utf-8'))
            if not netdash_equipment:
                raise InventoryTaskError(
                    'No equipment retrieved from previous list')
        except Exception as e:
            warning_callback(str(e))
            update_latch_status(pending=False, failure=True)
            raise e

    try:
        next_r = get_next_redis(InventoryTask.config)
        next_r.set('netdash', json.dumps(netdash_equipment))
        info_callback(f'saved {len(netdash_equipment)} managed routers')
    except Exception as e:
        warning_callback(str(e))
        update_latch_status(pending=False, failure=True)
        raise e
    return netdash_equipment


# updated with transaction
def _erase_next_db_chorded(config):
    """
    flush next db, but first save latch and then restore afterwards

    TODO: handle the no latch scenario nicely
    :param config:
    :return:
    """
    r = get_next_redis(config)
    saved_latch = get_latch(r)

    if saved_latch:
        # execute as transaction to ensure that latch is always available in
        # db that is being flushed
        rp = r.pipeline()
        rp.multi()
        rp.flushdb()
        set_single_latch(
            rp,
            saved_latch['this'],
            saved_latch['current'],
            saved_latch['next'],
            saved_latch.get('timestamp', 0)
        )
        rp.execute()

        # ensure latch is consistent in all dbs
        set_latch(
            config,
            new_current=saved_latch['current'],
            new_next=saved_latch['next'],
            timestamp=saved_latch.get('timestamp', 0))


# updated
@app.task(base=InventoryTask, bind=True, name='reload_lab_router_config')
@log_task_entry_and_exit
def reload_lab_router_config_chorded(self, hostname):
    try:
        self.log_info(f'loading netconf data for lab {hostname} RL')

        # load new netconf data, in this thread
        netconf_str = retrieve_and_persist_netconf_config(
            hostname, lab=True, update_callback=self.log_warning)
        netconf_doc = etree.fromstring(netconf_str)

        refresh_juniper_interface_list(hostname, netconf_doc, lab=True)

        # load snmp indexes
        community = juniper.snmp_community_string(netconf_doc)
        if not community:
            raise InventoryTaskError(
                f'error extracting community string for {hostname}')
        else:
            self.log_info(f'refreshing snmp interface indexes for {hostname}')
            logical_systems = juniper.logical_systems(netconf_doc)

            # load snmp data, in this thread
            snmp_refresh_interfaces(hostname, community, logical_systems)

        self.log_info(f'updated configuration for lab {hostname}')
    except Exception as e:
        logger.error(e)
        update_latch_status(InventoryTask.config, pending=True, failure=True)


# updated
@app.task(base=InventoryTask, bind=True, name='reload_router_config')
@log_task_entry_and_exit
def reload_router_config_chorded(self, hostname):
    try:
        self.log_info(f'loading netconf data for {hostname} RL')
        netconf_str = retrieve_and_persist_netconf_config(
            hostname, update_callback=self.log_warning)

        netconf_doc = etree.fromstring(netconf_str)

        # clear cached classifier responses for this router, and
        # refresh peering data
        logger.info(f'refreshing peers & clearing cache for {hostname}')
        refresh_juniper_bgp_peers(hostname, netconf_doc)
        refresh_juniper_interface_list(hostname, netconf_doc)

        # load snmp indexes
        community = juniper.snmp_community_string(netconf_doc)
        if not community:
            raise InventoryTaskError(
                f'error extracting community string for {hostname}')
        else:
            self.log_info(f'refreshing snmp interface indexes for {hostname}')
            logical_systems = juniper.logical_systems(netconf_doc)

            # load snmp data, in this thread
            snmp_refresh_interfaces_chorded(
                hostname, community, logical_systems, self.log_info)
            snmp_refresh_peerings_chorded(hostname, community, logical_systems)

        logger.info(f'updated configuration for {hostname}')
    except Exception as e:
        logger.error(e)
        update_latch_status(InventoryTask.config, pending=True, failure=True)


# new
def retrieve_and_persist_netconf_config(
        hostname, lab=False, update_callback=lambda s: None):
    redis_key = f'netconf:{hostname}'
    if lab:
        redis_key = f'lab:{redis_key}'

    try:
        netconf_doc = juniper.load_config(
            hostname, InventoryTask.config["ssh"])
        netconf_str = etree.tostring(netconf_doc, encoding='unicode')
    except (ConnectionError, juniper.NetconfHandlingError, InventoryTaskError):
        msg = f'error loading netconf data from {hostname}'
        logger.exception(msg)
        update_callback(msg)
        r = get_current_redis(InventoryTask.config)

        netconf_str = r.get(redis_key)
        if not netconf_str:
            update_callback(f'no cached netconf for {redis_key}')
            raise InventoryTaskError(
                f'netconf error with {hostname}'
                f' and no cached netconf data found')
        logger.info(f'Returning cached netconf data for {hostname}')
        update_callback(f'Returning cached netconf data for {hostname}')

    r = get_next_redis(InventoryTask.config)
    r.set(redis_key, netconf_str)
    logger.info(f'netconf info loaded from {hostname}')
    return netconf_str


# updated as is no longer a task
@log_task_entry_and_exit
def snmp_refresh_interfaces_chorded(
        hostname, community, logical_systems, update_callback=lambda s: None):
    try:
        interfaces = list(
            snmp.get_router_snmp_indexes(hostname, community, logical_systems))
    except ConnectionError:
        msg = f'error loading snmp interface data from {hostname}'
        logger.exception(msg)
        update_callback(msg)
        r = get_current_redis(InventoryTask.config)
        interfaces = r.get(f'snmp-interfaces:{hostname}')
        if not interfaces:
            raise InventoryTaskError(
                f'snmp error with {hostname}'
                f' and no cached snmp interface data found')
        # unnecessary json encode/decode here ... could be optimized
        interfaces = json.loads(interfaces.decode('utf-8'))
        update_callback(f'using cached snmp interface data for {hostname}')

    r = get_next_redis(InventoryTask.config)

    rp = r.pipeline()
    rp.set(f'snmp-interfaces:{hostname}', json.dumps(interfaces))

    # optimization for DBOARD3-372
    # interfaces is a list of dicts like: {'name': str, 'index': int}
    for ifc in interfaces:
        ifc['hostname'] = hostname
        rp.set(
            f'snmp-interfaces-single:{hostname}:{ifc["name"]}',
            json.dumps(ifc))

    rp.execute()

    update_callback(f'snmp interface info loaded from {hostname}')


# updated as is no longer a task
@log_task_entry_and_exit
def snmp_refresh_peerings_chorded(
        hostname, community, logical_systems, update_callback=lambda S: None):
    try:
        peerings = list(
            snmp.get_peer_state_info(hostname, community, logical_systems))
    except ConnectionError:
        msg = f'error loading snmp peering data from {hostname}'
        logger.exception(msg)
        update_callback(msg)
        r = get_current_redis(InventoryTask.config)
        peerings = r.get(f'snmp-peerings:hosts:{hostname}')
        if peerings is None:
            raise InventoryTaskError(
                f'snmp error with {peerings}'
                f' and no cached peering data found')
        # unnecessary json encode/decode here ... could be optimized
        peerings = json.loads(peerings.decode('utf-8'))
        update_callback(f'using cached snmp peering data for {hostname}')

    r = get_next_redis(InventoryTask.config)
    r.set(f'snmp-peerings:hosts:{hostname}', json.dumps(peerings))

    update_callback(f'snmp peering info loaded from {hostname}')


# new
@app.task(base=InventoryTask, bind=True, name='ims_task')
@log_task_entry_and_exit
def ims_task(self, use_current=False):

    try:
        extracted_data = extract_ims_data()
        transformed_data = transform_ims_data(extracted_data)
        transformed_data['locations'] = extracted_data['locations']
        transformed_data['lg_routers'] = extracted_data['lg_routers']
        persist_ims_data(transformed_data, use_current)
    except Exception as e:
        logger.error(e)
        update_latch_status(InventoryTask.config, pending=True, failure=True)


# new
def extract_ims_data():

    c = InventoryTask.config["ims"]
    ds1 = IMS(c['api'], c['username'], c['password'])
    ds2 = IMS(c['api'], c['username'], c['password'])
    ds3 = IMS(c['api'], c['username'], c['password'])
    ds4 = IMS(c['api'], c['username'], c['password'])
    ds5 = IMS(c['api'], c['username'], c['password'])

    locations = {}
    lg_routers = []
    customer_contacts = {}
    circuit_ids_to_monitor = []
    additional_circuit_customer_ids = {}

    hierarchy = {}
    port_id_details = defaultdict(list)
    port_id_services = defaultdict(list)

    def _populate_locations():
        nonlocal locations
        locations = {k: v for k, v in ims_data.get_node_locations(ds1)}

    def _populate_lg_routers():
        nonlocal lg_routers
        lg_routers = list(ims_data.lookup_lg_routers(ds5))

    def _populate_customer_contacts():
        nonlocal customer_contacts
        customer_contacts = \
            {k: v for k, v in ims_data.get_customer_service_emails(ds2)}

    def _populate_circuit_ids_to_monitor():
        nonlocal circuit_ids_to_monitor
        circuit_ids_to_monitor = \
            list(ims_data.get_monitored_circuit_ids(ds3))

    def _populate_additional_circuit_customer_ids():
        nonlocal additional_circuit_customer_ids
        additional_circuit_customer_ids = \
            ims_data.get_circuit_related_customer_ids(ds4)

    exceptions = {}
    with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = {
            executor.submit(_populate_locations): 'locations',
            executor.submit(_populate_lg_routers): 'lg_routers',
            executor.submit(_populate_customer_contacts): 'customer_contacts',
            executor.submit(_populate_circuit_ids_to_monitor):
                'circuit_ids_to_monitor',
            executor.submit(_populate_additional_circuit_customer_ids):
                'additional_circuit_customer_ids'
        }

        for future in concurrent.futures.as_completed(futures):
            if future.exception():
                exceptions[futures[future]] = str(future.exception())

    if exceptions:
        raise InventoryTaskError(json.dumps(exceptions, indent=2))

    def _populate_hierarchy():
        nonlocal hierarchy
        hierarchy = {d['id']: d for d in ims_data.get_circuit_hierarchy(ds1)}
        logger.debug("hierarchy complete")

    def _populate_port_id_details():
        nonlocal port_id_details
        for x in ims_data.get_port_details(ds2):
            pd = port_id_details[x['port_id']]
            pd.append(x)
        logger.debug("Port details complete")

    def _populate_circuit_info():
        for x in ims_data.get_port_id_services(ds3):
            port_id_services[x['port_a_id']].append(x)
        logger.debug("port circuits complete")

    with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = {
            executor.submit(_populate_hierarchy): 'hierarchy',
            executor.submit(_populate_port_id_details): 'port_id_details',
            executor.submit(_populate_circuit_info): 'circuit_info'
        }

        for future in concurrent.futures.as_completed(futures):
            if future.exception():
                exceptions[futures[future]] = str(future.exception())

    if exceptions:
        raise InventoryTaskError(json.dumps(exceptions, indent=2))

    return {
        'locations': locations,
        'lg_routers': lg_routers,
        'customer_contacts': customer_contacts,
        'circuit_ids_to_monitor': circuit_ids_to_monitor,
        'additional_circuit_customer_ids': additional_circuit_customer_ids,
        'hierarchy': hierarchy,
        'port_id_details': port_id_details,
        'port_id_services': port_id_services
    }


# new
def transform_ims_data(data):
    locations = data['locations']
    customer_contacts = data['customer_contacts']
    circuit_ids_to_monitor = data['circuit_ids_to_monitor']
    additional_circuit_customer_ids = data['additional_circuit_customer_ids']
    hierarchy = data['hierarchy']
    port_id_details = data['port_id_details']
    port_id_services = data['port_id_services']

    def _get_circuit_contacts(c):
        customer_ids = {c['customerid']}
        customer_ids.update(additional_circuit_customer_ids.get(c['id'], []))
        return set().union(
            *[customer_contacts.get(cid, []) for cid in customer_ids])

    for d in hierarchy.values():
        d['contacts'] = sorted(list(_get_circuit_contacts(d)))

    def _convert_to_bits(value, unit):
        unit = unit.lower()
        conversions = {
            'm': 1 << 20,
            'mb': 1 << 20,
            'g': 1 << 30,
            'gbe': 1 << 30,
        }
        return int(value) * conversions[unit]

    def _get_speed(circuit_id):
        c = hierarchy[circuit_id]
        if c['status'] != 'operational':
            return 0
        pattern = re.compile(r'^(\d+)([a-zA-z]+)$')
        m = pattern.match(c['speed'])
        if m:
            try:
                return _convert_to_bits(m[1], m[2])
            except KeyError as e:
                logger.debug(f'Could not find key: {e} '
                             f'for circuit: {circuit_id}')
                return 0
        else:
            if c['circuit-type'] == 'service' \
                    or c['product'].lower() == 'ethernet':
                return sum(
                    (_get_speed(x) for x in c['carrier-circuits'])
                )
            else:
                return 0

    def _get_fibre_routes(c_id):
        _circ = hierarchy.get(c_id, None)
        if _circ is None:
            return
        if _circ['speed'].lower() == 'fibre_route':
            yield _circ['id']
        else:
            for cc in _circ['carrier-circuits']:
                yield from _get_fibre_routes(cc)

    def _get_related_services(circuit_id: str) -> List[dict]:
        rs = {}
        c = hierarchy.get(circuit_id, None)
        if c:

            if c['circuit-type'] == 'service':
                rs[c['id']] = {
                    'id': c['id'],
                    'name': c['name'],
                    'circuit_type': c['circuit-type'],
                    'service_type': c['product'],
                    'project': c['project'],
                    'contacts': c['contacts']
                }
                if c['id'] in circuit_ids_to_monitor:
                    rs[c['id']]['status'] = c['status']
                else:
                    rs[c['id']]['status'] = 'non-monitored'

            if c['sub-circuits']:
                for sub in c['sub-circuits']:
                    temp_parents = \
                        _get_related_services(sub)
                    rs.update({t['id']: t for t in temp_parents})
        return list(rs.values())

    def _format_service(s):

        if s['circuit_type'] == 'service' \
                and s['id'] not in circuit_ids_to_monitor:
            s['status'] = 'non-monitored'
        pd_a = port_id_details[s['port_a_id']][0]
        location_a = locations.get(pd_a['equipment_name'], None)
        if location_a:
            loc_a = location_a['pop']
        else:
            loc_a = locations['UNKNOWN_LOC']['pop']
            logger.warning(
                f'Unable to find location for {pd_a["equipment_name"]} - '
                f'Service ID {s["id"]}')
        s['pop_name'] = loc_a['name']
        s['pop_abbreviation'] = loc_a['abbreviation']
        s['equipment'] = pd_a['equipment_name']
        s['card_id'] = ''  # this is redundant I believe
        s['port'] = pd_a['interface_name']
        s['logical_unit'] = ''  # this is redundant I believe
        if 'port_b_id' in s:
            pd_b = port_id_details[s['port_b_id']][0]
            location_b = locations.get(pd_b['equipment_name'], None)
            if location_b:
                loc_b = location_b['pop']
            else:
                loc_b = locations['UNKNOWN_LOC']['pop']
                logger.warning(
                    f'Unable to find location for {pd_b["equipment_name"]} - '
                    f'Service ID {s["id"]}')

            s['other_end_pop_name'] = loc_b['name']
            s['other_end_pop_abbreviation'] = loc_b['abbreviation']
            s['other_end_equipment'] = pd_b['equipment_name']
            s['other_end_port'] = pd_b['interface_name']
        else:
            s['other_end_pop_name'] = ''
            s['other_end_pop_abbreviation'] = ''
            s['other_end_equipment'] = ''
            s['other_end_port'] = ''

        s.pop('port_a_id', None)
        s.pop('port_b_id', None)

    services_by_type = {}
    interface_services = defaultdict(list)

    for key, value in port_id_details.items():
        for details in value:
            k = f"{details['equipment_name']}:" \
                f"{details['interface_name']}"
            circuits = port_id_services.get(details['port_id'], [])

            for circ in circuits:
                contacts = _get_circuit_contacts(circ)
                circ['fibre-routes'] = []
                for x in set(_get_fibre_routes(circ['id'])):
                    c = {
                        'id': hierarchy[x]['id'],
                        'name': hierarchy[x]['name'],
                        'status': hierarchy[x]['status']
                    }
                    circ['fibre-routes'].append(c)

                circ['related-services'] = \
                    _get_related_services(circ['id'])

                for tlc in circ['related-services']:
                    contacts.update(tlc.pop('contacts'))
                circ['contacts'] = sorted(list(contacts))

                circ['calculated-speed'] = _get_speed(circ['id'])
                _format_service(circ)

                type_services = services_by_type.setdefault(
                    ims_sorted_service_type_key(circ['service_type']), dict())
                type_services[circ['id']] = circ

            interface_services[k].extend(circuits)

    return {
        'hierarchy': hierarchy,
        'interface_services': interface_services,
        'services_by_type': services_by_type
    }


# new
def persist_ims_data(data, use_current=False):
    hierarchy = data['hierarchy']
    locations = data['locations']
    lg_routers = data['lg_routers']
    interface_services = data['interface_services']
    services_by_type = data['services_by_type']

    if use_current:
        r = get_current_redis(InventoryTask.config)

        # only need to delete the individual keys if it's just an IMS update
        # rather than a complete update (the db will have been flushed)
        for key_pattern in [
            'ims:location:*',
            'ims:lg:*',
            'ims:circuit_hierarchy:*',
            'ims:interface_services:*',
            'ims:access_services:*',
            'ims:gws_indirect:*'
        ]:
            rp = r.pipeline()
            for k in r.scan_iter(key_pattern, count=1000):
                rp.delete(k)
    else:
        r = get_next_redis(InventoryTask.config)

    rp = r.pipeline()
    for h, d in locations.items():
        rp.set(f'ims:location:{h}', json.dumps([d]))
    rp.execute()
    rp = r.pipeline()
    for router in lg_routers:
        rp.set(f'ims:lg:{router["equipment name"]}', json.dumps(router))
    rp.execute()
    rp = r.pipeline()
    for circ in hierarchy.values():
        rp.set(f'ims:circuit_hierarchy:{circ["id"]}', json.dumps([circ]))
    rp.execute()
    rp = r.pipeline()
    for k, v in interface_services.items():
        rp.set(
            f'ims:interface_services:{k}',
            json.dumps(v))
    rp.execute()
    rp = r.pipeline()

    populate_poller_cache(interface_services, r)

    for service_type, services in services_by_type.items():
        for v in services.values():
            rp.set(
                f'ims:services:{service_type}:{v["name"]}',
                json.dumps({
                    'id': v['id'],
                    'name': v['name'],
                    'project': v['project'],
                    'here': {
                        'pop': {
                            'name': v['pop_name'],
                            'abbreviation': v['pop_abbreviation']
                        },
                        'equipment': v['equipment'],
                        'port': v['port'],
                    },
                    'there': {
                        'pop': {
                            'name': v['other_end_pop_name'],
                            'abbreviation': v['other_end_pop_abbreviation']
                        },
                        'equipment': v['other_end_equipment'],
                        'port': v['other_end_port'],
                    },
                    'speed_value': v['calculated-speed'],
                    'speed_unit': 'n/a',
                    'status': v['status'],
                    'type': v['service_type']
                }))

    rp.execute()


# new
@app.task(base=InventoryTask, bind=True, name='final_task')
@log_task_entry_and_exit
def final_task(self):

    r = get_current_redis(InventoryTask.config)
    latch = get_latch(r)
    if latch['failure']:
        raise InventoryTaskError('Sub-task failed - check logs for details')

    _build_subnet_db(update_callback=self.log_info)
    _build_snmp_peering_db(update_callback=self.log_info)
    _build_juniper_peering_db(update_callback=self.log_info)
    populate_poller_interfaces_cache()

    latch_db(InventoryTask.config)
    self.log_info('latched current/next dbs')


def populate_poller_interfaces_cache():

    base_key_pattern = 'netconf:*'
    standard_interfaces = _load_interfaces(
        InventoryTask.config, base_key_pattern, use_next_redis=True)
    lab_interfaces = _load_interfaces(
        InventoryTask.config, f'lab:{base_key_pattern}')

    bundles = {}

    base_bundles_key_pattern = 'netconf-interface-bundles:*'
    _load_interface_bundles(
        InventoryTask.config,
        base_bundles_key_pattern,
        use_next_redis=True
    )
    _load_interface_bundles(
        InventoryTask.config,
        f'lab:{base_bundles_key_pattern}',
        use_next_redis=True)
    snmp_indexes = load_snmp_indexes(use_next_redis=True)

    def _get_populated_interfaces(interfaces):
        for ifc in interfaces:

            router_snmp = snmp_indexes.get(ifc['router'], None)
            if router_snmp and ifc['name'] in router_snmp:
                ifc['snmp-index'] = router_snmp[ifc['name']]['index']

                router_bundle = bundles.get(ifc['router'], None)
                if router_bundle:
                    base_ifc = ifc['name'].split('.')[0]
                    ifc['bundle-parents'] = router_bundle.get(base_ifc, [])

                router_services = _load_services(
                    InventoryTask.config, use_next_redis=True)
                if router_services:
                    ifc['circuits'] = router_services.get(
                        get_ims_interface(ifc['name']), []
                    )

                dashboards = _get_dashboards(ifc)
                ifc['dashboards'] = sorted([d.name for d in dashboards])
                yield _get_dashboard_data(ifc)
            else:
                continue

    non_lab_populated_interfaces = \
        list(_get_populated_interfaces(standard_interfaces))
    r = get_next_redis(InventoryTask.config)
    cache_key = 'classifier-cache:poller-interfaces:all'
    r.set(cache_key, json.dumps(non_lab_populated_interfaces).encode('utf-8'))

    all_populated_interfaces = non_lab_populated_interfaces + \
        list(_get_populated_interfaces(lab_interfaces))
    cache_key = 'classifier-cache:poller-interfaces:all:no-lab'
    r.set(cache_key, json.dumps(all_populated_interfaces).encode('utf-8'))