Skip to content
Snippets Groups Projects
Commit 9e00996d authored by Erik Reid's avatar Erik Reid
Browse files

made multi-threaded worker more generic/reusable

parent 44afd145
No related branches found
No related tags found
No related merge requests found
...@@ -275,6 +275,67 @@ def load_snmp_indexes(hostname=None): ...@@ -275,6 +275,67 @@ def load_snmp_indexes(hostname=None):
return result return result
def distribute_jobs_across_workers(
worker_proc, jobs, input_ctx, num_threads=10):
"""
Launch `num_threads` threads with worker_proc and distribute
jobs across them. Then return the results from all workers.
(generic version of _load_redis_docs)
worker_proc should be a function that takes args:
- input queue (items from input_data_items are written here)
- output queue (results from each input item are to be written here)
- input_ctx (some worker-specific data)
worker contract is:
- None is written to input queue iff there are no more items coming
- the worker writes None to the output queue when it exits
:param worker_proc: worker proc, as above
:param input_data_items: an iterable of things to put in input queue
:param input_ctx: some data to pass when starting worker proc
:param num_threads: number of worker threads to start
:return: yields all values computed by worker procs
"""
assert isinstance(num_threads, int) and num_threads > 0 # sanity
response_queue = queue.Queue()
threads = []
for _ in range(num_threads):
q = queue.Queue()
t = threading.Thread(
target=worker_proc,
args=[q, response_queue, input_ctx])
t.start()
threads.append({'thread': t, 'queue': q})
for job_data in jobs:
t = random.choice(threads)
t['queue'].put(job_data)
# tell all threads there are no more keys coming
for t in threads:
t['queue'].put(None)
num_finished = 0
# read values from response_queue until we receive
# None len(threads) times
while num_finished < len(threads):
job_result = response_queue.get()
if not job_result:
# contract is that thread returns None when done
num_finished += 1
logger.debug('one worker thread finished')
continue
yield job_result
# cleanup like we're supposed to, even though it's python
for t in threads:
t['thread'].join(timeout=0.5) # timeout, for sanity
def ims_equipment_to_hostname(equipment): def ims_equipment_to_hostname(equipment):
""" """
changes names like MX1.AMS.NL to mx1.ams.nl.geant.net changes names like MX1.AMS.NL to mx1.ams.nl.geant.net
......
...@@ -62,13 +62,13 @@ helpers ...@@ -62,13 +62,13 @@ helpers
.. autofunction:: inventory_provider.routes.msr._handle_peering_group_request .. autofunction:: inventory_provider.routes.msr._handle_peering_group_request
""" # noqa E501 """ # noqa E501
import binascii
import functools import functools
import hashlib
import itertools import itertools
import json import json
import ipaddress import ipaddress
import logging import logging
import queue
import random
import re import re
import threading import threading
...@@ -565,42 +565,11 @@ def _get_peering_services_multi_thread(addresses): ...@@ -565,42 +565,11 @@ def _get_peering_services_multi_thread(addresses):
:param addresses: iterable of address strings :param addresses: iterable of address strings
:return: yields dicts returned from _get_services_for_address :return: yields dicts returned from _get_services_for_address
""" """
yield from common.distribute_jobs_across_workers(
response_queue = queue.Queue() worker_proc=_load_address_services_proc,
jobs=addresses,
threads = [] input_ctx=current_app.config['INVENTORY_PROVIDER_CONFIG'],
config_params = current_app.config['INVENTORY_PROVIDER_CONFIG'] num_threads=min(len(addresses), 10))
for _ in range(min(len(addresses), 10)):
q = queue.Queue()
t = threading.Thread(
target=_load_address_services_proc,
args=[q, response_queue, config_params])
t.start()
threads.append({'thread': t, 'queue': q})
for a in addresses:
t = random.choice(threads)
t['queue'].put(a)
# tell all threads there are no more keys coming
for t in threads:
t['queue'].put(None)
num_finished = 0
# read values from response_queue until we receive
# None len(threads) times
while num_finished < len(threads):
value = response_queue.get()
if not value:
# contract is that thread returns None when done
num_finished += 1
logger.debug('one worker thread finished')
continue
yield value
# cleanup like we're supposed to, even though it's python
for t in threads:
t['thread'].join(timeout=0.5) # timeout, for sanity
def _get_peering_services_single_thread(addresses): def _get_peering_services_single_thread(addresses):
...@@ -620,6 +589,13 @@ def _get_peering_services_single_thread(addresses): ...@@ -620,6 +589,13 @@ def _get_peering_services_single_thread(addresses):
yield from _get_services_for_address(a, r) yield from _get_services_for_address(a, r)
def _obj_key(o):
m = hashlib.sha256()
m.update(json.dumps(json.dumps(o)).encode('utf-8'))
digest = binascii.b2a_hex(m.digest()).decode('utf-8')
return digest.upper()[-4:]
@routes.route('/bgp/peering-services', methods=['POST']) @routes.route('/bgp/peering-services', methods=['POST'])
@common.require_accepts_json @common.require_accepts_json
def get_peering_services(): def get_peering_services():
...@@ -650,21 +626,32 @@ def get_peering_services(): ...@@ -650,21 +626,32 @@ def get_peering_services():
addresses = set(addresses) # remove duplicates addresses = set(addresses) # remove duplicates
# validate addresses, to decrease chances of dying in a worker thread input_data_key = _obj_key(sorted(list(addresses)))
for a in addresses: cache_key = f'classifier-cache:msr:peering-services:{input_data_key}'
assert ipaddress.ip_address(a)
r = common.get_current_redis()
response = _ignore_cache_or_retrieve(request, cache_key, r)
if not response:
# validate addresses, to decrease chances of dying in a worker thread
for a in addresses:
assert ipaddress.ip_address(a)
no_threads = common.get_bool_request_arg('no-threads', False)
if no_threads:
response = _get_peering_services_single_thread(addresses)
else:
response = _get_peering_services_multi_thread(addresses)
no_threads = common.get_bool_request_arg('no-threads', False) response = list(response)
if no_threads: if response:
response = _get_peering_services_single_thread(addresses) response = json.dumps(response)
else: r.set(cache_key, response.encode('utf-8'))
response = _get_peering_services_multi_thread(addresses)
response = list(response)
if not response: if not response:
return Response( return Response(
response='no interfaces found', response='no interfaces found',
status=404, status=404,
mimetype="text/html") mimetype="text/html")
return jsonify(response) return Response(response, mimetype="application/json")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment