diff --git a/inventory_provider/routes/common.py b/inventory_provider/routes/common.py index 7a84df255501f5d225279d696b71747325bb0348..04fcf283c67ca2551e6796ed4fa52e65801c448c 100644 --- a/inventory_provider/routes/common.py +++ b/inventory_provider/routes/common.py @@ -1,6 +1,10 @@ +from collections import OrderedDict import functools +import json import logging -from collections import OrderedDict +import queue +import random +import threading import requests from flask import request, Response, current_app, g @@ -103,3 +107,93 @@ def after_request(response): data, str(response.status_code))) return response + + +def _redis_client_proc(key_queue, value_queue, config_params): + """ + create a local redis connection with the current db index, + lookup the values of the keys that come from key_queue + and put them o=n value_queue + + i/o contract: + None arriving on key_queue means no more keys are coming + put None in value_queue means we are finished + + :param key_queue: + :param value_queue: + :param config_params: app config + :return: yields dicts like {'key': str, 'value': dict} + """ + try: + r = tasks_common.get_current_redis(config_params) + while True: + key = key_queue.get() + + # contract is that None means no more requests + if not key: + break + + value = r.get(key).decode('utf-8') + value_queue.put({ + 'key': key, + 'value': json.loads(value) + }) + + except json.JSONDecodeError: + logger.exception(f'error decoding entry for {key}') + + finally: + # contract is to return None when finished + value_queue.put(None) + + +def load_json_docs(config_params, key_pattern, num_threads=10): + """ + load all json docs from redis + + the loading is done with multiple connections in parallel, since this + method is called from an api handler and when the client is far from + the redis master the cumulative latency causes nginx/gunicorn timeouts + + :param config_params: app config + :param pattern: key pattern to load + :param num_threads: number of client threads to create + :return: yields dicts like {'key': str, 'value': dict} + """ + response_queue = queue.Queue() + + threads = [] + for _ in range(num_threads): + q = queue.Queue() + t = threading.Thread( + target=_redis_client_proc, + args=[q, response_queue, config_params]) + t.start() + threads.append({'thread': t, 'queue': q}) + + r = tasks_common.get_current_redis(config_params) + # scan with bigger batches, to mitigate network latency effects + for k in r.scan_iter(key_pattern, count=1000): + k = k.decode('utf-8') + t = random.choice(threads) + t['queue'].put(k) + + # tell all threads there are no more keys coming + for t in threads: + t['queue'].put(None) + + num_finished = 0 + tasks = {} + # read values from response_queue until we receive + # None len(threads) times + while num_finished < len(threads): + value = response_queue.get() + if not value: + num_finished += 1 + logger.debug('one worker thread finished') + continue + yield value + + # cleanup like we're supposed to, even though it's python + for t in threads: + t['thread'].join(timeout=0.5) # timeout, for sanity