diff --git a/mdproxy.py b/mdproxy.py index f1e779bced0d0630b223b593f2ac3d92729378fd..284882336d14c2557b16cda184ee59fcd5e2eb9b 100755 --- a/mdproxy.py +++ b/mdproxy.py @@ -1,9 +1,7 @@ #!/usr/bin/env python import requests -import hashlib - from lxml import etree as ET -from flask import Flask +from flask import Flask, Response from urllib.parse import unquote from dateutil import parser, tz from datetime import datetime @@ -26,6 +24,10 @@ def serve(domain, eid): else: entityID = hasher(entityID) + response = Response() + response.headers['Content-Type'] = "application/samlmetadata+xml" + response.headers['Content-Disposition'] = "filename = \"metadata.xml\"" + cached[domain] = cached.get(domain, {}) if entityID in cached[domain]: if cached[domain][entityID].valid_until > datetime.now(tz.tzutc()): @@ -33,15 +35,22 @@ def serve(domain, eid): return cached[domain][entityID].md else: print(f"request {entityID}") - result = requests.get(f"{signer_url}/{domain}/entities/{{sha1}}{entityID}").text - parsed = ET.fromstring(result) - validUntil = parsed.get('validUntil') - # cacheDuration = parsed.get('cacheDuration') - cached_entity = Entity() - cached_entity.md = result - cached_entity.valid_until = parser.isoparse(validUntil) - cached[domain][entityID] = cached_entity - return result + data = requests.get(f"{signer_url}/{domain}/entities/{{sha1}}{entityID}").text + try: + parsed = ET.fromstring(data) + validUntil = parsed.get('validUntil') + # cacheDuration = parsed.get('cacheDuration') + cached_entity = Entity() + cached_entity.md = data + cached_entity.valid_until = parser.isoparse(validUntil) + cached[domain][entityID] = cached_entity + except ET.XMLSyntaxError: + data = "No valid metadata\n" + response.headers['Content-type'] = "text/html" + response.status = 404 + + response.data = data + return response app.run(host='0.0.0.0', port=5002) diff --git a/mdserver.py b/mdserver.py index dbb2e50894cfb0ff467b9352269b4b9d11a2e9a9..1384575453fff4671eb4edb79534b18248994424 100755 --- a/mdserver.py +++ b/mdserver.py @@ -1,68 +1,32 @@ #!/usr/bin/env python -from lxml import etree as ET +from utils import read_config, server, event_notifier from flask import Flask, Response -from urllib.parse import unquote -from dateutil import tz -from datetime import datetime -# import pyinotify -import traceback -from utils import read_config, read_domain, hasher, idps, \ - signed, signer, Signers, Entity, event_notifier - -signers = Signers() +config = read_config() app = Flask(__name__) -@app.route('/<domain>/entities/<path:eid>', methods=['GET']) -def serve(domain, eid): - entityID = unquote(eid) - if entityID[:6] == "{sha1}": - sha1 = entityID[6:] - else: - sha1 = hasher(entityID) - +@app.route('/<domain>/entities/<path:entity_id>', methods=['GET']) +def serve(domain, entity_id): response = Response() response.headers['Content-Type'] = "application/samlmetadata+xml" response.headers['Content-Disposition'] = "filename = \"metadata.xml\"" - if sha1 in signed[domain]: - signed_entity = signed[domain][sha1] - if signed_entity.valid_until > datetime.now(tz.tzutc()): - response.data = signed[domain][sha1].md - elif sha1 in idps[domain]: - try: - print(f"sign {domain} {sha1}") - valid_until = idps[domain][sha1].valid_until - if valid_until > datetime.now(tz.tzutc()): - signed_element = signers[signer[domain]](idps[domain][sha1].md) - signed_xml = ET.tostring(signed_element, pretty_print=True).decode() - signed_entity = Entity() - signed_entity.md = signed_xml - signed_entity.valid_until = idps[domain][sha1].valid_until - signed[domain][sha1] = signed_entity - response.data = signed_xml - except Exception as e: - print(sha1) - print(f" {e}") - traceback.print_exc() - else: + try: + response.data = server[domain][entity_id] + except Exception: response.data = "No valid metadata\n" response.headers['Content-type'] = "text/html" response.status = 404 - return response - print(f"serve {domain} {sha1}") return response -config = read_config() - for domain, values in config.items(): print(f"domain: {domain}") - read_domain(domain, values) - signer[domain] = values['signer'] + conf = (values['metadir'], values['signer']) + server[domain] = conf -app.run(host='127.0.0.1', port=5001) event_notifier.start() +app.run(host='127.0.0.1', port=5001) diff --git a/mdsigner.py b/mdsigner.py index 23c981d636f47a0daf4349606e2100aa5583cd9d..3dc533f991a1d0d98be9daba43306bbc7198dfa0 100755 --- a/mdsigner.py +++ b/mdsigner.py @@ -14,7 +14,7 @@ idps = [] success = 0 failed = 0 maxthreads = 8 -signer = Signers()['normal_signer'] +signer = Signers('normal_signer') def sign(xml, name): diff --git a/signers.py b/signers.py new file mode 100644 index 0000000000000000000000000000000000000000..b2a0b61ee661a575bb9fc8629b8824bc5b8a0b7e --- /dev/null +++ b/signers.py @@ -0,0 +1,26 @@ +from signxml import XMLSigner + +cert = open("meta.crt").read() +key = open("meta.key").read() + + +def Signers(signer): + def _normal_signer(xml): + print("Normal signer") + return XMLSigner().sign(xml, key=key, cert=cert) + + def _test_signer(xml): + print("Test signer") + return XMLSigner().sign(xml, key=key, cert=cert) + + def _foobar_signer(xml): + print("Foobar signer") + return XMLSigner().sign(xml, key=key, cert=cert) + + signers = { + 'normal_signer': _normal_signer, + 'test_signer': _test_signer, + 'foobar_signer': _foobar_signer + } + + return signers[signer] diff --git a/utils.py b/utils.py old mode 100644 new mode 100755 index 543982d13d05cf6f20c36c0ffd5805f3c8ff8a0c..23cb4f29b99574983ab1a3feb11f03afb7a92f1a --- a/utils.py +++ b/utils.py @@ -1,24 +1,22 @@ import os -import copy -import yaml from lxml import etree as ET -import hashlib -from signxml import XMLSigner -from isoduration import parse_duration from dateutil import parser, tz +from isoduration import parse_duration from datetime import datetime +import hashlib +from urllib.parse import unquote +import yaml import pyinotify +from signers import Signers -cert = open("meta.crt").read() -key = open("meta.key").read() - +watch_list = {} watch_manager = pyinotify.WatchManager() -watch_list = {} -idps = {} -signed = {} -signer = {} +def read_config(): + with open('mdserver.yaml') as f: + config = yaml.safe_load(f) + return config def hasher(entity_id): @@ -28,95 +26,119 @@ def hasher(entity_id): return sha1_digest -class Entity(object): - md = None - valid_until = 0 - cache_duration = 0 - - class EventProcessor(pyinotify.ProcessEvent): def process_IN_CLOSE_WRITE(self, event): domain = watch_list[event.path] - read_metadata(domain, event.path) + print(f"Notify {domain} {event.path}") + server[domain].walk_location(event.path) -class Signers(dict): +class Entity: def __init__(self): - self['normal_signer'] = self._normal_signer - self['test_signer'] = self._test_signer - self['foobar_signer'] = self._foobar_signer - - def _normal_signer(self, xml): - print("Normal signer") - return XMLSigner().sign(xml, key=key, cert=cert) - - def _test_signer(self, xml): - print("Test signer") - return XMLSigner().sign(xml, key=key, cert=cert) - - def _foobar_signer(self, xml): - print("Foobar signer") - return XMLSigner().sign(xml, key=key, cert=cert) - - -def read_metadata(domain, mdfile): - print("--- READ METADATA --") - global idps, signed - found = 0 - removed = 0 - old_idps = copy.deepcopy(idps.get(domain, {})) - idps[domain] = idps.get(domain, {}) - signed[domain] = signed.get(domain, {}) - tree = ET.ElementTree(file=mdfile) - root = tree.getroot() - ns = copy.deepcopy(root.nsmap) - ns['xml'] = 'http://www.w3.org/XML/1998/namespace' - validUntil = root.get('validUntil') - cacheDuration = root.get('cacheDuration') - valid_until = parser.isoparse(validUntil) - cache_duration = parse_duration(cacheDuration) - if valid_until > datetime.now(tz.tzutc()): - for entity_descriptor in root.findall('md:EntityDescriptor', ns): - entityID = entity_descriptor.attrib.get('entityID', 'none') + self.md = None + self.valid_until = 0 + self.cache_duration = 0 + + +class Resource: + def __init__(self, location, signer): + self.idps = {} + self.mdfiles = {} + self.signer = Signers(signer) + self.walk_location(location) + + def walk_location(self, location): + files = os.listdir(location) + for file in files: + mdfile = os.path.realpath(os.path.join(location, file)) + if os.path.isfile(mdfile): + self.mdfiles[mdfile] = [] + self._read_metadata(mdfile) + + def _read_metadata(self, mdfile): + print("--- READ METADATA --") + found = 0 + removed = 0 + old_idps = self.mdfiles[mdfile].copy() + print(f"old_idps: {old_idps}") + tree = ET.ElementTree(file=mdfile) + root = tree.getroot() + ns = root.nsmap.copy() + ns['xml'] = 'http://www.w3.org/XML/1998/namespace' + validUntil = root.get('validUntil') + cacheDuration = root.get('cacheDuration') + valid_until = parser.isoparse(validUntil) + cache_duration = parse_duration(cacheDuration) + if valid_until > datetime.now(tz.tzutc()): + for entity_descriptor in root.findall('md:EntityDescriptor', ns): + entityID = entity_descriptor.attrib.get('entityID', 'none') + sha1 = hasher(entityID) + print(f"{{sha1}}{sha1} {entityID}") + entity_descriptor.set('validUntil', validUntil) + entity_descriptor.set('cacheDuration', cacheDuration) + entity = Entity() + entity.md = entity_descriptor + entity.valid_until = valid_until + entity.cache_duration = cache_duration + self.idps[sha1] = entity + self.__dict__.pop(sha1, None) + if sha1 in self.mdfiles[mdfile]: + self.mdfiles[mdfile].remove(sha1) + found += 1 + for idp in old_idps: + self.idps.pop(idp, None) + self.__dict__.pop(idp, None) + removed += 1 + + self.mdfiles[mdfile] = self.idps.keys() + + print(f"Found: {found} entities") + print(f"Removed: {removed} entities") + print(f"validUntil: {validUntil}") + + def __getitem__(self, key): + entityID = unquote(key) + if entityID[:6] == "{sha1}": + sha1 = entityID[6:] + else: sha1 = hasher(entityID) - print(f"{{sha1}}{sha1} {entityID}") - entity_descriptor.set('validUntil', validUntil) - entity_descriptor.set('cacheDuration', cacheDuration) - entity = Entity() - entity.md = entity_descriptor - entity.valid_until = valid_until - entity.cache_duration = cache_duration - idps[domain][sha1] = entity - signed[domain].pop(sha1, None) - old_idps.pop(sha1, None) - found += 1 - for idp in old_idps: - idps[domain].pop(idp, None) - signed[domain].pop(idp, None) - removed += 1 - - print(f"Found: {found} entities") - print(f"Removed: {removed} entities") - print(f"validUntil: {validUntil}") - - -def read_domain(domain, values): - metadir = values['metadir'] - print(f" metadir: {metadir}") - files = os.listdir(metadir) - for file in files: - mdfile = os.path.realpath(os.path.join(metadir, file)) - if os.path.isfile(mdfile): - print(f" file: {mdfile}") - read_metadata(domain, mdfile) - watch_list[mdfile] = domain - watch_manager.add_watch(mdfile, pyinotify.IN_CLOSE_WRITE) - - -def read_config(): - with open('mdserver.yaml') as f: - config = yaml.safe_load(f) - return config - + if sha1 in self.__dict__: + signed_entity = self.__dict__[sha1] + if signed_entity.valid_until > datetime.now(tz.tzutc()): + data = self.__dict__[sha1].md + elif sha1 in self.idps: + try: + print(f"sign {sha1}") + valid_until = self.idps[sha1].valid_until + if valid_until > datetime.now(tz.tzutc()): + signed_element = self.signer(self.idps[sha1].md) + signed_xml = ET.tostring(signed_element, pretty_print=True).decode() + signed_entity = Entity() + signed_entity.md = signed_xml + signed_entity.valid_until = self.idps[sha1].valid_until + self.__dict__[sha1] = signed_entity + data = signed_xml + except Exception as e: + print(sha1) + print(f" {e}") + else: + raise KeyError + + print(f"serve {sha1}") + return data + + +class Server: + def __setitem__(self, domain, conf): + location, signer = conf + self.__dict__[domain] = Resource(location, signer) + watch_list[location] = domain + watch_manager.add_watch(location, pyinotify.IN_CLOSE_WRITE) + + def __getitem__(self, domain): + return self.__dict__[domain] + + +server = Server() event_notifier = pyinotify.ThreadedNotifier(watch_manager, EventProcessor())