diff --git a/mdproxy.py b/mdproxy.py index ef73bd30f8ee0a9b6f15dff0260a3d1804bf7846..eaab998897a103fa96ea1d8bd1bcc2566b9ffac3 100755 --- a/mdproxy.py +++ b/mdproxy.py @@ -1,5 +1,6 @@ #!/usr/bin/env python import requests +import hashlib from lxml import etree as ET from flask import Flask @@ -19,26 +20,37 @@ class Entity(object): md = None valid_until = 0 +def hasher(entity_id): + sha1 = hashlib.sha1() + sha1.update(entity_id.encode()) + sha1_digest = sha1.hexdigest() + sha1_identifier = sha1_digest + return sha1_identifier + @app.route('/cache/<path:eid>', methods=['GET']) def cache(eid): global cached - entity = unquote(eid) - print(f"entity: {entity}") - if entity in cached: - if cached[entity].valid_until > datetime.now(tz.tzutc()): - print(f"serve {entity}") - return cached[entity].md + entityID = unquote(eid) + if entityID[:6] == "{sha1}": + entityID = entityID[6:] + else: + entityID = hasher(entityID) + + if entityID in cached: + if cached[entityID].valid_until > datetime.now(tz.tzutc()): + print(f"serve {entityID}") + return cached[entityID].md else: - print(f"request {entity}") - result = requests.get(f"{signer}/{entity}").text + print(f"request {entityID}") + result = requests.get(f"{signer}/{{sha1}}{entityID}").text parsed = ET.fromstring(result) validUntil = parsed.get('validUntil') # cacheDuration = parsed.get('cacheDuration') - cached_entity = Entity + cached_entity = Entity() cached_entity.md = result cached_entity.valid_until = parser.isoparse(validUntil) - cached[entity] = cached_entity + cached[entityID] = cached_entity return result diff --git a/mdserver.py b/mdserver.py index 0857f8a0308390655b02a95e2667ecf4c0fa82dd..12b7d669c86da6487a59b52157022f0055ff08ee 100755 --- a/mdserver.py +++ b/mdserver.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import sys import copy +import hashlib from lxml import etree as ET from signxml import XMLSigner @@ -9,7 +10,7 @@ from urllib.parse import unquote from dateutil import parser, tz from datetime import datetime -# import hashlib +import traceback app = Flask(__name__) @@ -28,34 +29,48 @@ class Entity(object): valid_until = 0 +def hasher(entity_id): + sha1 = hashlib.sha1() + sha1.update(entity_id.encode()) + sha1_digest = sha1.hexdigest() + return sha1_digest + + def signer(xml): global cert, key + print(xml) return XMLSigner().sign(xml, key=key, cert=cert) @app.route('/sign/<path:eid>', methods=['GET']) def sign(eid): global idps, signed - entity = unquote(eid) - if entity in signed: - signed_entity = signed[entity] + entityID = unquote(eid) + if entityID[:6] == "{sha1}": + entityID = entityID[6:] + else: + entityID = hasher(entityID) + + if entityID in signed: + signed_entity = signed[entityID] if signed_entity.valid_until > datetime.now(tz.tzutc()): - print(f"serve {entity}") - return signed[entity].md + print(f"serve {entityID}") + return signed[entityID].md - if entity in idps: + if entityID in idps: try: - print(f"sign {entity}") - signed_element = signer(idps[entity].md) + print(f"sign {entityID}") + signed_element = signer(idps[entityID].md) signed_xml = ET.tostring(signed_element, pretty_print=True).decode() - signed_entity = Entity + signed_entity = Entity() signed_entity.md = signed_xml - signed_entity.valid_until = idps[entity].valid_until - signed[entity] = signed_entity + signed_entity.valid_until = idps[entityID].valid_until + signed[entityID] = signed_entity return signed_xml except Exception as e: - print(entity) + print(entityID) print(f" {e}") + traceback.print_exc() return "No valid metadata\n", 404 @@ -69,14 +84,16 @@ for mdfile in sys.argv[1:]: cacheDuration = root.get('cacheDuration') for entity_descriptor in root.findall('md:EntityDescriptor', ns): entityID = entity_descriptor.attrib.get('entityID', 'none') + sha1 = hasher(entityID) entity_descriptor.set('validUntil', validUntil) entity_descriptor.set('cacheDuration', cacheDuration) - entity = Entity + entity = Entity() entity.md = entity_descriptor entity.valid_until = parser.isoparse(validUntil) - if entityID not in idps: + if sha1 not in idps: print(entityID) - idps[entityID] = entity + print(sha1) + idps[sha1] = entity found += 1 print(f"Found: {found}")