diff --git a/mdproxy.py b/mdproxy.py index 51ebe350ad4dacba5b9da4603e43648258ca4b12..fa701272e94c1e933f2988c464f04e67257d86d5 100755 --- a/mdproxy.py +++ b/mdproxy.py @@ -8,7 +8,7 @@ from datetime import datetime from isoduration import parse_duration from email.utils import formatdate -from utils import read_config, hasher, Entity +from utils import read_config, hasher, Entity, Server import logging log = logging.getLogger('werkzeug') @@ -19,11 +19,73 @@ app = Flask(__name__) # Find all IdP's in edugain metadata -cached = {} +cached = Server() -@app.route('/<domain>/entities/<path:eid>', methods=['GET']) -def serve(domain, eid): +@app.route('/<domain>/entities', + strict_slashes=False, + methods=['GET']) +def serve_all(domain): + response = Response() + response.headers['Content-Type'] = "application/samlmetadata+xml" + response.headers['Content-Disposition'] = "filename = \"metadata.xml\"" + + if cached.all_entities is not None: + print("cache all") + cache = cached.all_entities + data = cache.md + last_modified = cache.last_modified + expires = min(datetime.now(tz.tzutc()) + + cache.cache_duration, + cache.valid_until) + max_age = int((expires - + datetime.now(tz.tzutc())).total_seconds()) + + else: + print("request all") + request = requests.get(f"{config[domain]['signer']}/{domain}" + f"/entities") + data = request.text + last_modified = request.headers.get('Last-Modified', + formatdate(timeval=None, + localtime=False, + usegmt=True)) + + try: + root = ET.fromstring(data) + validUntil = root.get('validUntil') + cacheDuration = root.get('cacheDuration') + cached_entity = Entity() + cached_entity.md = data + cached_entity.valid_until = parser.isoparse(validUntil) + cached_entity.cache_duration = parse_duration(cacheDuration) + cached_entity.expires = min(datetime.now(tz.tzutc()) + + cached_entity.cache_duration, + cached_entity.valid_until) + cached_entity.last_modified = last_modified + max_age = int((cached_entity.expires - + datetime.now(tz.tzutc())).total_seconds()) + cached_entity.max_age = max_age + if cached_entity.expires > datetime.now(tz.tzutc()): + cached.all_entities = cached_entity + else: + raise KeyError + except Exception as e: + print(f"{e}") + data = "No valid metadata\n" + max_age = 60 + response.headers['Content-type'] = "text/html" + response.headers['Cache-Control'] = "max-age=60" + response.status = 404 + + response.headers['Cache-Control'] = f"max-age={max_age}" + response.headers['Last-Modified'] = last_modified + response.data = data + return response + +@app.route('/<domain>/entities/<path:eid>', + methods=['GET']) +def serve_one(domain, eid): entityID = unquote(eid) if entityID[:6] == "{sha1}": entityID = entityID[6:] diff --git a/mdserver.py b/mdserver.py index e587b0c2e8947c9efd4d689d3bb7c89ffba3420a..124e70aff88201c94e75c42ca7568e66590b3b35 100755 --- a/mdserver.py +++ b/mdserver.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -from utils import read_config, Resource +from utils import read_config, Resource, Server from flask import Flask, Response from datetime import datetime from dateutil import tz @@ -12,12 +12,40 @@ log.setLevel(logging.ERROR) config = read_config('mdserver.yaml') app = Flask(__name__) -server = {} +server = Server() + + +@app.route('/<domain>/entities', + strict_slashes=False, + methods=['GET']) +def serve_all(domain): + response = Response() + response.headers['Content-Type'] = "application/samlmetadata+xml" + response.headers['Content-Disposition'] = "filename = \"metadata.xml\"" + + if server.all_entities is not None: + print("cache all") + data = server.all_entities + response.data = data.md + else: + print("sign all") + data = server[domain].all_entities() + response.data = data.md + server.all_entities = data + + max_age = int((data.expires - datetime.now(tz.tzutc())).total_seconds()) + + response.headers['Cache-Control'] = f"max-age={max_age}" + response.headers['Last-Modified'] = formatdate(timeval=mktime(data.last_modified.timetuple()), + localtime=False, usegmt=True) + return response @app.route('/<domain>/entities/<path:entity_id>', + strict_slashes=False, methods=['GET']) -def serve(domain, entity_id): +def serve_one(domain, entity_id): + print(f"entity_id: {entity_id}") response = Response() response.headers['Content-Type'] = "application/samlmetadata+xml" response.headers['Content-Disposition'] = "filename = \"metadata.xml\"" diff --git a/utils.py b/utils.py index b5b7d463e725cfff2651b269b4abc0b98db2363a..554e4d2e458110bb5b1a9fe0bd86f0328c6cff28 100755 --- a/utils.py +++ b/utils.py @@ -1,7 +1,7 @@ import os from lxml import etree as ET from dateutil import parser, tz -from isoduration import parse_duration +from isoduration import parse_duration, format_duration from datetime import datetime, timedelta import hashlib from urllib.parse import unquote @@ -36,6 +36,7 @@ class MData(object): self.md = None self.max_age = (datetime.now(tz.tzutc()) + timedelta(seconds=60)) + self.last_modified = 0 class EventProcessor(pyinotify.ProcessEvent): @@ -50,6 +51,9 @@ class EventProcessor(pyinotify.ProcessEvent): else: self.resource.read_metadata(event.pathname) +class Server(dict): + def __init__(self): + self.all_entities = None class Resource: watch_list = {} @@ -103,6 +107,8 @@ class Resource: return ns = root.nsmap.copy() ns['xml'] = 'http://www.w3.org/XML/1998/namespace' + for signature in root.findall('.//ds:Signature', ns): + signature.getparent().remove(signature) validUntil = root.get('validUntil') cacheDuration = root.get('cacheDuration') valid_until = parser.isoparse(validUntil) @@ -153,7 +159,6 @@ class Resource: print(f"cache {sha1}") data.md = self.__dict__[sha1].md - if data.md is None and sha1 in self.idps: try: print(f"sign {sha1}") @@ -164,8 +169,8 @@ class Resource: pretty_print=True).decode() signed_entity = Entity() signed_entity.md = signed_xml - signed_entity.expires = (datetime.now(tz.tzutc()) - + self.idps[sha1].cache_duration) + signed_entity.expires = (datetime.now(tz.tzutc()) + + self.idps[sha1].cache_duration) signed_entity.last_modified = self.idps[sha1].last_modified self.__dict__[sha1] = signed_entity data.md = signed_xml @@ -179,3 +184,27 @@ class Resource: datetime.now(tz.tzutc())).total_seconds()) data.last_modified = signed_entity.last_modified return data + + def all_entities(self): + data = MData() + ns = {'md': 'urn:oasis:names:tc:SAML:2.0:metadata'} + root = ET.Element('{urn:oasis:names:tc:SAML:2.0:metadata}EntitiesDescriptor', nsmap=ns) + # We are going to minimize expires, so set to some inf value + expires = (datetime.now(tz.tzutc()) + + timedelta(days=365)) + for sha1, entity in self.idps.items(): + expires = min(expires, entity.expires) + root.append(entity.md) + + + last_modified = datetime.now(tz.tzutc()) + ezulu = str(expires).replace('+00:00', 'Z') + root.set('validUntil', ezulu) + root.set('cacheDuration', "PT6H") + + signed_root = self.signer(root) + data.md = ET.tostring(signed_root, pretty_print=True) + data.expires = expires + data.last_modified = last_modified + + return data