-
Robert Latta authoredRobert Latta authored
common.py 11.34 KiB
from collections import OrderedDict
import functools
import json
import logging
import queue
import random
import re
import threading
from distutils.util import strtobool
from lxml import etree
import requests
from flask import request, Response, current_app, g
from inventory_provider.tasks import common as tasks_common
logger = logging.getLogger(__name__)
_DECODE_TYPE_XML = 'xml'
_DECODE_TYPE_JSON = 'json'
def get_bool_request_arg(name, default=False):
assert isinstance(default, bool) # sanity, otherwise caller error
value = request.args.get(name, default=str(default), type=str)
try:
value = bool(strtobool(value))
except ValueError:
value = default
return value
def _ignore_cache_or_retrieve(request_, cache_key, r):
ignore_cache = get_bool_request_arg('ignore-cache', default=False)
if ignore_cache:
result = False
logger.debug('ignoring cache')
else:
result = r.get(cache_key)
if result:
result = result.decode('utf-8')
return result
def ims_hostname_decorator(field):
"""
Decorator to convert host names to various formats to try to match what is
found in IMS before executing the decorated function.
:param field: name of the field containing hostname
:return: result of decorated function
"""
suffix = '.geant.net'
def wrapper(func):
def inner(*args, **kwargs):
orig_val = kwargs[field]
values_to_try = []
if orig_val.endswith(suffix):
values_to_try.append(orig_val[:-len(suffix)].upper())
values_to_try.append(orig_val.upper())
values_to_try.append(orig_val)
values_to_try.append(orig_val.lower())
for val in list(OrderedDict.fromkeys(values_to_try)):
kwargs[field] = val
res = func(*args, **kwargs)
if res.status_code != requests.codes.not_found:
return res
return res
return inner
return wrapper
def get_current_redis():
if 'current_redis_db' in g:
latch = tasks_common.get_latch(g.current_redis_db)
if latch and latch['current'] == latch['this']:
return g.current_redis_db
logger.warning('switching to current redis db')
config = current_app.config['INVENTORY_PROVIDER_CONFIG']
g.current_redis_db = tasks_common.get_current_redis(config)
return g.current_redis_db
def get_next_redis():
if 'next_redis_db' in g:
latch = tasks_common.get_latch(g.next_redis_db)
if latch and latch['next'] == latch['this']:
return g.next_redis_db
logger.warning('switching to next redis db')
config = current_app.config['INVENTORY_PROVIDER_CONFIG']
g.next_redis_db = tasks_common.get_next_redis(config)
return g.next_redis_db
def require_accepts_json(f):
"""
used as a route handler decorator to return an error
unless the request allows responses with type "application/json"
:param f: the function to be decorated
:return: the decorated function
"""
@functools.wraps(f)
def decorated_function(*args, **kwargs):
# TODO: use best_match to disallow */* ...?
if not request.accept_mimetypes.accept_json:
return Response(
response="response will be json",
status=406,
mimetype="text/html")
return f(*args, **kwargs)
return decorated_function
def after_request(response):
"""
generic function to do additional logging of requests & responses
:param response:
:return:
"""
if response.status_code != 200:
try:
data = response.data.decode('utf-8')
except Exception:
# never expected to happen, but we don't want any failures here
logging.exception('INTERNAL DECODING ERROR')
data = 'decoding error (see logs)'
logger.warning('"%s %s" "%s" %s' % (
request.method,
request.path,
data,
str(response.status_code)))
return response
def _redis_client_proc(
key_queue, value_queue, config_params, doc_type, use_next_redis=False):
"""
create a local redis connection with the current db index,
lookup the values of the keys that come from key_queue
and put them on value_queue
i/o contract:
None arriving on key_queue means no more keys are coming
put None in value_queue means we are finished
:param key_queue:
:param value_queue:
:param config_params: app config
:param doc_type: decoding type to do (xml or json)
:return: nothing
"""
assert doc_type in (_DECODE_TYPE_JSON, _DECODE_TYPE_XML)
def _decode(bv):
if not bv:
return
value = bv.decode('utf-8')
if doc_type == _DECODE_TYPE_JSON:
return json.loads(value)
elif doc_type == _DECODE_TYPE_XML:
return etree.XML(value)
try:
if use_next_redis:
r = tasks_common.get_next_redis(config_params)
else:
r = tasks_common.get_current_redis(config_params)
while True:
key = key_queue.get()
# contract is that None means no more requests
if not key:
break
v = _decode(r.get(key))
if v is not None:
value_queue.put({
'key': key,
'value': v
})
except json.JSONDecodeError:
logger.exception(f'error decoding entry for {key}')
finally:
# contract is to return None when finished
value_queue.put(None)
def _load_redis_docs(
config_params,
key_pattern,
num_threads=10,
doc_type=_DECODE_TYPE_JSON,
use_next_redis=False):
"""
load all docs from redis and decode as `doc_type`
the loading is done with multiple connections in parallel, since this
method is called from an api handler and when the client is far from
the redis master the cumulative latency causes nginx/gunicorn timeouts
:param config_params: app config
:param key_pattern: key pattern or iterable of keys to load
:param num_threads: number of client threads to create
:param doc_type: decoding type to do (xml or json)
:return: yields dicts like {'key': str, 'value': dict or xml doc}
"""
assert doc_type in (_DECODE_TYPE_XML, _DECODE_TYPE_JSON)
response_queue = queue.Queue()
threads = []
for _ in range(num_threads):
q = queue.Queue()
t = threading.Thread(
target=_redis_client_proc,
args=[q, response_queue, config_params, doc_type, use_next_redis])
t.start()
threads.append({'thread': t, 'queue': q})
if use_next_redis:
r = tasks_common.get_next_redis(config_params)
else:
r = tasks_common.get_current_redis(config_params)
if isinstance(key_pattern, str):
# scan with bigger batches, to mitigate network latency effects
for k in r.scan_iter(key_pattern, count=1000):
t = random.choice(threads)
t['queue'].put(k.decode('utf-8'))
else:
for k in key_pattern:
t = random.choice(threads)
t['queue'].put(k)
# tell all threads there are no more keys coming
for t in threads:
t['queue'].put(None)
num_finished = 0
# read values from response_queue until we receive
# None len(threads) times
while num_finished < len(threads):
value = response_queue.get()
if not value:
num_finished += 1
logger.debug('one worker thread finished')
continue
yield value
# cleanup like we're supposed to, even though it's python
for t in threads:
t['thread'].join(timeout=0.5) # timeout, for sanity
def load_json_docs(
config_params, key_pattern, num_threads=10, use_next_redis=False):
yield from _load_redis_docs(
config_params,
key_pattern,
num_threads,
doc_type=_DECODE_TYPE_JSON,
use_next_redis=use_next_redis
)
def load_xml_docs(
config_params, key_pattern, num_threads=10, use_next_redis=False):
yield from _load_redis_docs(
config_params,
key_pattern,
num_threads,
doc_type=_DECODE_TYPE_XML,
use_next_redis=use_next_redis)
def load_snmp_indexes(config, hostname=None, use_next_redis=False):
result = dict()
key_pattern = f'snmp-interfaces:{hostname}*' \
if hostname else 'snmp-interfaces:*'
for doc in load_json_docs(
config_params=config,
key_pattern=key_pattern,
use_next_redis=use_next_redis):
router = doc['key'][len('snmp-interfaces:'):]
result[router] = {e['name']: e for e in doc['value']}
return result
def distribute_jobs_across_workers(
worker_proc, jobs, input_ctx, num_threads=10):
"""
Launch `num_threads` threads with worker_proc and distribute
jobs across them. Then return the results from all workers.
(generic version of _load_redis_docs)
worker_proc should be a function that takes args:
- input queue (items from input_data_items are written here)
- output queue (results from each input item are to be written here)
- input_ctx (some worker-specific data)
worker contract is:
- None is written to input queue iff there are no more items coming
- the worker writes None to the output queue when it exits
:param worker_proc: worker proc, as above
:param input_data_items: an iterable of things to put in input queue
:param input_ctx: some data to pass when starting worker proc
:param num_threads: number of worker threads to start
:return: yields all values computed by worker procs
"""
assert isinstance(num_threads, int) and num_threads > 0 # sanity
response_queue = queue.Queue()
threads = []
for _ in range(num_threads):
q = queue.Queue()
t = threading.Thread(
target=worker_proc,
args=[q, response_queue, input_ctx])
t.start()
threads.append({'thread': t, 'queue': q})
for job_data in jobs:
t = random.choice(threads)
t['queue'].put(job_data)
# tell all threads there are no more keys coming
for t in threads:
t['queue'].put(None)
num_finished = 0
# read values from response_queue until we receive
# None len(threads) times
while num_finished < len(threads):
job_result = response_queue.get()
if not job_result:
# contract is that thread returns None when done
num_finished += 1
logger.debug('one worker thread finished')
continue
yield job_result
# cleanup like we're supposed to, even though it's python
for t in threads:
t['thread'].join(timeout=0.5) # timeout, for sanity
def ims_equipment_to_hostname(equipment):
"""
changes names like MX1.AMS.NL to mx1.ams.nl.geant.net
leaves CPE names alone (e.g. 'INTERXION Z-END')
:param equipment: the IMS equipment name string
:return: hostname, or the input string if it doesn't look like a host
"""
if re.match(r'.*\s.*', equipment):
# doesn't look like a hostname
return equipment
hostname = equipment.lower()
if not re.match(r'.*\.geant\.(net|org)$', hostname):
hostname = f'{hostname}.geant.net'
return hostname