Skip to content
Snippets Groups Projects
worker.py 24.49 KiB
import json
import logging
import os
import re
import time

from celery import Task, states
from celery.result import AsyncResult

from collections import defaultdict
from lxml import etree
import jsonschema

from inventory_provider.tasks.app import app
from inventory_provider.tasks.common \
    import get_next_redis, latch_db, get_latch, set_latch, update_latch_status
from inventory_provider import config
from inventory_provider import environment
from inventory_provider.db import db, opsdb
from inventory_provider import snmp
from inventory_provider import juniper

FINALIZER_POLLING_FREQUENCY_S = 2.5
FINALIZER_TIMEOUT_S = 300

# TODO: error callback (cf. http://docs.celeryproject.org/en/latest/userguide/calling.html#linking-callbacks-errbacks)  # noqa: E501

environment.setup_logging()

logger = logging.getLogger(__name__)


class InventoryTaskError(Exception):
    pass


class InventoryTask(Task):

    config = None

    def __init__(self):

        if InventoryTask.config:
            return

        assert os.path.isfile(
            os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']), (
                'config file %r not found' %
                os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'])

        with open(os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']) as f:
            logging.info(
                    "Initializing worker with config from: %r" %
                    os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'])
            InventoryTask.config = config.load(f)
            logging.debug("loaded config: %r" % InventoryTask.config)

    def update_state(self, **kwargs):
        logger.debug(json.dumps(
            {'state': kwargs['state'], 'meta': str(kwargs['meta'])}
        ))
        super().update_state(**kwargs)

    def on_failure(self, exc, task_id, args, kwargs, einfo):
        logger.exception(exc)
        super().on_failure(exc, task_id, args, kwargs, einfo)


@app.task(base=InventoryTask, bind=True)
def snmp_refresh_interfaces(self, hostname, community):
    logger.debug(
        '>>> snmp_refresh_interfaces(%r, %r)' % (hostname, community))

    value = list(snmp.get_router_snmp_indexes(hostname, community))

    r = get_next_redis(InventoryTask.config)
    r.set('snmp-interfaces:' + hostname, json.dumps(value))

    logger.debug(
        '<<< snmp_refresh_interfaces(%r, %r)' % (hostname, community))


@app.task(base=InventoryTask, bind=True)
def netconf_refresh_config(self, hostname):
    logger.debug('>>> netconf_refresh_config(%r)' % hostname)

    netconf_doc = juniper.load_config(hostname, InventoryTask.config["ssh"])
    netconf_str = etree.tostring(netconf_doc, encoding='unicode')

    r = get_next_redis(InventoryTask.config)
    r.set('netconf:' + hostname, netconf_str)

    logger.debug('<<< netconf_refresh_config(%r)' % hostname)


@app.task(base=InventoryTask, bind=True)
def update_interfaces_to_services(self):
    logger.debug('>>> update_interfaces_to_services')

    interface_services = defaultdict(list)
    with db.connection(InventoryTask.config["ops-db"]) as cx:
        for service in opsdb.get_circuits(cx):
            equipment_interface = '%s:%s' % (
                service['equipment'], service['interface_name'])
            interface_services[equipment_interface].append(service)

    r = get_next_redis(InventoryTask.config)
    for key in r.scan_iter('opsdb:interface_services:*'):
        r.delete(key)
    rp = r.pipeline()
    for equipment_interface, services in interface_services.items():
        rp.set(
            f'opsdb:interface_services:{equipment_interface}',
            json.dumps(services))
    rp.execute()

    logger.debug('<<< update_interfaces_to_services')


@app.task(base=InventoryTask, bind=True)
def update_access_services(self):
    logger.debug('>>> update_access_services')

    access_services = {}
    with db.connection(InventoryTask.config["ops-db"]) as cx:
        for service in opsdb.get_access_services(cx):

            if service['equipment'] in access_services:
                logger.warning(
                    f'got multiple access services for {service["equipment"]}')

            access_services[service['equipment']] = service

    r = get_next_redis(InventoryTask.config)
    for key in r.scan_iter('opsdb:access_services:*'):
        r.delete(key)
    rp = r.pipeline()
    for equipment_interface, service in access_services.items():
        rp.set(
            f'opsdb:access_services:{equipment_interface}',
            json.dumps(service))
    rp.execute()

    logger.debug('<<< update_access_services')


@app.task(base=InventoryTask, bind=True)
def update_lg_routers(self):
    logger.debug('>>> update_lg_routers')

    r = get_next_redis(InventoryTask.config)
    for k in r.scan_iter('opsdb:lg:*'):
        r.delete(k)

    with db.connection(InventoryTask.config["ops-db"]) as cx:
        for router in opsdb.lookup_lg_routers(cx):
            r.set(f'opsdb:lg:{router["equipment name"]}', json.dumps(router))

    logger.debug('<<< update_lg_routers')


@app.task(base=InventoryTask, bind=True)
def update_equipment_locations(self):
    logger.debug('>>> update_equipment_locations')

    r = get_next_redis(InventoryTask.config)
    for k in r.scan_iter('opsdb:location:*'):
        r.delete(k)

    with db.connection(InventoryTask.config["ops-db"]) as cx:
        for h in _derive_router_hostnames(InventoryTask.config):
            # lookup_pop_info returns a list of locations
            # (there can sometimes be more than one match)
            locations = list(opsdb.lookup_pop_info(cx, h))
            r.set('opsdb:location:%s' % h, json.dumps(locations))

    logger.debug('<<< update_equipment_locations')


@app.task(base=InventoryTask, bind=True)
def update_circuit_hierarchy(self):
    logger.debug('>>> update_circuit_hierarchy')

    # TODO: integers are not JSON keys
    with db.connection(InventoryTask.config["ops-db"]) as cx:
        child_to_parents = defaultdict(list)
        parent_to_children = defaultdict(list)
        for relation in opsdb.get_circuit_hierarchy(cx):
            parent_id = relation["parent_circuit_id"]
            child_id = relation["child_circuit_id"]
            parent_to_children[parent_id].append(relation)
            child_to_parents[child_id].append(relation)

        r = get_next_redis(InventoryTask.config)
        for key in r.scan_iter('opsdb:services:parents:*'):
            r.delete(key)
        for key in r.scan_iter('opsdb:services:children:*'):
            r.delete(key)

        rp = r.pipeline()
        for cid, parents in parent_to_children.items():
            rp.set('opsdb:services:parents:%d' % cid, json.dumps(parents))
        for cid, children in child_to_parents.items():
            rp.set('opsdb:services:children:%d' % cid, json.dumps(children))
        rp.execute()

    logger.debug('<<< update_circuit_hierarchy')


@app.task(base=InventoryTask, bind=True)
def update_geant_lambdas(self):
    logger.debug('>>> update_geant_lambdas')

    r = get_next_redis(InventoryTask.config)
    for key in r.scan_iter('opsdb:geant_lambdas:*'):
        r.delete(key)
    with db.connection(InventoryTask.config["ops-db"]) as cx:
        rp = r.pipeline()
        for ld in opsdb.get_geant_lambdas(cx):
            rp.set(
                'opsdb:geant_lambdas:%s' % ld['name'].lower(),
                json.dumps(ld))
        rp.execute()

    logger.debug('<<< geant_lambdas')


@app.task(base=InventoryTask, bind=True)
def update_junosspace_device_list(self):
    logger.debug('>>> update_junosspace_device_list')

    self.update_state(
        state=states.STARTED,
        meta={
            'task': 'update_junosspace_device_list',
            'message': 'querying junosspace for managed routers'
        })

    r = get_next_redis(InventoryTask.config)

    routers = {}
    for d in juniper.load_routers_from_junosspace(
            InventoryTask.config['junosspace']):
        routers['junosspace:' + d['hostname']] = json.dumps(d).encode('utf-8')

    self.update_state(
        state=states.STARTED,
        meta={
            'task': 'update_junosspace_device_list',
            'message': 'found %d routers, saving details' % len(routers)
        })

    for k in r.scan_iter('junosspace:*'):
        r.delete(k)
    rp = r.pipeline()
    for k, v in routers.items():
        rp.set(k, v)
    rp.execute()

    logger.debug('<<< update_junosspace_device_list')

    return {
        'task': 'update_junosspace_device_list',
        'message': 'saved %d managed routers' % len(routers)
    }


def load_netconf_data(hostname):
    """
    this method should only be called from a task

    :param hostname:
    :return:
    """
    r = get_next_redis(InventoryTask.config)
    netconf = r.get('netconf:' + hostname)
    if not netconf:
        raise InventoryTaskError('no netconf data found for %r' % hostname)
    return etree.fromstring(netconf.decode('utf-8'))


def clear_cached_classifier_responses(hostname=None):
    if hostname:
        logger.debug(
            'removing cached classifier responses for %r' % hostname)
    else:
        logger.debug('removing all cached classifier responses')

    r = get_next_redis(InventoryTask.config)

    def _hostname_keys():
        for k in r.keys('classifier-cache:juniper:%s:*' % hostname):
            yield k

        # TODO: very inefficient ... but logically simplest at this point
        for k in r.keys('classifier-cache:peer:*'):
            value = r.get(k.decode('utf-8'))
            if not value:
                # deleted in another thread
                continue
            value = json.loads(value.decode('utf-8'))
            interfaces = value.get('interfaces', [])
            if hostname in [i['interface']['router'] for i in interfaces]:
                yield k

    def _all_keys():
        return r.keys('classifier-cache:*')

    keys_to_delete = _hostname_keys() if hostname else _all_keys()
    for k in keys_to_delete:
        r.delete(k)


def _refresh_peers(hostname, key_base, peers):
    logger.debug(
        'removing cached %s for %r' % (key_base, hostname))
    r = get_next_redis(InventoryTask.config)
    # WARNING (optimization): this is an expensive query if
    #       the redis connection is slow, and we currently only
    #       call this method during a full refresh
    # for k in r.scan_iter(key_base + ':*'):
    #     # potential race condition: another proc could have
    #     # delete this element between the time we read the
    #     # keys and the next statement ... check for None below
    #     value = r.get(k.decode('utf-8'))
    #     if value:
    #         value = json.loads(value.decode('utf-8'))
    #         if value['router'] == hostname:
    #             r.delete(k)

    rp = r.pipeline()
    for peer in peers:
        peer['router'] = hostname
        rp.set(
            '%s:%s' % (key_base, peer['name']),
            json.dumps(peer))
    rp.execute()


def refresh_ix_public_peers(hostname, netconf):
    _refresh_peers(
        hostname,
        'ix_public_peer',
        juniper.ix_public_peers(netconf))


def refresh_vpn_rr_peers(hostname, netconf):
    _refresh_peers(
        hostname,
        'vpn_rr_peer',
        juniper.vpn_rr_peers(netconf))


def refresh_interface_address_lookups(hostname, netconf):
    _refresh_peers(
        hostname,
        'reverse_interface_addresses',
        juniper.interface_addresses(netconf))


def refresh_juniper_interface_list(hostname, netconf):
    logger.debug(
        'removing cached netconf-interfaces for %r' % hostname)

    r = get_next_redis(InventoryTask.config)
    for k in r.scan_iter('netconf-interfaces:%s:*' % hostname):
        r.delete(k)
    for k in r.keys('netconf-interface-bundles:%s:*' % hostname):
        r.delete(k)

    all_bundles = defaultdict(list)

    rp = r.pipeline()
    for ifc in juniper.list_interfaces(netconf):
        bundles = ifc.get('bundle', None)
        for bundle in bundles:
            if bundle:
                all_bundles[bundle].append(ifc['name'])
        rp.set(
            'netconf-interfaces:%s:%s' % (hostname, ifc['name']),
            json.dumps(ifc))
    for k, v in all_bundles.items():
        rp.set(
            'netconf-interface-bundles:%s:%s' % (hostname, k),
            json.dumps(v))
    rp.execute()


@app.task(base=InventoryTask, bind=True)
def reload_router_config(self, hostname):
    logger.debug('>>> reload_router_config')

    self.update_state(
        state=states.STARTED,
        meta={
            'task': 'reload_router_config',
            'hostname': hostname,
            'message': 'loading router netconf data'
        })

    # get the timestamp for the current netconf data
    current_netconf_timestamp = None
    try:
        netconf_doc = load_netconf_data(hostname)
        current_netconf_timestamp \
            = juniper.netconf_changed_timestamp(netconf_doc)
        logger.debug(
            'current netconf timestamp: %r' % current_netconf_timestamp)
    except InventoryTaskError:
        pass  # ok at this point if not found

    # load new netconf data
    netconf_refresh_config.apply(args=[hostname])

    netconf_doc = load_netconf_data(hostname)

    # return if new timestamp is the same as the original timestamp
    new_netconf_timestamp = juniper.netconf_changed_timestamp(netconf_doc)
    assert new_netconf_timestamp, \
        'no timestamp available for new netconf data'
    if new_netconf_timestamp == current_netconf_timestamp:
        logger.debug('no netconf change timestamp change, aborting')
        logger.debug('<<< reload_router_config')
        return {
            'task': 'reload_router_config',
            'hostname': hostname,
            'message': 'OK (no change)'
        }

    # clear cached classifier responses for this router, and
    # refresh peering data
    self.update_state(
        state=states.STARTED,
        meta={
            'task': 'reload_router_config',
            'hostname': hostname,
            'message': 'refreshing peers & clearing cache'
        })
    refresh_ix_public_peers(hostname, netconf_doc)
    refresh_vpn_rr_peers(hostname, netconf_doc)
    refresh_interface_address_lookups(hostname, netconf_doc)
    refresh_juniper_interface_list(hostname, netconf_doc)
    # clear_cached_classifier_responses(hostname)

    # load snmp indexes
    community = juniper.snmp_community_string(netconf_doc)
    if not community:
        raise InventoryTaskError(
            'error extracting community string for %r' % hostname)
    else:
        self.update_state(
            state=states.STARTED,
            meta={
                'task': 'reload_router_config',
                'hostname': hostname,
                'message': 'refreshing snmp interface indexes'
            })
        snmp_refresh_interfaces.apply(args=[hostname, community])

    clear_cached_classifier_responses(None)

    logger.debug('<<< reload_router_config')

    return {
        'task': 'reload_router_config',
        'hostname': hostname,
        'message': 'OK'
    }


def _derive_router_hostnames(config):
    r = get_next_redis(config)
    junosspace_equipment = set()
    for k in r.keys('junosspace:*'):
        m = re.match('^junosspace:(.*)$', k.decode('utf-8'))
        assert m
        junosspace_equipment.add(m.group(1))

    opsdb_equipment = set()
    for k in r.scan_iter('opsdb:interface_services:*'):
        m = re.match(
            'opsdb:interface_services:([^:]+):.*$',
            k.decode('utf-8'))
        if m:
            opsdb_equipment.add(m.group(1))
        else:
            logger.info("Unable to derive router name from %s" %
                        k.decode('utf-8'))
    return junosspace_equipment & opsdb_equipment


def _erase_next_db(config):
    """
    flush next db, but first save latch and then restore afterwards

    TODO: handle the no latch scenario nicely
    :param config:
    :return:
    """
    r = get_next_redis(config)
    saved_latch = get_latch(r)
    r.flushdb()
    if saved_latch:
        set_latch(
            config,
            new_current=saved_latch['current'],
            new_next=saved_latch['next'])


def launch_refresh_cache_all(config):
    """
    utility function intended to be called outside of the worker process
    :param config: config structure as defined in config.py
    :return:
    """
    _erase_next_db(config)

    update_latch_status(config, pending=True)

    # first batch of subtasks: refresh cached opsdb data
    subtasks = [
        update_junosspace_device_list.apply_async(),
        update_interfaces_to_services.apply_async(),
        update_geant_lambdas.apply_async(),
        update_circuit_hierarchy.apply_async()
    ]
    [x.get() for x in subtasks]

    # second batch of subtasks:
    #   alarms db status cache
    #   juniper netconf & snmp data
    subtasks = [
        update_equipment_locations.apply_async(),
        update_lg_routers.apply_async(),
        update_access_services.apply_async()
    ]
    for hostname in _derive_router_hostnames(config):
        logger.debug('queueing router refresh jobs for %r' % hostname)
        subtasks.append(reload_router_config.apply_async(args=[hostname]))

    pending_task_ids = [x.id for x in subtasks]

    t = refresh_finalizer.apply_async(args=[json.dumps(pending_task_ids)])
    pending_task_ids.append(t.id)
    return pending_task_ids


def _wait_for_tasks(task_ids, update_callback=lambda s: None):

    all_successful = True

    start_time = time.time()
    while task_ids and time.time() - start_time < FINALIZER_TIMEOUT_S:
        update_callback('waiting for tasks to complete: %r' % task_ids)
        time.sleep(FINALIZER_POLLING_FREQUENCY_S)

        def _is_error(id):
            status = check_task_status(id)
            return status['ready'] and not status['success']

        if any([_is_error(id) for id in task_ids]):
            all_successful = False

        task_ids = [
            id for id in task_ids
            if not check_task_status(id)['ready']
        ]

    if task_ids:
        raise InventoryTaskError(
            'timeout waiting for pending tasks to complete')
    if not all_successful:
        raise InventoryTaskError(
            'some tasks finished with an error')

    update_callback('pending taskscompleted in {} seconds'.format(
            time.time() - start_time))


@app.task(base=InventoryTask, bind=True)
def refresh_finalizer(self, pending_task_ids_json):
    logger.debug('>>> refresh_finalizer')
    logger.debug('task_ids: %r' % pending_task_ids_json)

    input_schema = {
        "$schema": "http://json-schema.org/draft-07/schema#",
        "type": "array",
        "items": {"type": "string"}
    }

    def _update(s):
        logger.debug(s)
        self.update_state(
            state=states.STARTED,
            meta={
                'task': 'refresh_finalizer',
                'message': s
            })

    try:
        task_ids = json.loads(pending_task_ids_json)
        logger.debug('task_ids: %r' % task_ids)
        jsonschema.validate(task_ids, input_schema)

        _wait_for_tasks(task_ids, update_callback=_update)
        _build_subnet_db(update_callback=_update)
        _build_service_category_interface_list(update_callback=_update)

    except (jsonschema.ValidationError,
            json.JSONDecodeError,
            InventoryTaskError) as e:
        update_latch_status(InventoryTask.config, failure=True)
        raise e

    latch_db(InventoryTask.config)
    _update('latched current/next dbs')

    logger.debug('<<< refresh_finalizer')


def _build_service_interface_user_list():

    def _interfaces():
        """
        yields interface info from netconf
        :return:
        """
        r = get_next_redis(InventoryTask.config)
        for k in r.scan_iter('netconf-interfaces:*'):
            k = k.decode('utf-8')
            (_, router_name, ifc_name) = k.split(':')

            info = r.get(k).decode('utf-8')
            info = json.loads(info)

            assert ifc_name == info['name']
            yield {
                'router': router_name,
                'interface': info['name'],
                'description': info['description']
            }

    def _lookup_interface_services(wanted_interfaces):
        """
        yields interface info from opsdb (with service id)
        ... only interfaces in wanted_interfaces
        :param wanted_interfaces:
        :return:
        """
        r = get_next_redis(InventoryTask.config)
        for k in r.scan_iter('opsdb:interface_services:*'):
            k = k.decode('utf-8')
            fields = k.split(':')
            if len(fields) < 4:
                # there are some strange records
                # e.g. TS1.*, ts1.*, dp1.*, dtn*, ...
                continue
            router = fields[2]
            ifc_name = fields[3]

            router_interface_key = f'{router}:{ifc_name}'
            if router_interface_key not in wanted_interfaces:
                continue

            info = r.get(k).decode('utf-8')
            info = json.loads(info)

            yield {
                'router': router,
                'interface': ifc_name,
                'service_ids': set([service['id'] for service in info])
            }

    # dict: 'router:interface' -> {'router', 'interface', 'description'}
    netconf_interface_map = dict([
        (f'{i["router"]}:{i["interface"]}', i) for i in _interfaces()])

    # dict: 'router:interface' -> {'router', 'interface', set([service_ids])}
    opsdb_interface_map = dict([
        (f'{i["router"]}:{i["interface"]}', i)
        for i in _lookup_interface_services(netconf_interface_map.keys())])

    all_service_ids = set()
    for r in opsdb_interface_map.values():
        all_service_ids |= r['service_ids']
    all_service_ids = list(all_service_ids)

    # dict: service_id[int] -> [list of users]
    service_user_map = dict()
    with db.connection(InventoryTask.config["ops-db"]) as cx:
        # for user in opsdb.get_service_users(cx, list(all_service_ids)):
        service_users = list(opsdb.get_service_users(cx, all_service_ids))
        for user in service_users:
            service_user_map.setdefault(
                user['service_id'], []).append(user['user'])

    def _users(ifc_key):
        """
        ifc = 'router:ifc_name'
        :param ifc:
        :return: list of users
        """
        users = set()
        if ifc_key not in opsdb_interface_map:
            return []
        service_id_list = opsdb_interface_map[ifc_key].get('service_ids', [])
        for service_id in service_id_list:
            users |= set(service_user_map.get(service_id, []))
        return list(users)

    for k, v in netconf_interface_map.items():
        v['users'] = _users(k)
        yield v


def _build_service_category_interface_list(update_callback=lambda s: None):
    logger.debug('>>> _build_interface_services')

    def _classify(ifc):
        if ifc['description'].startswith('SRV_MDVPN'):
            return 'mdvpn'
        if 'LHCONE' in ifc['description']:
            return 'lhcone'
        return None

    update_callback('loading all known interfaces')
    interfaces = list(_build_service_interface_user_list())
    update_callback(f'loaded {len(interfaces)} interfaces, '
                    'saving by service category')

    r = get_next_redis(InventoryTask.config)
    rp = r.pipeline()

    for ifc in interfaces:
        service_type = _classify(ifc)
        if not service_type:
            continue
        rp.set(
            f'interface-services:{service_type}'
            f':{ifc["router"]}:{ifc["interface"]}',
            json.dumps(ifc))

    rp.execute()
    logger.debug('<<< _build_interface_services')


def _build_subnet_db(update_callback=lambda s: None):

    r = get_next_redis(InventoryTask.config)

    update_callback('loading all network addresses')
    subnets = {}
    for k in r.scan_iter('reverse_interface_addresses:*'):
        info = r.get(k.decode('utf-8')).decode('utf-8')
        info = json.loads(info)
        entry = subnets.setdefault(info['interface address'], [])
        entry.append(info)

    update_callback('saving {} subnets'.format(len(subnets)))

    rp = r.pipeline()
    for k, v in subnets.items():
        rp.set('subnets:' + k, json.dumps(v))
    rp.execute()


def check_task_status(task_id):
    r = AsyncResult(task_id, app=app)
    result = {
        'id': r.id,
        'status': r.status,
        'exception': r.status in states.EXCEPTION_STATES,
        'ready': r.status in states.READY_STATES,
        'success': r.status == states.SUCCESS,
    }
    if r.result:
        # TODO: only discovered this case by testing, is this the only one?
        #       ... otherwise need to pre-test json serialization
        if isinstance(r.result, Exception):
            result['result'] = {
                'error type': type(r.result).__name__,
                'message': str(r.result)
            }
        else:
            result['result'] = r.result
    return result