Skip to content
Snippets Groups Projects
common.py 11.34 KiB
from collections import OrderedDict
import functools
import json
import logging
import queue
import random
import re
import threading

from distutils.util import strtobool
from lxml import etree
import requests
from flask import request, Response, current_app, g
from inventory_provider.tasks import common as tasks_common

logger = logging.getLogger(__name__)
_DECODE_TYPE_XML = 'xml'
_DECODE_TYPE_JSON = 'json'


def get_bool_request_arg(name, default=False):
    assert isinstance(default, bool)  # sanity, otherwise caller error
    value = request.args.get(name, default=str(default), type=str)
    try:
        value = bool(strtobool(value))
    except ValueError:
        value = default
    return value


def _ignore_cache_or_retrieve(request_, cache_key, r):
    ignore_cache = get_bool_request_arg('ignore-cache', default=False)
    if ignore_cache:
        result = False
        logger.debug('ignoring cache')
    else:
        result = r.get(cache_key)

    if result:
        result = result.decode('utf-8')
    return result


def ims_hostname_decorator(field):
    """
    Decorator to convert host names to various formats to try to match what is
    found in IMS before executing the decorated function.
    :param field: name of the field containing hostname
    :return: result of decorated function
    """

    suffix = '.geant.net'

    def wrapper(func):
        def inner(*args, **kwargs):
            orig_val = kwargs[field]
            values_to_try = []
            if orig_val.endswith(suffix):
                values_to_try.append(orig_val[:-len(suffix)].upper())
            values_to_try.append(orig_val.upper())
            values_to_try.append(orig_val)
            values_to_try.append(orig_val.lower())

            for val in list(OrderedDict.fromkeys(values_to_try)):
                kwargs[field] = val
                res = func(*args, **kwargs)
                if res.status_code != requests.codes.not_found:
                    return res
            return res
        return inner
    return wrapper


def get_current_redis():
    if 'current_redis_db' in g:
        latch = tasks_common.get_latch(g.current_redis_db)
        if latch and latch['current'] == latch['this']:
            return g.current_redis_db
        logger.warning('switching to current redis db')

    config = current_app.config['INVENTORY_PROVIDER_CONFIG']
    g.current_redis_db = tasks_common.get_current_redis(config)
    return g.current_redis_db


def get_next_redis():
    if 'next_redis_db' in g:
        latch = tasks_common.get_latch(g.next_redis_db)
        if latch and latch['next'] == latch['this']:
            return g.next_redis_db
        logger.warning('switching to next redis db')

    config = current_app.config['INVENTORY_PROVIDER_CONFIG']
    g.next_redis_db = tasks_common.get_next_redis(config)
    return g.next_redis_db


def require_accepts_json(f):
    """
    used as a route handler decorator to return an error
    unless the request allows responses with type "application/json"
    :param f: the function to be decorated
    :return: the decorated function
    """
    @functools.wraps(f)
    def decorated_function(*args, **kwargs):
        # TODO: use best_match to disallow */* ...?
        if not request.accept_mimetypes.accept_json:
            return Response(
                response="response will be json",
                status=406,
                mimetype="text/html")
        return f(*args, **kwargs)
    return decorated_function


def after_request(response):
    """
    generic function to do additional logging of requests & responses
    :param response:
    :return:
    """
    if response.status_code != 200:

        try:
            data = response.data.decode('utf-8')
        except Exception:
            # never expected to happen, but we don't want any failures here
            logging.exception('INTERNAL DECODING ERROR')
            data = 'decoding error (see logs)'

        logger.warning('"%s %s" "%s" %s' % (
            request.method,
            request.path,
            data,
            str(response.status_code)))
    return response


def _redis_client_proc(
        key_queue, value_queue, config_params, doc_type, use_next_redis=False):
    """
    create a local redis connection with the current db index,
    lookup the values of the keys that come from key_queue
    and put them on value_queue

    i/o contract:
        None arriving on key_queue means no more keys are coming
        put None in value_queue means we are finished

    :param key_queue:
    :param value_queue:
    :param config_params: app config
    :param doc_type: decoding type to do (xml or json)
    :return: nothing
    """
    assert doc_type in (_DECODE_TYPE_JSON, _DECODE_TYPE_XML)

    def _decode(bv):
        if not bv:
            return
        value = bv.decode('utf-8')
        if doc_type == _DECODE_TYPE_JSON:
            return json.loads(value)
        elif doc_type == _DECODE_TYPE_XML:
            return etree.XML(value)

    try:
        if use_next_redis:
            r = tasks_common.get_next_redis(config_params)
        else:
            r = tasks_common.get_current_redis(config_params)
        while True:
            key = key_queue.get()

            # contract is that None means no more requests
            if not key:
                break

            v = _decode(r.get(key))
            if v is not None:
                value_queue.put({
                    'key': key,
                    'value': v
                })

    except json.JSONDecodeError:
        logger.exception(f'error decoding entry for {key}')

    finally:
        # contract is to return None when finished
        value_queue.put(None)


def _load_redis_docs(
        config_params,
        key_pattern,
        num_threads=10,
        doc_type=_DECODE_TYPE_JSON,
        use_next_redis=False):
    """
    load all docs from redis and decode as `doc_type`

    the loading is done with multiple connections in parallel, since this
    method is called from an api handler and when the client is far from
    the redis master the cumulative latency causes nginx/gunicorn timeouts

    :param config_params: app config
    :param key_pattern: key pattern or iterable of keys to load
    :param num_threads: number of client threads to create
    :param doc_type: decoding type to do (xml or json)
    :return: yields dicts like {'key': str, 'value': dict or xml doc}
    """
    assert doc_type in (_DECODE_TYPE_XML, _DECODE_TYPE_JSON)
    response_queue = queue.Queue()

    threads = []
    for _ in range(num_threads):
        q = queue.Queue()
        t = threading.Thread(
            target=_redis_client_proc,
            args=[q, response_queue, config_params, doc_type, use_next_redis])
        t.start()
        threads.append({'thread': t, 'queue': q})

    if use_next_redis:
        r = tasks_common.get_next_redis(config_params)
    else:
        r = tasks_common.get_current_redis(config_params)

    if isinstance(key_pattern, str):
        # scan with bigger batches, to mitigate network latency effects
        for k in r.scan_iter(key_pattern, count=1000):
            t = random.choice(threads)
            t['queue'].put(k.decode('utf-8'))
    else:
        for k in key_pattern:
            t = random.choice(threads)
            t['queue'].put(k)

    # tell all threads there are no more keys coming
    for t in threads:
        t['queue'].put(None)

    num_finished = 0
    # read values from response_queue until we receive
    # None len(threads) times
    while num_finished < len(threads):
        value = response_queue.get()
        if not value:
            num_finished += 1
            logger.debug('one worker thread finished')
            continue
        yield value

    # cleanup like we're supposed to, even though it's python
    for t in threads:
        t['thread'].join(timeout=0.5)  # timeout, for sanity


def load_json_docs(
        config_params, key_pattern, num_threads=10, use_next_redis=False):
    yield from _load_redis_docs(
        config_params,
        key_pattern,
        num_threads,
        doc_type=_DECODE_TYPE_JSON,
        use_next_redis=use_next_redis
    )


def load_xml_docs(
        config_params, key_pattern, num_threads=10, use_next_redis=False):
    yield from _load_redis_docs(
        config_params,
        key_pattern,
        num_threads,
        doc_type=_DECODE_TYPE_XML,
        use_next_redis=use_next_redis)


def load_snmp_indexes(config, hostname=None, use_next_redis=False):
    result = dict()
    key_pattern = f'snmp-interfaces:{hostname}*' \
        if hostname else 'snmp-interfaces:*'

    for doc in load_json_docs(
            config_params=config,
            key_pattern=key_pattern,
            use_next_redis=use_next_redis):
        router = doc['key'][len('snmp-interfaces:'):]
        result[router] = {e['name']: e for e in doc['value']}

    return result


def distribute_jobs_across_workers(
        worker_proc, jobs, input_ctx, num_threads=10):
    """
    Launch `num_threads` threads with worker_proc and distribute
    jobs across them.  Then return the results from all workers.

    (generic version of _load_redis_docs)

    worker_proc should be a function that takes args:
      - input queue (items from input_data_items are written here)
      - output queue (results from each input item are to be written here)
      - input_ctx (some worker-specific data)

    worker contract is:
      - None is written to input queue iff there are no more items coming
      - the worker writes None to the output queue when it exits

    :param worker_proc: worker proc, as above
    :param input_data_items: an iterable of things to put in input queue
    :param input_ctx: some data to pass when starting worker proc
    :param num_threads: number of worker threads to start
    :return: yields all values computed by worker procs
    """
    assert isinstance(num_threads, int) and num_threads > 0  # sanity

    response_queue = queue.Queue()

    threads = []
    for _ in range(num_threads):
        q = queue.Queue()
        t = threading.Thread(
            target=worker_proc,
            args=[q, response_queue, input_ctx])
        t.start()
        threads.append({'thread': t, 'queue': q})

    for job_data in jobs:
        t = random.choice(threads)
        t['queue'].put(job_data)

    # tell all threads there are no more keys coming
    for t in threads:
        t['queue'].put(None)

    num_finished = 0
    # read values from response_queue until we receive
    # None len(threads) times
    while num_finished < len(threads):
        job_result = response_queue.get()
        if not job_result:
            # contract is that thread returns None when done
            num_finished += 1
            logger.debug('one worker thread finished')
            continue
        yield job_result

    # cleanup like we're supposed to, even though it's python
    for t in threads:
        t['thread'].join(timeout=0.5)  # timeout, for sanity


def ims_equipment_to_hostname(equipment):
    """
    changes names like MX1.AMS.NL to mx1.ams.nl.geant.net

    leaves CPE names alone (e.g. 'INTERXION Z-END')
    :param equipment: the IMS equipment name string
    :return: hostname, or the input string if it doesn't look like a host
    """
    if re.match(r'.*\s.*', equipment):
        # doesn't look like a hostname
        return equipment

    hostname = equipment.lower()
    if not re.match(r'.*\.geant\.(net|org)$', hostname):
        hostname = f'{hostname}.geant.net'
    return hostname