Skip to content
Snippets Groups Projects
common.py 6.89 KiB
from collections import OrderedDict
import functools
import json
import logging
import queue
import random
import threading

from lxml import etree
import requests
from flask import request, Response, current_app, g
from inventory_provider.tasks import common as tasks_common

logger = logging.getLogger(__name__)
_DECODE_TYPE_XML = 'xml'
_DECODE_TYPE_JSON = 'json'


def ims_hostname_decorator(field):
    """
    Decorator to convert host names to various formats to try to match what is
    found in IMS before executing the decorated function.
    :param field: name of the field containing hostname
    :return: result of decorated function
    """

    suffix = '.geant.net'

    def wrapper(func):
        def inner(*args, **kwargs):
            orig_val = kwargs[field]
            values_to_try = []
            if orig_val.endswith(suffix):
                values_to_try.append(orig_val[:-len(suffix)].upper())
            values_to_try.append(orig_val.upper())
            values_to_try.append(orig_val)
            values_to_try.append(orig_val.lower())

            for val in list(OrderedDict.fromkeys(values_to_try)):
                kwargs[field] = val
                res = func(*args, **kwargs)
                if res.status_code != requests.codes.not_found:
                    return res
            return res
        return inner
    return wrapper


def get_current_redis():
    if 'current_redis_db' in g:
        latch = tasks_common.get_latch(g.current_redis_db)
        if latch and latch['current'] == latch['this']:
            return g.current_redis_db
        logger.warning('switching to current redis db')

    config = current_app.config['INVENTORY_PROVIDER_CONFIG']
    g.current_redis_db = tasks_common.get_current_redis(config)
    return g.current_redis_db


def get_next_redis():
    if 'next_redis_db' in g:
        latch = tasks_common.get_latch(g.next_redis_db)
        if latch and latch['next'] == latch['this']:
            return g.next_redis_db
        logger.warning('switching to next redis db')

    config = current_app.config['INVENTORY_PROVIDER_CONFIG']
    g.next_redis_db = tasks_common.get_next_redis(config)
    return g.next_redis_db


def require_accepts_json(f):
    """
    used as a route handler decorator to return an error
    unless the request allows responses with type "application/json"
    :param f: the function to be decorated
    :return: the decorated function
    """
    @functools.wraps(f)
    def decorated_function(*args, **kwargs):
        # TODO: use best_match to disallow */* ...?
        if not request.accept_mimetypes.accept_json:
            return Response(
                response="response will be json",
                status=406,
                mimetype="text/html")
        return f(*args, **kwargs)
    return decorated_function


def after_request(response):
    """
    generic function to do additional logging of requests & responses
    :param response:
    :return:
    """
    if response.status_code != 200:

        try:
            data = response.data.decode('utf-8')
        except Exception:
            # never expected to happen, but we don't want any failures here
            logging.exception('INTERNAL DECODING ERROR')
            data = 'decoding error (see logs)'

        logger.warning('"%s %s" "%s" %s' % (
            request.method,
            request.path,
            data,
            str(response.status_code)))
    return response


def _redis_client_proc(key_queue, value_queue, config_params, doc_type):
    """
    create a local redis connection with the current db index,
    lookup the values of the keys that come from key_queue
    and put them on value_queue

    i/o contract:
        None arriving on key_queue means no more keys are coming
        put None in value_queue means we are finished

    :param key_queue:
    :param value_queue:
    :param config_params: app config
    :param doc_type: decoding type to do (xml or json)
    :return: nothing
    """
    assert doc_type in (_DECODE_TYPE_JSON, _DECODE_TYPE_XML)

    def _decode(bv):
        value = bv.decode('utf-8')
        if doc_type == _DECODE_TYPE_JSON:
            return json.loads(value)
        elif doc_type == _DECODE_TYPE_XML:
            return etree.XML(value)

    try:
        r = tasks_common.get_current_redis(config_params)
        while True:
            key = key_queue.get()

            # contract is that None means no more requests
            if not key:
                break

            value_queue.put({
                'key': key,
                'value': _decode(r.get(key))
            })

    except json.JSONDecodeError:
        logger.exception(f'error decoding entry for {key}')

    finally:
        # contract is to return None when finished
        value_queue.put(None)


def _load_redis_docs(
        config_params,
        key_pattern,
        num_threads=10,
        doc_type=_DECODE_TYPE_JSON):
    """
    load all docs from redis and decode as `doc_type`

    the loading is done with multiple connections in parallel, since this
    method is called from an api handler and when the client is far from
    the redis master the cumulative latency causes nginx/gunicorn timeouts

    :param config_params: app config
    :param pattern: key pattern to load
    :param num_threads: number of client threads to create
    :param doc_type: decoding type to do (xml or json)
    :return: yields dicts like {'key': str, 'value': dict or xml doc}
    """
    assert doc_type in (_DECODE_TYPE_XML, _DECODE_TYPE_JSON)
    response_queue = queue.Queue()

    threads = []
    for _ in range(num_threads):
        q = queue.Queue()
        t = threading.Thread(
            target=_redis_client_proc,
            args=[q, response_queue, config_params, doc_type])
        t.start()
        threads.append({'thread': t, 'queue': q})

    r = tasks_common.get_current_redis(config_params)
    # scan with bigger batches, to mitigate network latency effects
    for k in r.scan_iter(key_pattern, count=1000):
        t = random.choice(threads)
        t['queue'].put(k.decode('utf-8'))

    # tell all threads there are no more keys coming
    for t in threads:
        t['queue'].put(None)

    num_finished = 0
    # read values from response_queue until we receive
    # None len(threads) times
    while num_finished < len(threads):
        value = response_queue.get()
        if not value:
            num_finished += 1
            logger.debug('one worker thread finished')
            continue
        yield value

    # cleanup like we're supposed to, even though it's python
    for t in threads:
        t['thread'].join(timeout=0.5)  # timeout, for sanity


def load_json_docs(config_params, key_pattern, num_threads=10):
    yield from _load_redis_docs(
        config_params, key_pattern, num_threads, doc_type=_DECODE_TYPE_JSON)


def load_xml_docs(config_params, key_pattern, num_threads=10):
    yield from _load_redis_docs(
        config_params, key_pattern, num_threads, doc_type=_DECODE_TYPE_XML)