From 3ff076d3d817ae5fbe45e1e36c28c4d2fdcacf52 Mon Sep 17 00:00:00 2001
From: Erik Reid <erik.reid@geant.org>
Date: Thu, 18 Jul 2019 14:28:21 +0200
Subject: [PATCH] create finalizer task that waits for a list of task ids

---
 inventory_provider/tasks/worker.py | 48 +++++++++++++++++++++++++++++-
 1 file changed, 47 insertions(+), 1 deletion(-)

diff --git a/inventory_provider/tasks/worker.py b/inventory_provider/tasks/worker.py
index 88dbfc69..3c4ffff2 100644
--- a/inventory_provider/tasks/worker.py
+++ b/inventory_provider/tasks/worker.py
@@ -2,12 +2,14 @@ import json
 import logging
 import os
 import re
+import time
 
 from celery import Task, states
 from celery.result import AsyncResult
 
 from collections import defaultdict
 from lxml import etree
+import jsonschema
 
 from inventory_provider.tasks.app import app
 from inventory_provider.tasks.common import get_next_redis
@@ -17,6 +19,9 @@ from inventory_provider.db import db, opsdb
 from inventory_provider import snmp
 from inventory_provider import juniper
 
+FINALIZER_POLLING_FREQUENCY_S = 2.5
+FINALIZER_TIMEOUT_S = 300
+
 # TODO: error callback (cf. http://docs.celeryproject.org/en/latest/userguide/calling.html#linking-callbacks-errbacks)  # noqa: E501
 
 environment.setup_logging()
@@ -464,7 +469,48 @@ def launch_refresh_cache_all(config):
             'queueing router refresh jobs for %r' % hostname)
         subtasks.append(reload_router_config.apply_async(args=[hostname]))
 
-    return [x.id for x in subtasks]
+    pending_task_ids = [x.id for x in subtasks]
+
+    t = refresh_finalizer.apply_async(args=json.dumps(pending_task_ids))
+    pending_task_ids.append(t.id)
+    return pending_task_ids
+
+
+def _wait_for_tasks(task_ids):
+    logger = logging.getLogger(__name__)
+
+    start_time = time.time()
+    while task_ids and time.time() - start_time > FINALIZER_TIMEOUT_S:
+        logger.debug('waiting for tasks to complete: %r', task_ids)
+        time.sleep(FINALIZER_POLLING_FREQUENCY_S)
+        task_ids = [
+            id for id in task_ids
+            if not check_task_status(id)['ready']
+        ]
+
+    if task_ids:
+        raise InventoryTaskError(
+            'timeout waiting for pending tasks to complete')
+
+    logger.debug(
+        'previous tasks completed in {} seconds'.format(
+            time.time - start_time))
+
+
+@app.task(base=InventoryTask, bind=True)
+def refresh_finalizer(self, pending_task_ids_json):
+    logger = logging.getLogger(__name__)
+    logger.debug('>>> refresh_finalizer')
+
+    input_schema = {
+        "$schema": "http://json-schema.org/draft-07/schema#",
+        "type": "array",
+        "items": {"type": "string"}
+    }
+
+    task_ids = json.loads(pending_task_ids_json)
+    assert jsonschema.validate(task_ids, input_schema)
+    _wait_for_tasks(task_ids)
 
 
 def check_task_status(task_id):
-- 
GitLab