Skip to content
Snippets Groups Projects
worker.py 12.77 KiB
import json
import logging
import re

from celery import bootsteps, Task, group, states
from celery.result import AsyncResult

from collections import defaultdict
from lxml import etree

from inventory_provider.tasks.app import app
from inventory_provider.tasks.common import get_redis
from inventory_provider import config
from inventory_provider import constants
from inventory_provider import environment
from inventory_provider.db import db, opsdb, alarmsdb
from inventory_provider import snmp
from inventory_provider import juniper

# TODO: error callback (cf. http://docs.celeryproject.org/en/latest/userguide/calling.html#linking-callbacks-errbacks)  # noqa: E501

environment.setup_logging()


class InventoryTaskError(Exception):
    pass


class InventoryTask(Task):

    config = None

    def __init__(self):
        pass

    def update_state(self, **kwargs):
        task_logger = logging.getLogger(constants.TASK_LOGGER_NAME)
        task_logger.debug(json.dumps(
            {'state': kwargs['state'], 'meta': kwargs['meta']}
        ))
        super().update_state(**kwargs)


def _save_value(key, value):
    assert isinstance(value, str), \
        "sanity failure: expected string data as value"
    r = get_redis(InventoryTask.config)
    r.set(name=key, value=value)
    # InventoryTask.logger.debug("saved %s" % key)
    return "OK"


def _save_value_json(key, data_obj):
    _save_value(
        key,
        json.dumps(data_obj))


def _save_value_etree(key, xml_doc):
    _save_value(
        key,
        etree.tostring(xml_doc, encoding='unicode'))


class WorkerArgs(bootsteps.Step):
    def __init__(self, worker, config_filename, **options):
        with open(config_filename) as f:
            task_logger = logging.getLogger(constants.TASK_LOGGER_NAME)
            task_logger.info(
                "Initializing worker with config from: %r" % config_filename)
            InventoryTask.config = config.load(f)


def worker_args(parser):
    parser.add_argument(
        "--config_filename",
        dest="config_filename",
        action='store',
        type=str,
        help="Configuration filename")


app.user_options['worker'].add(worker_args)
app.steps['worker'].add(WorkerArgs)


@app.task
def snmp_refresh_interfaces(hostname, community):
    task_logger = logging.getLogger(constants.TASK_LOGGER_NAME)
    task_logger.debug(
        '>>> snmp_refresh_interfaces(%r, %r)' % (hostname, community))

    _save_value_json(
        'snmp-interfaces:' + hostname,
        list(snmp.get_router_interfaces(
            hostname,
            community,
            InventoryTask.config)))

    task_logger.debug(
        '<<< snmp_refresh_interfaces(%r, %r)' % (hostname, community))


@app.task
def netconf_refresh_config(hostname):
    task_logger = logging.getLogger(constants.TASK_LOGGER_NAME)
    task_logger.debug('>>> netconf_refresh_config(%r)' % hostname)

    _save_value_etree(
        'netconf:' + hostname,
        juniper.load_config(hostname, InventoryTask.config["ssh"]))

    task_logger.debug('<<< netconf_refresh_config(%r)' % hostname)


@app.task
def update_interfaces_to_services():
    task_logger = logging.getLogger(constants.TASK_LOGGER_NAME)
    task_logger.debug('>>> update_interfaces_to_services')

    interface_services = defaultdict(list)
    with db.connection(InventoryTask.config["ops-db"]) as cx:
        for service in opsdb.get_circuits(cx):
            equipment_interface = '%s:%s' % (
                service['equipment'], service['interface_name'])
            interface_services[equipment_interface].append(service)

    r = get_redis(InventoryTask.config)
    for key in r.scan_iter('opsdb:interface_services:*'):
        r.delete(key)
    for equipment_interface, services in interface_services.items():
        r.set(
            'opsdb:interface_services:' + equipment_interface,
            json.dumps(services))

    task_logger.debug('<<< update_interfaces_to_services')


@app.task
def update_equipment_locations():
    task_logger = logging.getLogger(constants.TASK_LOGGER_NAME)
    task_logger.debug('>>> update_equipment_locations')

    r = get_redis(InventoryTask.config)
    for key in r.scan_iter('opsdb:location:*'):
        r.delete(key)
    with db.connection(InventoryTask.config["ops-db"]) as cx:
        for ld in opsdb.get_equipment_location_data(cx):
            r.set('opsdb:location:%s' % ld['equipment_name'], json.dumps(ld))

    task_logger.debug('<<< update_equipment_locations')


@app.task
def update_circuit_hierarchy():
    task_logger = logging.getLogger(constants.TASK_LOGGER_NAME)
    task_logger.debug('>>> update_circuit_hierarchy')

    # TODO: integers are not JSON keys
    with db.connection(InventoryTask.config["ops-db"]) as cx:
        child_to_parents = defaultdict(list)
        parent_to_children = defaultdict(list)
        for relation in opsdb.get_circuit_hierarchy(cx):
            parent_id = relation["parent_circuit_id"]
            child_id = relation["child_circuit_id"]
            parent_to_children[parent_id].append(relation)
            child_to_parents[child_id].append(relation)

        r = get_redis(InventoryTask.config)
        for key in r.scan_iter('opsdb:services:parents:*'):
            r.delete(key)
        for cid, parents in child_to_parents.items():
            r.set('opsdb:services:parents:%d' % cid, json.dumps(parents))

        for key in r.scan_iter('opsdb:services:children:*'):
            r.delete(key)
        for cid, children in child_to_parents.items():
            r.set('opsdb:services:children:%d' % cid, json.dumps(children))

    task_logger.debug('<<< update_circuit_hierarchy')


@app.task
def update_interface_statuses():
    task_logger = logging.getLogger(constants.TASK_LOGGER_NAME)
    task_logger.debug('>>> update_interface_statuses')

    with db.connection(InventoryTask.config["ops-db"]) as cx:
        services = opsdb.get_circuits(cx)
    with db.connection(InventoryTask.config["alarms-db"]) as cx:
        with db.cursor(cx) as csr:
            for service in services:
                key = 'alarmsdb:interface_status:%s:%s' \
                      % (service['equipment'], service['interface_name'])
                status = alarmsdb.get_last_known_interface_status(
                    csr,
                    service["equipment"],
                    service["interface_name"])
                _save_value(key, status)

    task_logger.debug('<<< update_interface_statuses')


@app.task(base=InventoryTask, bind=True)
def update_junosspace_device_list(self):
    task_logger = logging.getLogger(constants.TASK_LOGGER_NAME)
    task_logger.debug('>>> update_junosspace_device_list')

    self.update_state(
        state=states.STARTED,
        meta={
            'task': 'update_junosspace_device_list',
            'message': 'querying junosspace for managed routers'
        })

    r = get_redis(InventoryTask.config)

    routers = {}
    for d in juniper.load_routers_from_junosspace(
            InventoryTask.config['junosspace']):
        routers['junosspace:' + d['hostname']] = json.dumps(d).encode('utf-8')

    self.update_state(
        state=states.STARTED,
        meta={
            'task': 'update_junosspace_device_list',
            'message': 'found %d routers, saving details' % len(routers)
        })

    for k in r.keys('junosspace:*'):
        r.delete(k)
    for k, v in routers.items():
        r.set(k, v)

    task_logger.debug('<<< update_junosspace_device_list')

    return {
        'task': 'update_junosspace_device_list',
        'message': 'saved %d managed routers' % len(routers)
    }


def load_netconf_data(hostname):
    """
    this method should only be called from a task

    :param hostname:
    :return:
    """
    r = get_redis(InventoryTask.config)
    netconf = r.get('netconf:' + hostname)
    if not netconf:
        return None
    return etree.fromstring(netconf.decode('utf-8'))


def clear_cached_classifier_responses(hostname):
    task_logger = logging.getLogger(constants.TASK_LOGGER_NAME)
    task_logger.debug(
        'removing cached classifier responses for %r' % hostname)
    r = get_redis(InventoryTask.config)
    for k in r.keys('classifier:cache:%s:*' % hostname):
        r.delete(k)


def _refresh_peers(hostname, key_base, peers):
    task_logger = logging.getLogger(constants.TASK_LOGGER_NAME)
    task_logger.debug(
        'removing cached %s for %r' % (key_base, hostname))
    r = get_redis(InventoryTask.config)
    for k in r.keys(key_base + ':*'):
        # potential race condition: another proc could have
        # delete this element between the time we read the
        # keys and the next statement ... check for None below
        value = r.get(k.decode('utf-8'))
        if value:
            value = json.loads(value.decode('utf-8'))
            if value['router'] == hostname:
                r.delete(k)

    for peer in peers:
        peer['router'] = hostname
        r.set(
            '%s:%s' % (key_base, peer['name']),
            json.dumps(peer))


def refresh_ix_public_peers(hostname, netconf):
    _refresh_peers(
        hostname,
        'ix_public_peer',
        juniper.ix_public_peers(netconf))


def refresh_vpn_rr_peers(hostname, netconf):
    _refresh_peers(
        hostname,
        'vpn_rr_peer',
        juniper.vpn_rr_peers(netconf))


@app.task(base=InventoryTask, bind=True)
def reload_router_config(self, hostname):
    task_logger = logging.getLogger(constants.TASK_LOGGER_NAME)
    task_logger.debug('>>> reload_router_config')

    self.update_state(
        state=states.STARTED,
        meta={
            'task': 'reload_router_config',
            'hostname': hostname,
            'message': 'loading router netconf data'
        })

    netconf_refresh_config.apply(args=[hostname])

    netconf_doc = load_netconf_data(hostname)
    if netconf_doc is None:
        raise InventoryTaskError(
            'no netconf data available for %r' % hostname)
    else:
        self.update_state(
            state=states.STARTED,
            meta={
                'task': 'reload_router_config',
                'hostname': hostname,
                'message': 'refreshing peers'
            })
        refresh_ix_public_peers(hostname, netconf_doc)
        refresh_vpn_rr_peers(hostname, netconf_doc)

        community = juniper.snmp_community_string(netconf_doc)
        if not community:
            raise InventoryTaskError(
                'error extracting community string for %r' % hostname)
        else:
            self.update_state(
                state=states.STARTED,
                meta={
                    'task': 'reload_router_config',
                    'hostname': hostname,
                    'message': 'refreshing snmp interface indexes'
                })
            snmp_refresh_interfaces.apply(args=[hostname, community])

            # TODO: move this out of else? (i.e. clear even if netconf fails?)
            clear_cached_classifier_responses(hostname)

    task_logger.debug('<<< reload_router_config')

    return {
        'task': 'reload_router_config',
        'hostname': hostname,
        'message': 'OK'
    }


def _derive_router_hostnames(config):
    r = get_redis(config)
    junosspace_equipment = set()
    for k in r.keys('junosspace:*'):
        m = re.match('^junosspace:(.*)$', k.decode('utf-8'))
        assert m
        junosspace_equipment.add(m.group(1))

    opsdb_equipment = set()
    for k in r.keys('opsdb:interface_services:*'):
        m = re.match(
            'opsdb:interface_services:([^:]+):.*$',
            k.decode('utf-8'))
        opsdb_equipment.add(m.group(1))

    return junosspace_equipment & opsdb_equipment


def launch_refresh_cache_all(config):
    """
    utility function intended to be called outside of the worker process
    :param config: config structure as defined in config.py
    :return:
    """
    task_logger = logging.getLogger(constants.TASK_LOGGER_NAME)

    # first batch of subtasks: refresh cached opsdb data
    subtasks = [
        update_junosspace_device_list.s(),
        update_interfaces_to_services.s(),
        update_circuit_hierarchy.s()
    ]

    results = group(subtasks).apply_async()
    results.join()

    # second batch of subtasks:
    #   alarms db status cache
    #   juniper netconf & snmp data
    subtasks = [
        update_equipment_locations.s(),
        update_interface_statuses.s()
    ]
    for hostname in _derive_router_hostnames(config):
        task_logger.debug(
            'queueing router refresh jobs for %r' % hostname)
        subtasks.append(reload_router_config.s(hostname))

    return [r.id for r in group(subtasks).apply_async()]


def check_task_status(task_id):
    r = AsyncResult(task_id, app=app)
    result = {
        'id': task_id,
        'status': r.status,
        'exception': r.status in states.EXCEPTION_STATES,
        'ready': r.status in states.READY_STATES,
        'success': r.status == states.SUCCESS,
    }
    if r.result:
        result['result'] = r.result
    return result