import json import logging import os import time from celery import Task, states from celery.result import AsyncResult from collections import defaultdict from lxml import etree import jsonschema from inventory_provider.tasks.app import app from inventory_provider.tasks.common \ import get_next_redis, latch_db, get_latch, set_latch, update_latch_status from inventory_provider.tasks import data from inventory_provider import config from inventory_provider import environment from inventory_provider.db import db, opsdb from inventory_provider import snmp from inventory_provider import juniper FINALIZER_POLLING_FREQUENCY_S = 2.5 FINALIZER_TIMEOUT_S = 300 # TODO: error callback (cf. http://docs.celeryproject.org/en/latest/userguide/calling.html#linking-callbacks-errbacks) # noqa: E501 environment.setup_logging() logger = logging.getLogger(__name__) def log_task_entry_and_exit(f): # cf. https://stackoverflow.com/a/47663642 def _w(*args, **kwargs): logger.debug(f'>>> {f.__name__}{args}') try: return f(*args, *kwargs) finally: logger.debug(f'<<< {f.__name__}{args}') return _w class InventoryTaskError(Exception): pass class InventoryTask(Task): config = None def __init__(self): if InventoryTask.config: return assert os.path.isfile( os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']), ( 'config file %r not found' % os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']) with open(os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']) as f: logging.info( "Initializing worker with config from: %r" % os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']) InventoryTask.config = config.load(f) logging.debug("loaded config: %r" % InventoryTask.config) def update_state(self, **kwargs): logger.debug(json.dumps( {'state': kwargs['state'], 'meta': str(kwargs['meta'])} )) super().update_state(**kwargs) def on_failure(self, exc, task_id, args, kwargs, einfo): logger.exception(exc) super().on_failure(exc, task_id, args, kwargs, einfo) @app.task(base=InventoryTask, bind=True, name='snmp_refresh_interfaces') @log_task_entry_and_exit def snmp_refresh_interfaces(self, hostname, community): value = list(snmp.get_router_snmp_indexes(hostname, community)) r = get_next_redis(InventoryTask.config) r.set('snmp-interfaces:' + hostname, json.dumps(value)) @app.task(base=InventoryTask, bind=True, name='netconf_refresh_config') @log_task_entry_and_exit def netconf_refresh_config(self, hostname): netconf_doc = juniper.load_config(hostname, InventoryTask.config["ssh"]) netconf_str = etree.tostring(netconf_doc, encoding='unicode') r = get_next_redis(InventoryTask.config) r.set('netconf:' + hostname, netconf_str) @app.task(base=InventoryTask, bind=True, name='update_interfaces_to_services') @log_task_entry_and_exit def update_interfaces_to_services(self): interface_services = defaultdict(list) with db.connection(InventoryTask.config["ops-db"]) as cx: for service in opsdb.get_circuits(cx): equipment_interface = '%s:%s' % ( service['equipment'], service['interface_name']) interface_services[equipment_interface].append(service) r = get_next_redis(InventoryTask.config) rp = r.pipeline() for key in r.scan_iter('opsdb:interface_services:*'): rp.delete(key) rp.execute() rp = r.pipeline() for equipment_interface, services in interface_services.items(): rp.set( f'opsdb:interface_services:{equipment_interface}', json.dumps(services)) rp.execute() @app.task(base=InventoryTask, bind=True, name='import_unmanaged_interfaces') @log_task_entry_and_exit def import_unmanaged_interfaces(self): def _convert(d): # the config file keys are more readable than # the keys used in redis return { 'name': d['address'], 'interface address': d['network'], 'interface name': d['interface'].lower(), 'router': d['router'].lower() } interfaces = [ _convert(ifc) for ifc in InventoryTask.config.get('unmanaged-interfaces', []) ] if interfaces: r = get_next_redis(InventoryTask.config) rp = r.pipeline() for ifc in interfaces: rp.set( f'reverse_interface_addresses:{ifc["name"]}', json.dumps(ifc)) rp.set( f'subnets:{ifc["interface address"]}', json.dumps([ifc])) rp.execute() @app.task(base=InventoryTask, bind=True, name='update_access_services') @log_task_entry_and_exit def update_access_services(self): access_services = {} with db.connection(InventoryTask.config["ops-db"]) as cx: for service in opsdb.get_access_services(cx): if service['name'] in access_services: logger.warning( 'got multiple access services ' f'with name "{service["name"]}"') access_services[service['name']] = service r = get_next_redis(InventoryTask.config) rp = r.pipeline() for key in r.scan_iter('opsdb:access_services:*'): rp.delete(key) rp.execute() rp = r.pipeline() for name, service in access_services.items(): rp.set( f'opsdb:access_services:{name}', json.dumps(service)) rp.execute() @app.task(base=InventoryTask, bind=True, name='update_lg_routers') @log_task_entry_and_exit def update_lg_routers(self): r = get_next_redis(InventoryTask.config) rp = r.pipeline() for k in r.scan_iter('opsdb:lg:*'): rp.delete(k) rp.execute() with db.connection(InventoryTask.config["ops-db"]) as cx: rp = r.pipeline() for router in opsdb.lookup_lg_routers(cx): rp.set(f'opsdb:lg:{router["equipment name"]}', json.dumps(router)) rp.execute() @app.task(base=InventoryTask, bind=True, name='update_equipment_locations') @log_task_entry_and_exit def update_equipment_locations(self): r = get_next_redis(InventoryTask.config) rp = r.pipeline() for k in r.scan_iter('opsdb:location:*'): rp.delete(k) rp.execute() with db.connection(InventoryTask.config["ops-db"]) as cx: rp = r.pipeline() for h in data.derive_router_hostnames(InventoryTask.config): # lookup_pop_info returns a list of locations # (there can sometimes be more than one match) locations = list(opsdb.lookup_pop_info(cx, h)) rp.set('opsdb:location:%s' % h, json.dumps(locations)) rp.execute() @app.task(base=InventoryTask, bind=True, name='update_circuit_hierarchy') @log_task_entry_and_exit def update_circuit_hierarchy(self): # TODO: integers are not JSON keys with db.connection(InventoryTask.config["ops-db"]) as cx: child_to_parents = defaultdict(list) parent_to_children = defaultdict(list) for relation in opsdb.get_circuit_hierarchy(cx): parent_id = relation["parent_circuit_id"] child_id = relation["child_circuit_id"] parent_to_children[parent_id].append(relation) child_to_parents[child_id].append(relation) r = get_next_redis(InventoryTask.config) rp = r.pipeline() for key in r.scan_iter('opsdb:services:parents:*'): rp.delete(key) for key in r.scan_iter('opsdb:services:children:*'): rp.delete(key) rp.execute() rp = r.pipeline() for cid, parents in parent_to_children.items(): rp.set('opsdb:services:parents:%d' % cid, json.dumps(parents)) for cid, children in child_to_parents.items(): rp.set('opsdb:services:children:%d' % cid, json.dumps(children)) rp.execute() @app.task(base=InventoryTask, bind=True, name='update_geant_lambdas') @log_task_entry_and_exit def update_geant_lambdas(self): r = get_next_redis(InventoryTask.config) rp = r.pipeline() for key in r.scan_iter('opsdb:geant_lambdas:*'): rp.delete(key) rp.execute() with db.connection(InventoryTask.config["ops-db"]) as cx: rp = r.pipeline() for ld in opsdb.get_geant_lambdas(cx): rp.set( 'opsdb:geant_lambdas:%s' % ld['name'].lower(), json.dumps(ld)) rp.execute() @app.task(base=InventoryTask, bind=True, name='update_neteng_managed_device_list') @log_task_entry_and_exit def update_neteng_managed_device_list(self): self.update_state( state=states.STARTED, meta={ 'task': 'update_neteng_managed_device_list', 'message': 'querying netdash for managed routers' }) routers = list(juniper.load_routers_from_netdash( InventoryTask.config['managed-routers'])) self.update_state( state=states.STARTED, meta={ 'task': 'update_neteng_managed_device_list', 'message': f'found {len(routers)} routers, saving details' }) r = get_next_redis(InventoryTask.config) r.set('netdash', json.dumps(routers).encode('utf-8')) return { 'task': 'update_neteng_managed_device_list', 'message': 'saved %d managed routers' % len(routers) } def load_netconf_data(hostname): """ this method should only be called from a task :param hostname: :return: """ r = get_next_redis(InventoryTask.config) netconf = r.get('netconf:' + hostname) if not netconf: raise InventoryTaskError('no netconf data found for %r' % hostname) return etree.fromstring(netconf.decode('utf-8')) def clear_cached_classifier_responses(hostname=None): if hostname: logger.debug( 'removing cached classifier responses for %r' % hostname) else: logger.debug('removing all cached classifier responses') r = get_next_redis(InventoryTask.config) def _hostname_keys(): for k in r.keys('classifier-cache:juniper:%s:*' % hostname): yield k # TODO: very inefficient ... but logically simplest at this point for k in r.keys('classifier-cache:peer:*'): value = r.get(k.decode('utf-8')) if not value: # deleted in another thread continue value = json.loads(value.decode('utf-8')) interfaces = value.get('interfaces', []) if hostname in [i['interface']['router'] for i in interfaces]: yield k def _all_keys(): return r.keys('classifier-cache:*') keys_to_delete = _hostname_keys() if hostname else _all_keys() rp = r.pipeline() for k in keys_to_delete: rp.delete(k) rp.execute() def _refresh_peers(hostname, key_base, peers): logger.debug( 'removing cached %s for %r' % (key_base, hostname)) r = get_next_redis(InventoryTask.config) # WARNING (optimization): this is an expensive query if # the redis connection is slow, and we currently only # call this method during a full refresh # for k in r.scan_iter(key_base + ':*'): # # potential race condition: another proc could have # # delete this element between the time we read the # # keys and the next statement ... check for None below # value = r.get(k.decode('utf-8')) # if value: # value = json.loads(value.decode('utf-8')) # if value['router'] == hostname: # r.delete(k) rp = r.pipeline() for peer in peers: peer['router'] = hostname rp.set( '%s:%s' % (key_base, peer['name']), json.dumps(peer)) rp.execute() def refresh_ix_public_peers(hostname, netconf): _refresh_peers( hostname, 'ix_public_peer', juniper.ix_public_peers(netconf)) def refresh_vpn_rr_peers(hostname, netconf): _refresh_peers( hostname, 'vpn_rr_peer', juniper.vpn_rr_peers(netconf)) def refresh_interface_address_lookups(hostname, netconf): _refresh_peers( hostname, 'reverse_interface_addresses', juniper.interface_addresses(netconf)) def refresh_juniper_interface_list(hostname, netconf): logger.debug( 'removing cached netconf-interfaces for %r' % hostname) r = get_next_redis(InventoryTask.config) rp = r.pipeline() for k in r.scan_iter('netconf-interfaces:%s:*' % hostname): rp.delete(k) for k in r.keys('netconf-interface-bundles:%s:*' % hostname): rp.delete(k) rp.execute() all_bundles = defaultdict(list) rp = r.pipeline() for ifc in juniper.list_interfaces(netconf): bundles = ifc.get('bundle', None) for bundle in bundles: if bundle: all_bundles[bundle].append(ifc['name']) rp.set( 'netconf-interfaces:%s:%s' % (hostname, ifc['name']), json.dumps(ifc)) for k, v in all_bundles.items(): rp.set( 'netconf-interface-bundles:%s:%s' % (hostname, k), json.dumps(v)) rp.execute() @app.task(base=InventoryTask, bind=True, name='reload_router_config') @log_task_entry_and_exit def reload_router_config(self, hostname): self.update_state( state=states.STARTED, meta={ 'task': 'reload_router_config', 'hostname': hostname, 'message': 'loading router netconf data' }) # get the timestamp for the current netconf data current_netconf_timestamp = None try: netconf_doc = load_netconf_data(hostname) current_netconf_timestamp \ = juniper.netconf_changed_timestamp(netconf_doc) logger.debug( 'current netconf timestamp: %r' % current_netconf_timestamp) except InventoryTaskError: pass # ok at this point if not found # load new netconf data netconf_refresh_config.apply(args=[hostname]) netconf_doc = load_netconf_data(hostname) # return if new timestamp is the same as the original timestamp new_netconf_timestamp = juniper.netconf_changed_timestamp(netconf_doc) assert new_netconf_timestamp, \ 'no timestamp available for new netconf data' if new_netconf_timestamp == current_netconf_timestamp: logger.debug('no netconf change timestamp change, aborting') return { 'task': 'reload_router_config', 'hostname': hostname, 'message': 'OK (no change)' } # clear cached classifier responses for this router, and # refresh peering data self.update_state( state=states.STARTED, meta={ 'task': 'reload_router_config', 'hostname': hostname, 'message': 'refreshing peers & clearing cache' }) refresh_ix_public_peers(hostname, netconf_doc) refresh_vpn_rr_peers(hostname, netconf_doc) refresh_interface_address_lookups(hostname, netconf_doc) refresh_juniper_interface_list(hostname, netconf_doc) # clear_cached_classifier_responses(hostname) # load snmp indexes community = juniper.snmp_community_string(netconf_doc) if not community: raise InventoryTaskError( 'error extracting community string for %r' % hostname) else: self.update_state( state=states.STARTED, meta={ 'task': 'reload_router_config', 'hostname': hostname, 'message': 'refreshing snmp interface indexes' }) snmp_refresh_interfaces.apply(args=[hostname, community]) clear_cached_classifier_responses(None) return { 'task': 'reload_router_config', 'hostname': hostname, 'message': 'OK' } def _erase_next_db(config): """ flush next db, but first save latch and then restore afterwards TODO: handle the no latch scenario nicely :param config: :return: """ r = get_next_redis(config) saved_latch = get_latch(r) r.flushdb() if saved_latch: set_latch( config, new_current=saved_latch['current'], new_next=saved_latch['next']) def launch_refresh_cache_all(config): """ utility function intended to be called outside of the worker process :param config: config structure as defined in config.py :return: """ _erase_next_db(config) update_latch_status(config, pending=True) # first batch of subtasks: refresh cached opsdb data subtasks = [ update_neteng_managed_device_list.apply_async(), update_interfaces_to_services.apply_async(), update_geant_lambdas.apply_async(), update_circuit_hierarchy.apply_async() ] [x.get() for x in subtasks] # second batch of subtasks: # alarms db status cache # juniper netconf & snmp data subtasks = [ update_equipment_locations.apply_async(), update_lg_routers.apply_async(), update_access_services.apply_async(), import_unmanaged_interfaces.apply_async() ] for hostname in data.derive_router_hostnames(config): logger.debug('queueing router refresh jobs for %r' % hostname) subtasks.append(reload_router_config.apply_async(args=[hostname])) pending_task_ids = [x.id for x in subtasks] t = refresh_finalizer.apply_async(args=[json.dumps(pending_task_ids)]) pending_task_ids.append(t.id) return pending_task_ids def _wait_for_tasks(task_ids, update_callback=lambda s: None): all_successful = True start_time = time.time() while task_ids and time.time() - start_time < FINALIZER_TIMEOUT_S: update_callback('waiting for tasks to complete: %r' % task_ids) time.sleep(FINALIZER_POLLING_FREQUENCY_S) def _is_error(id): status = check_task_status(id) return status['ready'] and not status['success'] if any([_is_error(id) for id in task_ids]): all_successful = False task_ids = [ id for id in task_ids if not check_task_status(id)['ready'] ] if task_ids: raise InventoryTaskError( 'timeout waiting for pending tasks to complete') if not all_successful: raise InventoryTaskError( 'some tasks finished with an error') update_callback('pending taskscompleted in {} seconds'.format( time.time() - start_time)) @app.task(base=InventoryTask, bind=True, name='refresh_finalizer') @log_task_entry_and_exit def refresh_finalizer(self, pending_task_ids_json): input_schema = { "$schema": "http://json-schema.org/draft-07/schema#", "type": "array", "items": {"type": "string"} } def _update(s): logger.debug(s) self.update_state( state=states.STARTED, meta={ 'task': 'refresh_finalizer', 'message': s }) try: task_ids = json.loads(pending_task_ids_json) logger.debug('task_ids: %r' % task_ids) jsonschema.validate(task_ids, input_schema) _wait_for_tasks(task_ids, update_callback=_update) _build_subnet_db(update_callback=_update) _build_service_category_interface_list(update_callback=_update) except (jsonschema.ValidationError, json.JSONDecodeError, InventoryTaskError) as e: update_latch_status(InventoryTask.config, failure=True) raise e latch_db(InventoryTask.config) _update('latched current/next dbs') def _build_service_category_interface_list(update_callback=lambda s: None): def _classify(ifc): if ifc['description'].startswith('SRV_MDVPN'): return 'mdvpn' if 'LHCONE' in ifc['description']: return 'lhcone' return None update_callback('loading all known interfaces') interfaces = data.build_service_interface_user_list(InventoryTask.config) interfaces = list(interfaces) update_callback(f'loaded {len(interfaces)} interfaces, ' 'saving by service category') r = get_next_redis(InventoryTask.config) rp = r.pipeline() for ifc in interfaces: service_type = _classify(ifc) if not service_type: continue rp.set( f'interface-services:{service_type}' f':{ifc["router"]}:{ifc["interface"]}', json.dumps(ifc)) rp.execute() def _build_subnet_db(update_callback=lambda s: None): r = get_next_redis(InventoryTask.config) update_callback('loading all network addresses') subnets = {} for k in r.scan_iter('reverse_interface_addresses:*'): info = r.get(k.decode('utf-8')).decode('utf-8') info = json.loads(info) entry = subnets.setdefault(info['interface address'], []) entry.append(info) update_callback('saving {} subnets'.format(len(subnets))) rp = r.pipeline() for k, v in subnets.items(): rp.set('subnets:' + k, json.dumps(v)) rp.execute() def check_task_status(task_id): r = AsyncResult(task_id, app=app) result = { 'id': r.id, 'status': r.status, 'exception': r.status in states.EXCEPTION_STATES, 'ready': r.status in states.READY_STATES, 'success': r.status == states.SUCCESS, } if r.result: # TODO: only discovered this case by testing, is this the only one? # ... otherwise need to pre-test json serialization if isinstance(r.result, Exception): result['result'] = { 'error type': type(r.result).__name__, 'message': str(r.result) } else: result['result'] = r.result return result