From 03a91b0f737d35fe96e37ff0812c056de6f6e000 Mon Sep 17 00:00:00 2001 From: Erik Reid <erik.reid@geant.org> Date: Mon, 1 Jul 2019 15:06:42 +0200 Subject: [PATCH] don't load from environment in bootsteps ... this makes unit testing easier --- inventory_provider/tasks/worker.py | 34 +++++++++++++----------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/inventory_provider/tasks/worker.py b/inventory_provider/tasks/worker.py index 253a88a4..9a8ecdb2 100644 --- a/inventory_provider/tasks/worker.py +++ b/inventory_provider/tasks/worker.py @@ -3,7 +3,7 @@ import logging import os import re -from celery import bootsteps, Task, states +from celery import Task, states from celery.result import AsyncResult from collections import defaultdict @@ -31,7 +31,20 @@ class InventoryTask(Task): config = None def __init__(self): - pass + + if InventoryTask.config: + return + + assert os.path.isfile(os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']), ( + 'config file %r not found' % + os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']) + + with open(os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']) as f: + logging.info( + "Initializing worker with config from: %r" % + os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']) + InventoryTask.config = config.load(f) + logging.debug("loaded config: %r" % InventoryTask.config) def update_state(self, **kwargs): logger = logging.getLogger(__name__) @@ -62,23 +75,6 @@ def _save_value_etree(key, xml_doc): etree.tostring(xml_doc, encoding='unicode')) -class LoadConfig(bootsteps.Step): - - def __init__(self): - assert os.path.isfile(os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']), ( - 'config file %r not found' % - os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']) - - with open(os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']) as f: - logging.info( - "Initializing worker with config from: %r" % - os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']) - InventoryTask.config = config.load(f) - - -app.steps['worker'].add(LoadConfig) - - @app.task def snmp_refresh_interfaces(hostname, community): logger = logging.getLogger(__name__) -- GitLab