From 4a93eb396470eb7a445289f3963f8de85619b6ad Mon Sep 17 00:00:00 2001
From: Martin van Es <martin@mrvanes.com>
Date: Thu, 9 Dec 2021 14:31:39 +0100
Subject: [PATCH] Add per-domain signer, inotiify listener

---
 .gitignore            |   1 +
 README.md             |   5 +-
 mdserver.py           |  91 ++++++++---------------------------
 mdserver.yaml.example |   7 +++
 requirements.txt      |   2 +
 utils.py              | 109 +++++++++++++++++++++++++++++++++++++++++-
 6 files changed, 141 insertions(+), 74 deletions(-)
 create mode 100644 mdserver.yaml.example

diff --git a/.gitignore b/.gitignore
index 27bbec0..02410a7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,3 +6,4 @@ __pycache__
 meta.crt
 meta.key
 *.xml
+mdserver.yaml
diff --git a/README.md b/README.md
index 8022bcf..28db9e5 100644
--- a/README.md
+++ b/README.md
@@ -16,8 +16,9 @@ Reads source metadata file(s) and outputs them signed to filesystem
 
 ## ```mdserver.py [mdfile] [mdfile] [mdfile] ...```
 Starts a metadata signer server.
-Reads source metadata files(s) on SIGHUP signal.
-Serves and caches them signed from memory, on request
+Reads source metadata files(s) from mdsigner.yaml configuration, see example.
+Reloads metadata on inotify CLOSE_WRITE of metadata file.
+Serves and caches signed by domain signer from memory, on request
 
 ## ```mdproxy.py```
 Caches signed and cached ```mdserver.py``` metadata requests
diff --git a/mdserver.py b/mdserver.py
index a133879..0d35ae9 100755
--- a/mdserver.py
+++ b/mdserver.py
@@ -1,33 +1,20 @@
 #!/usr/bin/env python
-import sys
-import copy
-import signal
-
 from lxml import etree as ET
 from flask import Flask, Response
 from urllib.parse import unquote
-from dateutil import parser, tz
+from dateutil import tz
 from datetime import datetime
-from isoduration import parse_duration
+# import pyinotify
 
 import traceback
-from utils import hasher, signer, Entity
-
+from utils import read_config, read_domain, hasher, idps, \
+    signed, signer, Signers, Entity
 
 app = Flask(__name__)
 
 
-# Find all IdP's in edugain metadata
-idps = {}
-signed = {}
-
-cert = open("meta.crt").read()
-key = open("meta.key").read()
-
-
-@app.route('/sign/<path:eid>', methods=['GET'])
-def sign(eid):
-    global idps, signed, cert, key
+@app.route('/<domain>/entities/<path:eid>', methods=['GET'])
+def sign(domain, eid):
     entityID = unquote(eid)
     if entityID[:6] == "{sha1}":
         sha1 = entityID[6:]
@@ -37,21 +24,21 @@ def sign(eid):
     response = Response()
     response.headers['Content-Type'] = "application/samlmetadata+xml"
 
-    if sha1 in signed:
-        signed_entity = signed[sha1]
+    if sha1 in signed[domain]:
+        signed_entity = signed[domain][sha1]
         if signed_entity.valid_until > datetime.now(tz.tzutc()):
-            response.data = signed[sha1].md
-    elif sha1 in idps:
+            response.data = signed[domain][sha1].md
+    elif sha1 in idps[domain]:
         try:
-            print(f"sign {sha1}")
-            valid_until = idps[sha1].valid_until
+            print(f"sign {domain} {sha1}")
+            valid_until = idps[domain][sha1].valid_until
             if valid_until > datetime.now(tz.tzutc()):
-                signed_element = signer(idps[sha1].md, cert, key)
+                signed_element = Signers()[signer[domain]](idps[domain][sha1].md)
                 signed_xml = ET.tostring(signed_element, pretty_print=True).decode()
                 signed_entity = Entity()
                 signed_entity.md = signed_xml
-                signed_entity.valid_until = idps[sha1].valid_until
-                signed[sha1] = signed_entity
+                signed_entity.valid_until = idps[domain][sha1].valid_until
+                signed[domain][sha1] = signed_entity
                 response.data = signed_xml
         except Exception as e:
             print(sha1)
@@ -63,51 +50,15 @@ def sign(eid):
         response.status = 404
         return response
 
-    print(f"serve {sha1}")
+    print(f"serve {domain} {sha1}")
     return response
 
-def read_metadata(signum, frm):
-    print("\n--- SIGHUP ---")
-    global idps, signed
-    found = 0
-    removed = 0
-    old_idps = copy.deepcopy(idps)
-    for mdfile in sys.argv[1:]:
-        tree = ET.ElementTree(file=mdfile)
-        root = tree.getroot()
-        ns = copy.deepcopy(root.nsmap)
-        ns['xml'] = 'http://www.w3.org/XML/1998/namespace'
-        validUntil = root.get('validUntil')
-        cacheDuration = root.get('cacheDuration')
-        valid_until = parser.isoparse(validUntil)
-        cache_duration = parse_duration(cacheDuration)
-        if valid_until > datetime.now(tz.tzutc()):
-            for entity_descriptor in root.findall('md:EntityDescriptor', ns):
-                entityID = entity_descriptor.attrib.get('entityID', 'none')
-                sha1 = hasher(entityID)
-                print(f"{{sha1}}{sha1} {entityID}")
-                entity_descriptor.set('validUntil', validUntil)
-                entity_descriptor.set('cacheDuration', cacheDuration)
-                entity = Entity()
-                entity.md = entity_descriptor
-                entity.valid_until = valid_until
-                entity.cache_duration = cache_duration
-                idps[sha1] = entity
-                old_idps.pop(sha1, None)
-                # signed.pop(sha1, None)
-                found += 1
-        for idp in old_idps:
-            idps.pop(idp, None)
-            signed.pop(idp, None)
-            removed += 1
-
-    print(f"Found: {found} entities")
-    print(f"Removed: {removed} entities")
-    print(f"validUntil: {validUntil}")
-
 
-signal.signal(signal.SIGHUP, read_metadata)
+config = read_config()
 
-read_metadata(None, None)
+for domain, values in config.items():
+    print(f"domain: {domain}")
+    read_domain(domain, values)
+    signer[domain] = values['signer']
 
 app.run(host='127.0.0.1', port=5001)
diff --git a/mdserver.yaml.example b/mdserver.yaml.example
new file mode 100644
index 0000000..1d636e9
--- /dev/null
+++ b/mdserver.yaml.example
@@ -0,0 +1,7 @@
+---
+test:
+  signer: test_signer
+  metadir: metadata/test
+foobar:
+  signer: foobar_signer
+  metadir: metadata/foobar
diff --git a/requirements.txt b/requirements.txt
index 1a986fb..30cbf8e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,3 +4,5 @@ flask
 requests
 python-dateutil
 isoduration
+pyyaml
+pyinotify
diff --git a/utils.py b/utils.py
index d74d918..aae3bbc 100644
--- a/utils.py
+++ b/utils.py
@@ -1,5 +1,24 @@
+import os
+import copy
+import yaml
+from lxml import etree as ET
 import hashlib
 from signxml import XMLSigner
+from isoduration import parse_duration
+from dateutil import parser, tz
+from datetime import datetime
+import pyinotify
+
+cert = open("meta.crt").read()
+key = open("meta.key").read()
+
+watch_manager = pyinotify.WatchManager()
+
+watch_list = {}
+
+idps = {}
+signed = {}
+signer = {}
 
 
 class Entity(object):
@@ -8,6 +27,91 @@ class Entity(object):
     cache_duration = 0
 
 
+class EventProcessor(pyinotify.ProcessEvent):
+    def process_IN_CLOSE_WRITE(self, event):
+        domain = watch_list[event.path]
+        read_metadata(domain, event.path)
+
+
+class Signers(dict):
+    def __init__(self):
+        self['normal_signer'] = self._normal_signer
+        self['test_signer'] = self._test_signer
+        self['foobar_signer'] = self._foobar_signer
+
+    def _normal_signer(self, xml):
+        print("Normal signer")
+        return XMLSigner().sign(xml, key=key, cert=cert)
+
+    def _test_signer(self, xml):
+        print("Test signer")
+        return XMLSigner().sign(xml, key=key, cert=cert)
+
+    def _foobar_signer(self, xml):
+        print("Foobar signer")
+        return XMLSigner().sign(xml, key=key, cert=cert)
+
+
+def read_metadata(domain, mdfile):
+    print("--- READ METADATA --")
+    global idps, signed
+    found = 0
+    removed = 0
+    old_idps = copy.deepcopy(idps.get(domain, {}))
+    idps[domain] = idps.get(domain, {})
+    signed[domain] = signed.get(domain, {})
+    tree = ET.ElementTree(file=mdfile)
+    root = tree.getroot()
+    ns = copy.deepcopy(root.nsmap)
+    ns['xml'] = 'http://www.w3.org/XML/1998/namespace'
+    validUntil = root.get('validUntil')
+    cacheDuration = root.get('cacheDuration')
+    valid_until = parser.isoparse(validUntil)
+    cache_duration = parse_duration(cacheDuration)
+    if valid_until > datetime.now(tz.tzutc()):
+        for entity_descriptor in root.findall('md:EntityDescriptor', ns):
+            entityID = entity_descriptor.attrib.get('entityID', 'none')
+            sha1 = hasher(entityID)
+            print(f"{{sha1}}{sha1} {entityID}")
+            entity_descriptor.set('validUntil', validUntil)
+            entity_descriptor.set('cacheDuration', cacheDuration)
+            entity = Entity()
+            entity.md = entity_descriptor
+            entity.valid_until = valid_until
+            entity.cache_duration = cache_duration
+            idps[domain][sha1] = entity
+            signed[domain].pop(sha1, None)
+            old_idps.pop(sha1, None)
+            found += 1
+    for idp in old_idps:
+        idps[domain].pop(idp, None)
+        signed[domain].pop(idp, None)
+        removed += 1
+
+    print(f"Found: {found} entities")
+    print(f"Removed: {removed} entities")
+    print(f"validUntil: {validUntil}")
+
+
+def read_domain(domain, values):
+    metadir = values['metadir']
+    print(f" metadir: {metadir}")
+    files = os.listdir(metadir)
+    for file in files:
+        mdfile = os.path.realpath(os.path.join(metadir, file))
+        if os.path.isfile(mdfile):
+            print(f"  file: {mdfile}")
+            read_metadata(domain, mdfile)
+            watch_list[mdfile] = domain
+            watch_manager.add_watch(mdfile, pyinotify.IN_CLOSE_WRITE)
+
+
+def read_config():
+    with open('mdserver.yaml') as f:
+        config = yaml.safe_load(f)
+    return config
+
+
 def hasher(entity_id):
     sha1 = hashlib.sha1()
     sha1.update(entity_id.encode())
@@ -15,5 +119,6 @@ def hasher(entity_id):
     return sha1_digest
 
 
-def signer(xml, cert, key):
-    return XMLSigner().sign(xml, key=key, cert=cert)
+
+event_notifier = pyinotify.ThreadedNotifier(watch_manager, EventProcessor())
+event_notifier.start()
-- 
GitLab