diff --git a/mdproxy.py b/mdproxy.py index 1ba29c87dcc90b4073222a598f26a01c29bf6465..a6600f0757049103e0d2543ea4bc3fd6a61ca517 100755 --- a/mdproxy.py +++ b/mdproxy.py @@ -5,6 +5,7 @@ from flask import Flask, Response from urllib.parse import unquote from dateutil import parser, tz from datetime import datetime +from isoduration import parse_duration from utils import read_config, hasher, Entity @@ -30,27 +31,35 @@ def serve(domain, eid): cached[domain] = cached.get(domain, {}) if entityID in cached[domain]: - if cached[domain][entityID].valid_until > datetime.now(tz.tzutc()): + if cached[domain][entityID].expires > datetime.now(tz.tzutc()): print(f"serve {entityID}") return cached[domain][entityID].md - else: - print(f"request {entityID}") - data = requests.get(f"{config[domain]['signer']}/{domain}/entities/{{sha1}}{entityID}").text - try: - parsed = ET.fromstring(data) - validUntil = parsed.get('validUntil') - # cacheDuration = parsed.get('cacheDuration') - cached_entity = Entity() - cached_entity.md = data - cached_entity.valid_until = parser.isoparse(validUntil) + + print(f"request {entityID}") + data = requests.get(f"{config[domain]['signer']}/{domain}" + f"/entities/{{sha1}}{entityID}").text + try: + root = ET.fromstring(data) + validUntil = root.get('validUntil') + cacheDuration = root.get('cacheDuration') + cached_entity = Entity() + cached_entity.md = data + cached_entity.valid_until = parser.isoparse(validUntil) + cached_entity.cache_duration = parse_duration(cacheDuration) + cached_entity.expires = min(datetime.now(tz.tzutc()) + + cached_entity.cache_duration, + cached_entity.valid_until) + if cached_entity.valid_until > datetime.now(tz.tzutc()): cached[domain][entityID] = cached_entity - except ET.XMLSyntaxError: - data = "No valid metadata\n" - response.headers['Content-type'] = "text/html" - response.status = 404 + else: + raise KeyError + except Exception: + data = "No valid metadata\n" + response.headers['Content-type'] = "text/html" + response.status = 404 - response.data = data - return response + response.data = data + return response -app.run(host='0.0.0.0', port=80) +app.run(host='127.0.0.1', port=5002) diff --git a/mdserver.py b/mdserver.py index 0ad714b1404e5eb7f05c69d613db9e608a632d21..fbbb34753cd960d8ffda91acb10e43bfa4c55308 100755 --- a/mdserver.py +++ b/mdserver.py @@ -6,7 +6,9 @@ config = read_config('mdserver.yaml') app = Flask(__name__) server = Server() -@app.route('/<domain>/entities/<path:entity_id>', methods=['GET']) + +@app.route('/<domain>/entities/<path:entity_id>', + methods=['GET']) def serve(domain, entity_id): response = Response() response.headers['Content-Type'] = "application/samlmetadata+xml" @@ -31,4 +33,4 @@ for domain, values in config.items(): if __name__ == "__main__": - app.run(host='0.0.0.0', port=5001, debug=False) + app.run(host='127.0.0.1', port=5001, debug=False) diff --git a/utils.py b/utils.py index fbfa82e41f16e5c16ebe430b7c07faa45bcb7182..d3eb6f6a8aca373ccdc8505b78c15793c2feaa16 100755 --- a/utils.py +++ b/utils.py @@ -54,8 +54,7 @@ class Resource: found = 0 removed = 0 old_idps = self.mdfiles[mdfile].copy() - tree = ET.ElementTree(file=mdfile) - root = tree.getroot() + root = ET.ElementTree(file=mdfile).getroot() ns = root.nsmap.copy() ns['xml'] = 'http://www.w3.org/XML/1998/namespace' validUntil = root.get('validUntil') @@ -73,6 +72,8 @@ class Resource: entity.md = entity_descriptor entity.valid_until = valid_until entity.cache_duration = cache_duration + entity.expires = min(datetime.now(tz.tzutc()) + cache_duration, + valid_until) self.idps[sha1] = entity self.__dict__.pop(sha1, None) if sha1 in self.mdfiles[mdfile]: @@ -96,27 +97,31 @@ class Resource: else: sha1 = hasher(entityID) + data = None if sha1 in self.__dict__: signed_entity = self.__dict__[sha1] - if signed_entity.valid_until > datetime.now(tz.tzutc()): + if signed_entity.expires > datetime.now(tz.tzutc()): data = self.__dict__[sha1].md - elif sha1 in self.idps: + + if data is None and sha1 in self.idps: try: print(f"sign {sha1}") valid_until = self.idps[sha1].valid_until if valid_until > datetime.now(tz.tzutc()): signed_element = self.signer(self.idps[sha1].md) - signed_xml = ET.tostring(signed_element, pretty_print=True).decode() + signed_xml = ET.tostring(signed_element, + pretty_print=True).decode() signed_entity = Entity() signed_entity.md = signed_xml - signed_entity.valid_until = self.idps[sha1].valid_until + signed_entity.expires = (datetime.now(tz.tzutc()) + + self.idps[sha1].cache_duration) self.__dict__[sha1] = signed_entity data = signed_xml + else: + raise KeyError except Exception as e: print(sha1) print(f" {e}") - else: - raise KeyError print(f"serve {sha1}") return data @@ -135,7 +140,8 @@ class Server: def __init__(self): self.watch_manager = pyinotify.WatchManager() - self.event_notifier = pyinotify.ThreadedNotifier(self.watch_manager, EventProcessor(self)) + self.event_notifier = pyinotify.ThreadedNotifier(self.watch_manager, + EventProcessor(self)) self.event_notifier.start() def add_watch(self, domain, location):