Skip to content
Snippets Groups Projects
Commit 9e00996d authored by Erik Reid's avatar Erik Reid
Browse files

made multi-threaded worker more generic/reusable

parent 44afd145
No related branches found
No related tags found
No related merge requests found
......@@ -275,6 +275,67 @@ def load_snmp_indexes(hostname=None):
return result
def distribute_jobs_across_workers(
worker_proc, jobs, input_ctx, num_threads=10):
"""
Launch `num_threads` threads with worker_proc and distribute
jobs across them. Then return the results from all workers.
(generic version of _load_redis_docs)
worker_proc should be a function that takes args:
- input queue (items from input_data_items are written here)
- output queue (results from each input item are to be written here)
- input_ctx (some worker-specific data)
worker contract is:
- None is written to input queue iff there are no more items coming
- the worker writes None to the output queue when it exits
:param worker_proc: worker proc, as above
:param input_data_items: an iterable of things to put in input queue
:param input_ctx: some data to pass when starting worker proc
:param num_threads: number of worker threads to start
:return: yields all values computed by worker procs
"""
assert isinstance(num_threads, int) and num_threads > 0 # sanity
response_queue = queue.Queue()
threads = []
for _ in range(num_threads):
q = queue.Queue()
t = threading.Thread(
target=worker_proc,
args=[q, response_queue, input_ctx])
t.start()
threads.append({'thread': t, 'queue': q})
for job_data in jobs:
t = random.choice(threads)
t['queue'].put(job_data)
# tell all threads there are no more keys coming
for t in threads:
t['queue'].put(None)
num_finished = 0
# read values from response_queue until we receive
# None len(threads) times
while num_finished < len(threads):
job_result = response_queue.get()
if not job_result:
# contract is that thread returns None when done
num_finished += 1
logger.debug('one worker thread finished')
continue
yield job_result
# cleanup like we're supposed to, even though it's python
for t in threads:
t['thread'].join(timeout=0.5) # timeout, for sanity
def ims_equipment_to_hostname(equipment):
"""
changes names like MX1.AMS.NL to mx1.ams.nl.geant.net
......
......@@ -62,13 +62,13 @@ helpers
.. autofunction:: inventory_provider.routes.msr._handle_peering_group_request
""" # noqa E501
import binascii
import functools
import hashlib
import itertools
import json
import ipaddress
import logging
import queue
import random
import re
import threading
......@@ -565,42 +565,11 @@ def _get_peering_services_multi_thread(addresses):
:param addresses: iterable of address strings
:return: yields dicts returned from _get_services_for_address
"""
response_queue = queue.Queue()
threads = []
config_params = current_app.config['INVENTORY_PROVIDER_CONFIG']
for _ in range(min(len(addresses), 10)):
q = queue.Queue()
t = threading.Thread(
target=_load_address_services_proc,
args=[q, response_queue, config_params])
t.start()
threads.append({'thread': t, 'queue': q})
for a in addresses:
t = random.choice(threads)
t['queue'].put(a)
# tell all threads there are no more keys coming
for t in threads:
t['queue'].put(None)
num_finished = 0
# read values from response_queue until we receive
# None len(threads) times
while num_finished < len(threads):
value = response_queue.get()
if not value:
# contract is that thread returns None when done
num_finished += 1
logger.debug('one worker thread finished')
continue
yield value
# cleanup like we're supposed to, even though it's python
for t in threads:
t['thread'].join(timeout=0.5) # timeout, for sanity
yield from common.distribute_jobs_across_workers(
worker_proc=_load_address_services_proc,
jobs=addresses,
input_ctx=current_app.config['INVENTORY_PROVIDER_CONFIG'],
num_threads=min(len(addresses), 10))
def _get_peering_services_single_thread(addresses):
......@@ -620,6 +589,13 @@ def _get_peering_services_single_thread(addresses):
yield from _get_services_for_address(a, r)
def _obj_key(o):
m = hashlib.sha256()
m.update(json.dumps(json.dumps(o)).encode('utf-8'))
digest = binascii.b2a_hex(m.digest()).decode('utf-8')
return digest.upper()[-4:]
@routes.route('/bgp/peering-services', methods=['POST'])
@common.require_accepts_json
def get_peering_services():
......@@ -650,21 +626,32 @@ def get_peering_services():
addresses = set(addresses) # remove duplicates
# validate addresses, to decrease chances of dying in a worker thread
for a in addresses:
assert ipaddress.ip_address(a)
input_data_key = _obj_key(sorted(list(addresses)))
cache_key = f'classifier-cache:msr:peering-services:{input_data_key}'
r = common.get_current_redis()
response = _ignore_cache_or_retrieve(request, cache_key, r)
if not response:
# validate addresses, to decrease chances of dying in a worker thread
for a in addresses:
assert ipaddress.ip_address(a)
no_threads = common.get_bool_request_arg('no-threads', False)
if no_threads:
response = _get_peering_services_single_thread(addresses)
else:
response = _get_peering_services_multi_thread(addresses)
no_threads = common.get_bool_request_arg('no-threads', False)
if no_threads:
response = _get_peering_services_single_thread(addresses)
else:
response = _get_peering_services_multi_thread(addresses)
response = list(response)
if response:
response = json.dumps(response)
r.set(cache_key, response.encode('utf-8'))
response = list(response)
if not response:
return Response(
response='no interfaces found',
status=404,
mimetype="text/html")
return jsonify(response)
return Response(response, mimetype="application/json")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment