From ade6dc2dee55c5e710ebe625f36dd4bc6a7b6337 Mon Sep 17 00:00:00 2001
From: Martin van Es <martin@mrvanes.com>
Date: Wed, 24 Nov 2021 12:32:20 +0100
Subject: [PATCH] Add SIGHUP signal to mdserver to reload metadata

---
 mdserver.py | 95 +++++++++++++++++++++++++++++++----------------------
 mdsigner.py |  2 +-
 2 files changed, 57 insertions(+), 40 deletions(-)

diff --git a/mdserver.py b/mdserver.py
index fc68901..0a03b1a 100755
--- a/mdserver.py
+++ b/mdserver.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python
 import sys
 import copy
+import signal
 
 from lxml import etree as ET
 from flask import Flask
@@ -18,7 +19,6 @@ app = Flask(__name__)
 # Find all IdP's in edugain metadata
 idps = {}
 signed = {}
-found = 0
 
 cert = open("meta.crt").read()
 key = open("meta.key").read()
@@ -29,55 +29,72 @@ def sign(eid):
     global idps, signed, cert, key
     entityID = unquote(eid)
     if entityID[:6] == "{sha1}":
-        entityID = entityID[6:]
+        sha1 = entityID[6:]
     else:
-        entityID = hasher(entityID)
+        sha1 = hasher(entityID)
 
-    if entityID in signed:
-        signed_entity = signed[entityID]
+    if sha1 in signed:
+        signed_entity = signed[sha1]
         if signed_entity.valid_until > datetime.now(tz.tzutc()):
-            print(f"serve {entityID}")
-            return signed[entityID].md
+            print(f"serve {sha1}")
+            return signed[sha1].md
 
-    if entityID in idps:
+    if sha1 in idps:
         try:
-            print(f"sign {entityID}")
-            signed_element = signer(idps[entityID].md, cert, key)
-            signed_xml = ET.tostring(signed_element, pretty_print=True).decode()
-            signed_entity = Entity()
-            signed_entity.md = signed_xml
-            signed_entity.valid_until = idps[entityID].valid_until
-            signed[entityID] = signed_entity
-            return signed_xml
+            print(f"sign {sha1}")
+            valid_until = idps[sha1].valid_until
+            if valid_until > datetime.now(tz.tzutc()):
+                signed_element = signer(idps[sha1].md, cert, key)
+                signed_xml = ET.tostring(signed_element, pretty_print=True).decode()
+                signed_entity = Entity()
+                signed_entity.md = signed_xml
+                signed_entity.valid_until = idps[sha1].valid_until
+                signed[sha1] = signed_entity
+                return signed_xml
         except Exception as e:
-            print(entityID)
+            print(sha1)
             print(f"  {e}")
             traceback.print_exc()
 
     return "No valid metadata\n", 404
 
 
-for mdfile in sys.argv[1:]:
-    tree = ET.ElementTree(file=mdfile)
-    root = tree.getroot()
-    ns = copy.deepcopy(root.nsmap)
-    ns['xml'] = 'http://www.w3.org/XML/1998/namespace'
-    validUntil = root.get('validUntil')
-    cacheDuration = root.get('cacheDuration')
-    for entity_descriptor in root.findall('md:EntityDescriptor', ns):
-        entityID = entity_descriptor.attrib.get('entityID', 'none')
-        sha1 = hasher(entityID)
-        entity_descriptor.set('validUntil', validUntil)
-        entity_descriptor.set('cacheDuration', cacheDuration)
-        entity = Entity()
-        entity.md = entity_descriptor
-        entity.valid_until = parser.isoparse(validUntil)
-        if sha1 not in idps:
-            print(entityID)
-            print(sha1)
-            idps[sha1] = entity
-            found += 1
-
-print(f"Found: {found}")
+def read_metadata(signum, frm):
+    global idps, signed
+    found = 0
+    removed = 0
+    old_idps = copy.deepcopy(idps)
+    for mdfile in sys.argv[1:]:
+        tree = ET.ElementTree(file=mdfile)
+        root = tree.getroot()
+        ns = copy.deepcopy(root.nsmap)
+        ns['xml'] = 'http://www.w3.org/XML/1998/namespace'
+        validUntil = root.get('validUntil')
+        cacheDuration = root.get('cacheDuration')
+        valid_until = parser.isoparse(validUntil)
+        if valid_until > datetime.now(tz.tzutc()):
+            for entity_descriptor in root.findall('md:EntityDescriptor', ns):
+                entityID = entity_descriptor.attrib.get('entityID', 'none')
+                sha1 = hasher(entityID)
+                entity_descriptor.set('validUntil', validUntil)
+                entity_descriptor.set('cacheDuration', cacheDuration)
+                entity = Entity()
+                entity.md = entity_descriptor
+                entity.valid_until = valid_until
+                print(f"{{sha1}}{sha1} {entityID}")
+                idps[sha1] = entity
+                signed.pop(sha1, None)
+                old_idps.pop(sha1, None)
+                found += 1
+        for idp in old_idps:
+            idps.pop(idp, None)
+            signed.pop(idp, None)
+            removed += 1
+
+    print(f"Found: {found} entities")
+    print(f"Removed: {removed} entities")
+    print(f"validUntil: {validUntil}")
+
+signal.signal(signal.SIGHUP, read_metadata)
 
 app.run(host='0.0.0.0', port=5001)
diff --git a/mdsigner.py b/mdsigner.py
index 3a3d29b..352eb61 100755
--- a/mdsigner.py
+++ b/mdsigner.py
@@ -6,7 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
 from lxml import etree as ET
 # import traceback
 
-from .utils import hasher, signer
+from utils import hasher, signer
 
 
 # Find all IdP's in edugain metadata
-- 
GitLab