From f8ed9aac3bf10b41a3041910d64a284b95a605ba Mon Sep 17 00:00:00 2001
From: Martin van Es <martin@mrvanes.com>
Date: Fri, 12 Nov 2021 11:06:40 +0100
Subject: [PATCH] Allow entities as {sha1} hashes

---
 mdproxy.py  | 32 ++++++++++++++++++++++----------
 mdserver.py | 49 +++++++++++++++++++++++++++++++++----------------
 2 files changed, 55 insertions(+), 26 deletions(-)

diff --git a/mdproxy.py b/mdproxy.py
index ef73bd3..eaab998 100755
--- a/mdproxy.py
+++ b/mdproxy.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env python
 import requests
+import hashlib
 
 from lxml import etree as ET
 from flask import Flask
@@ -19,26 +20,37 @@ class Entity(object):
     md = None
     valid_until = 0
 
+def hasher(entity_id):
+    sha1 = hashlib.sha1()
+    sha1.update(entity_id.encode())
+    sha1_digest = sha1.hexdigest()
+    sha1_identifier = sha1_digest
+    return sha1_identifier
+
 
 @app.route('/cache/<path:eid>', methods=['GET'])
 def cache(eid):
     global cached
-    entity = unquote(eid)
-    print(f"entity: {entity}")
-    if entity in cached:
-        if cached[entity].valid_until > datetime.now(tz.tzutc()):
-            print(f"serve {entity}")
-            return cached[entity].md
+    entityID = unquote(eid)
+    if entityID[:6] == "{sha1}":
+        entityID = entityID[6:]
+    else:
+        entityID = hasher(entityID)
+
+    if entityID in cached:
+        if cached[entityID].valid_until > datetime.now(tz.tzutc()):
+            print(f"serve {entityID}")
+            return cached[entityID].md
     else:
-        print(f"request {entity}")
-        result = requests.get(f"{signer}/{entity}").text
+        print(f"request {entityID}")
+        result = requests.get(f"{signer}/{{sha1}}{entityID}").text
         parsed = ET.fromstring(result)
         validUntil = parsed.get('validUntil')
         # cacheDuration = parsed.get('cacheDuration')
-        cached_entity = Entity
+        cached_entity = Entity()
         cached_entity.md = result
         cached_entity.valid_until = parser.isoparse(validUntil)
-        cached[entity] = cached_entity
+        cached[entityID] = cached_entity
         return result
 
 
diff --git a/mdserver.py b/mdserver.py
index 0857f8a..12b7d66 100755
--- a/mdserver.py
+++ b/mdserver.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python
 import sys
 import copy
+import hashlib
 
 from lxml import etree as ET
 from signxml import XMLSigner
@@ -9,7 +10,7 @@ from urllib.parse import unquote
 from dateutil import parser, tz
 from datetime import datetime
 
-# import hashlib
+import traceback
 
 app = Flask(__name__)
 
@@ -28,34 +29,48 @@ class Entity(object):
     valid_until = 0
 
 
+def hasher(entity_id):
+    sha1 = hashlib.sha1()
+    sha1.update(entity_id.encode())
+    sha1_digest = sha1.hexdigest()
+    return sha1_digest
+
+
 def signer(xml):
     global cert, key
+    print(xml)
     return XMLSigner().sign(xml, key=key, cert=cert)
 
 
 @app.route('/sign/<path:eid>', methods=['GET'])
 def sign(eid):
     global idps, signed
-    entity = unquote(eid)
-    if entity in signed:
-        signed_entity = signed[entity]
+    entityID = unquote(eid)
+    if entityID[:6] == "{sha1}":
+        entityID = entityID[6:]
+    else:
+        entityID = hasher(entityID)
+
+    if entityID in signed:
+        signed_entity = signed[entityID]
         if signed_entity.valid_until > datetime.now(tz.tzutc()):
-            print(f"serve {entity}")
-            return signed[entity].md
+            print(f"serve {entityID}")
+            return signed[entityID].md
 
-    if entity in idps:
+    if entityID in idps:
         try:
-            print(f"sign {entity}")
-            signed_element = signer(idps[entity].md)
+            print(f"sign {entityID}")
+            signed_element = signer(idps[entityID].md)
             signed_xml = ET.tostring(signed_element, pretty_print=True).decode()
-            signed_entity = Entity
+            signed_entity = Entity()
             signed_entity.md = signed_xml
-            signed_entity.valid_until = idps[entity].valid_until
-            signed[entity] = signed_entity
+            signed_entity.valid_until = idps[entityID].valid_until
+            signed[entityID] = signed_entity
             return signed_xml
         except Exception as e:
-            print(entity)
+            print(entityID)
             print(f"  {e}")
+            traceback.print_exc()
 
     return "No valid metadata\n", 404
 
@@ -69,14 +84,16 @@ for mdfile in sys.argv[1:]:
     cacheDuration = root.get('cacheDuration')
     for entity_descriptor in root.findall('md:EntityDescriptor', ns):
         entityID = entity_descriptor.attrib.get('entityID', 'none')
+        sha1 = hasher(entityID)
         entity_descriptor.set('validUntil', validUntil)
         entity_descriptor.set('cacheDuration', cacheDuration)
-        entity = Entity
+        entity = Entity()
         entity.md = entity_descriptor
         entity.valid_until = parser.isoparse(validUntil)
-        if entityID not in idps:
+        if sha1 not in idps:
             print(entityID)
-            idps[entityID] = entity
+            print(sha1)
+            idps[sha1] = entity
             found += 1
 
 print(f"Found: {found}")
-- 
GitLab