From 03a91b0f737d35fe96e37ff0812c056de6f6e000 Mon Sep 17 00:00:00 2001
From: Erik Reid <erik.reid@geant.org>
Date: Mon, 1 Jul 2019 15:06:42 +0200
Subject: [PATCH] don't load from environment in bootsteps

... this makes unit testing easier
---
 inventory_provider/tasks/worker.py | 34 +++++++++++++-----------------
 1 file changed, 15 insertions(+), 19 deletions(-)

diff --git a/inventory_provider/tasks/worker.py b/inventory_provider/tasks/worker.py
index 253a88a4..9a8ecdb2 100644
--- a/inventory_provider/tasks/worker.py
+++ b/inventory_provider/tasks/worker.py
@@ -3,7 +3,7 @@ import logging
 import os
 import re
 
-from celery import bootsteps, Task, states
+from celery import Task, states
 from celery.result import AsyncResult
 
 from collections import defaultdict
@@ -31,7 +31,20 @@ class InventoryTask(Task):
     config = None
 
     def __init__(self):
-        pass
+
+        if InventoryTask.config:
+            return
+
+        assert os.path.isfile(os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']), (
+                'config file %r not found' %
+                os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'])
+
+        with open(os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']) as f:
+            logging.info(
+                    "Initializing worker with config from: %r" %
+                    os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'])
+            InventoryTask.config = config.load(f)
+            logging.debug("loaded config: %r" % InventoryTask.config)
 
     def update_state(self, **kwargs):
         logger = logging.getLogger(__name__)
@@ -62,23 +75,6 @@ def _save_value_etree(key, xml_doc):
         etree.tostring(xml_doc, encoding='unicode'))
 
 
-class LoadConfig(bootsteps.Step):
-
-    def __init__(self):
-        assert os.path.isfile(os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']), (
-                'config file %r not found' %
-                os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'])
-
-        with open(os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME']) as f:
-            logging.info(
-                    "Initializing worker with config from: %r" %
-                    os.environ['INVENTORY_PROVIDER_CONFIG_FILENAME'])
-            InventoryTask.config = config.load(f)
-
-
-app.steps['worker'].add(LoadConfig)
-
-
 @app.task
 def snmp_refresh_interfaces(hostname, community):
     logger = logging.getLogger(__name__)
-- 
GitLab