From 32ee225740698704e244967c3b648595f92667bb Mon Sep 17 00:00:00 2001
From: Martin van Es <martin@mrvanes.com>
Date: Thu, 24 Feb 2022 12:35:44 +0100
Subject: [PATCH] A lot of expiration improvements

---
 constants.py     |  2 ++
 mdproxy.py       | 31 ++++++++++++++++---------------
 mdserver.py      | 28 +++++++++++++++++-----------
 requirements.txt |  2 +-
 utils.py         | 40 ++++++++++++++++++++++++----------------
 5 files changed, 60 insertions(+), 43 deletions(-)
 create mode 100644 constants.py

diff --git a/constants.py b/constants.py
new file mode 100644
index 0000000..56ff769
--- /dev/null
+++ b/constants.py
@@ -0,0 +1,2 @@
+MD_NAMESPACE = 'urn:oasis:names:tc:SAML:2.0:metadata'
+NSMAP = {'md': MD_NAMESPACE}
diff --git a/mdproxy.py b/mdproxy.py
index fa70127..d4bf9cd 100755
--- a/mdproxy.py
+++ b/mdproxy.py
@@ -5,7 +5,7 @@ from flask import Flask, Response
 from urllib.parse import unquote
 from dateutil import parser, tz
 from datetime import datetime
-from isoduration import parse_duration
+from isodate import parse_duration
 from email.utils import formatdate
 
 from utils import read_config, hasher, Entity, Server
@@ -22,10 +22,10 @@ app = Flask(__name__)
 cached = Server()
 
 
-@app.route('/<domain>/entities',
+@app.route('/<realm>/entities',
            strict_slashes=False,
            methods=['GET'])
-def serve_all(domain):
+def serve_all(realm):
     response = Response()
     response.headers['Content-Type'] = "application/samlmetadata+xml"
     response.headers['Content-Disposition'] = "filename = \"metadata.xml\""
@@ -43,7 +43,7 @@ def serve_all(domain):
 
     else:
         print("request all")
-        request = requests.get(f"{config[domain]['signer']}/{domain}"
+        request = requests.get(f"{config[realm]['signer']}/{realm}"
                                f"/entities")
         data = request.text
         last_modified = request.headers.get('Last-Modified',
@@ -64,7 +64,7 @@ def serve_all(domain):
                                         cached_entity.valid_until)
             cached_entity.last_modified = last_modified
             max_age = int((cached_entity.expires -
-                        datetime.now(tz.tzutc())).total_seconds())
+                           datetime.now(tz.tzutc())).total_seconds())
             cached_entity.max_age = max_age
             if cached_entity.expires > datetime.now(tz.tzutc()):
                 cached.all_entities = cached_entity
@@ -83,9 +83,10 @@ def serve_all(domain):
     response.data = data
     return response
 
-@app.route('/<domain>/entities/<path:eid>',
+
+@app.route('/<realm>/entities/<path:eid>',
            methods=['GET'])
-def serve_one(domain, eid):
+def serve_one(realm, eid):
     entityID = unquote(eid)
     if entityID[:6] == "{sha1}":
         entityID = entityID[6:]
@@ -96,20 +97,20 @@ def serve_one(domain, eid):
     response.headers['Content-Type'] = "application/samlmetadata+xml"
     response.headers['Content-Disposition'] = "filename = \"metadata.xml\""
 
-    cached[domain] = cached.get(domain, {})
-    if entityID in cached[domain]:
-        if cached[domain][entityID].expires > datetime.now(tz.tzutc()):
+    cached[realm] = cached.get(realm, {})
+    if entityID in cached[realm]:
+        if cached[realm][entityID].expires > datetime.now(tz.tzutc()):
             print(f"cache {entityID}")
-            max_age = int((cached[domain][entityID].expires -
+            max_age = int((cached[realm][entityID].expires -
                            datetime.now(tz.tzutc())).total_seconds())
-            last_modified = cached[domain][entityID].last_modified
+            last_modified = cached[realm][entityID].last_modified
             response.headers['Cache-Control'] = f"max-age={max_age}"
             response.headers['Last-Modified'] = last_modified
-            response.data = cached[domain][entityID].md
+            response.data = cached[realm][entityID].md
             return response
 
     print(f"request {entityID}")
-    request = requests.get(f"{config[domain]['signer']}/{domain}"
+    request = requests.get(f"{config[realm]['signer']}/{realm}"
                            f"/entities/{{sha1}}{entityID}")
     data = request.text
     last_modified = request.headers.get('Last-Modified',
@@ -129,7 +130,7 @@ def serve_one(domain, eid):
                                     cached_entity.valid_until)
         cached_entity.last_modified = last_modified
         if cached_entity.expires > datetime.now(tz.tzutc()):
-            cached[domain][entityID] = cached_entity
+            cached[realm][entityID] = cached_entity
             max_age = int((cached_entity.expires -
                            datetime.now(tz.tzutc())).total_seconds())
         else:
diff --git a/mdserver.py b/mdserver.py
index 124e70a..5942ecb 100755
--- a/mdserver.py
+++ b/mdserver.py
@@ -15,25 +15,31 @@ app = Flask(__name__)
 server = Server()
 
 
-@app.route('/<domain>/entities',
+@app.route('/<realm>/entities',
            strict_slashes=False,
            methods=['GET'])
-def serve_all(domain):
+def serve_all(realm):
+    dirty = False
+    for key, resource in server.items():
+        if resource.dirty:
+            dirty = True
+            resource.dirty = False
+
     response = Response()
     response.headers['Content-Type'] = "application/samlmetadata+xml"
     response.headers['Content-Disposition'] = "filename = \"metadata.xml\""
 
-    if server.all_entities is not None:
+    if server.all_entities is not None and not dirty:
         print("cache all")
         data = server.all_entities
         response.data = data.md
     else:
         print("sign all")
-        data = server[domain].all_entities()
+        data = server[realm].all_entities()
         response.data = data.md
         server.all_entities = data
 
-    max_age = int((data.expires - datetime.now(tz.tzutc())).total_seconds())
+    max_age = int((data.valid_until - datetime.now(tz.tzutc())).total_seconds())
 
     response.headers['Cache-Control'] = f"max-age={max_age}"
     response.headers['Last-Modified'] = formatdate(timeval=mktime(data.last_modified.timetuple()),
@@ -41,17 +47,17 @@ def serve_all(domain):
     return response
 
 
-@app.route('/<domain>/entities/<path:entity_id>',
+@app.route('/<realm>/entities/<path:entity_id>',
            strict_slashes=False,
            methods=['GET'])
-def serve_one(domain, entity_id):
+def serve_one(realm, entity_id):
     print(f"entity_id: {entity_id}")
     response = Response()
     response.headers['Content-Type'] = "application/samlmetadata+xml"
     response.headers['Content-Disposition'] = "filename = \"metadata.xml\""
 
     try:
-        data = server[domain][entity_id]
+        data = server[realm][entity_id]
         response.data = data.md
         max_age = data.max_age
         last_modified = data.last_modified
@@ -68,11 +74,11 @@ def serve_one(domain, entity_id):
     return response
 
 
-for domain, values in config.items():
-    print(f"domain: {domain}")
+for realm, values in config.items():
+    print(f"realm: {realm}")
     location = values['metadir']
     signer = values['signer']
-    server[domain] = Resource(location, signer)
+    server[realm] = Resource(location, signer)
 
 if __name__ == "__main__":
     app.run(host='127.0.0.1', port=5001, debug=False)
diff --git a/requirements.txt b/requirements.txt
index 7ce540a..6650931 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,7 +4,7 @@ signxml
 flask
 requests
 python-dateutil
-isoduration
+isodate2
 pyyaml
 pyinotify
 pyXMLSecurity
diff --git a/utils.py b/utils.py
index 554e4d2..9ec536f 100755
--- a/utils.py
+++ b/utils.py
@@ -1,13 +1,14 @@
 import os
 from lxml import etree as ET
 from dateutil import parser, tz
-from isoduration import parse_duration, format_duration
+from isodate import parse_duration, duration_isoformat
 from datetime import datetime, timedelta
 import hashlib
 from urllib.parse import unquote
 import yaml
 import pyinotify
 from signers import Signers
+from constants import MD_NAMESPACE, NSMAP
 
 
 def read_config(config):
@@ -37,6 +38,7 @@ class MData(object):
         self.max_age = (datetime.now(tz.tzutc()) +
                         timedelta(seconds=60))
         self.last_modified = 0
+        self.valid_until = 0
 
 
 class EventProcessor(pyinotify.ProcessEvent):
@@ -51,12 +53,15 @@ class EventProcessor(pyinotify.ProcessEvent):
         else:
             self.resource.read_metadata(event.pathname)
 
+
 class Server(dict):
     def __init__(self):
         self.all_entities = None
 
+
 class Resource:
     watch_list = {}
+    dirty = False
 
     def __init__(self, location, signer):
         self.idps = {}
@@ -86,7 +91,7 @@ class Resource:
                 self.read_metadata(mdfile)
 
         for mdf, idps in old_mdfiles.items():
-            print("\n--- REMOVE METADATA --")
+            print("--- REMOVE METADATA --")
             print(mdf)
             for idp in idps:
                 print(f"  {{sha1}}{idp}")
@@ -94,7 +99,7 @@ class Resource:
                 self.__dict__.pop(idp, None)
 
     def read_metadata(self, mdfile):
-        print("\n--- READ METADATA --")
+        print("--- READ METADATA --")
         print(mdfile)
         found = 0
         removed = 0
@@ -115,6 +120,7 @@ class Resource:
         cache_duration = parse_duration(cacheDuration)
         last_modified = datetime.now(tz.tzutc())
         if valid_until > datetime.now(tz.tzutc()):
+            self.dirty = True
             for entity_descriptor in root.findall('md:EntityDescriptor', ns):
                 entityID = entity_descriptor.attrib.get('entityID', 'none')
                 sha1 = hasher(entityID)
@@ -141,9 +147,8 @@ class Resource:
 
         self.mdfiles[mdfile] = list(mdfiles)
 
-        print(f"Found: {found} entities")
-        print(f"Removed: {removed} entities")
-        print(f"validUntil: {validUntil}")
+        print(f"  Found: {found} entities")
+        print(f"  Removed: {removed} entities\n")
 
     def __getitem__(self, key):
         entityID = unquote(key)
@@ -187,24 +192,27 @@ class Resource:
 
     def all_entities(self):
         data = MData()
-        ns = {'md': 'urn:oasis:names:tc:SAML:2.0:metadata'}
-        root = ET.Element('{urn:oasis:names:tc:SAML:2.0:metadata}EntitiesDescriptor', nsmap=ns)
+        ns = NSMAP
+        root = ET.Element(f"{{{MD_NAMESPACE}}}EntitiesDescriptor",
+                          nsmap=ns)
         # We are going to minimize expires, so set to some inf value
-        expires = (datetime.now(tz.tzutc()) +
-                        timedelta(days=365))
+        valid_until = (datetime.now(tz.tzutc()) +
+                       timedelta(days=365))
+        cache_duration = parse_duration("P1D")
         for sha1, entity in self.idps.items():
-            expires = min(expires, entity.expires)
+            valid_until = min(valid_until, entity.valid_until)
+            cache_duration = min(cache_duration, entity.cache_duration)
+            ET.strip_attributes(entity.md, 'validUntil', 'cacheDuration')
             root.append(entity.md)
 
-
+        vu_zulu = str(valid_until).replace('+00:00', 'Z')
+        root.set('validUntil', vu_zulu)
+        root.set('cacheDuration', duration_isoformat(cache_duration))
         last_modified = datetime.now(tz.tzutc())
-        ezulu = str(expires).replace('+00:00', 'Z')
-        root.set('validUntil', ezulu)
-        root.set('cacheDuration', "PT6H")
 
         signed_root = self.signer(root)
         data.md = ET.tostring(signed_root, pretty_print=True)
-        data.expires = expires
+        data.valid_until = valid_until
         data.last_modified = last_modified
 
         return data
-- 
GitLab