Skip to content
Snippets Groups Projects
Commit 420ce3ab authored by Martin van Es's avatar Martin van Es
Browse files

Rewrite mdserver classes

parent 16f454fc
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python #!/usr/bin/env python
import requests import requests
import hashlib
from lxml import etree as ET from lxml import etree as ET
from flask import Flask from flask import Flask, Response
from urllib.parse import unquote from urllib.parse import unquote
from dateutil import parser, tz from dateutil import parser, tz
from datetime import datetime from datetime import datetime
...@@ -26,6 +24,10 @@ def serve(domain, eid): ...@@ -26,6 +24,10 @@ def serve(domain, eid):
else: else:
entityID = hasher(entityID) entityID = hasher(entityID)
response = Response()
response.headers['Content-Type'] = "application/samlmetadata+xml"
response.headers['Content-Disposition'] = "filename = \"metadata.xml\""
cached[domain] = cached.get(domain, {}) cached[domain] = cached.get(domain, {})
if entityID in cached[domain]: if entityID in cached[domain]:
if cached[domain][entityID].valid_until > datetime.now(tz.tzutc()): if cached[domain][entityID].valid_until > datetime.now(tz.tzutc()):
...@@ -33,15 +35,22 @@ def serve(domain, eid): ...@@ -33,15 +35,22 @@ def serve(domain, eid):
return cached[domain][entityID].md return cached[domain][entityID].md
else: else:
print(f"request {entityID}") print(f"request {entityID}")
result = requests.get(f"{signer_url}/{domain}/entities/{{sha1}}{entityID}").text data = requests.get(f"{signer_url}/{domain}/entities/{{sha1}}{entityID}").text
parsed = ET.fromstring(result) try:
validUntil = parsed.get('validUntil') parsed = ET.fromstring(data)
# cacheDuration = parsed.get('cacheDuration') validUntil = parsed.get('validUntil')
cached_entity = Entity() # cacheDuration = parsed.get('cacheDuration')
cached_entity.md = result cached_entity = Entity()
cached_entity.valid_until = parser.isoparse(validUntil) cached_entity.md = data
cached[domain][entityID] = cached_entity cached_entity.valid_until = parser.isoparse(validUntil)
return result cached[domain][entityID] = cached_entity
except ET.XMLSyntaxError:
data = "No valid metadata\n"
response.headers['Content-type'] = "text/html"
response.status = 404
response.data = data
return response
app.run(host='0.0.0.0', port=5002) app.run(host='0.0.0.0', port=5002)
#!/usr/bin/env python #!/usr/bin/env python
from lxml import etree as ET from utils import read_config, server, event_notifier
from flask import Flask, Response from flask import Flask, Response
from urllib.parse import unquote
from dateutil import tz
from datetime import datetime
# import pyinotify
import traceback config = read_config()
from utils import read_config, read_domain, hasher, idps, \
signed, signer, Signers, Entity, event_notifier
signers = Signers()
app = Flask(__name__) app = Flask(__name__)
@app.route('/<domain>/entities/<path:eid>', methods=['GET']) @app.route('/<domain>/entities/<path:entity_id>', methods=['GET'])
def serve(domain, eid): def serve(domain, entity_id):
entityID = unquote(eid)
if entityID[:6] == "{sha1}":
sha1 = entityID[6:]
else:
sha1 = hasher(entityID)
response = Response() response = Response()
response.headers['Content-Type'] = "application/samlmetadata+xml" response.headers['Content-Type'] = "application/samlmetadata+xml"
response.headers['Content-Disposition'] = "filename = \"metadata.xml\"" response.headers['Content-Disposition'] = "filename = \"metadata.xml\""
if sha1 in signed[domain]: try:
signed_entity = signed[domain][sha1] response.data = server[domain][entity_id]
if signed_entity.valid_until > datetime.now(tz.tzutc()): except Exception:
response.data = signed[domain][sha1].md
elif sha1 in idps[domain]:
try:
print(f"sign {domain} {sha1}")
valid_until = idps[domain][sha1].valid_until
if valid_until > datetime.now(tz.tzutc()):
signed_element = signers[signer[domain]](idps[domain][sha1].md)
signed_xml = ET.tostring(signed_element, pretty_print=True).decode()
signed_entity = Entity()
signed_entity.md = signed_xml
signed_entity.valid_until = idps[domain][sha1].valid_until
signed[domain][sha1] = signed_entity
response.data = signed_xml
except Exception as e:
print(sha1)
print(f" {e}")
traceback.print_exc()
else:
response.data = "No valid metadata\n" response.data = "No valid metadata\n"
response.headers['Content-type'] = "text/html" response.headers['Content-type'] = "text/html"
response.status = 404 response.status = 404
return response
print(f"serve {domain} {sha1}")
return response return response
config = read_config()
for domain, values in config.items(): for domain, values in config.items():
print(f"domain: {domain}") print(f"domain: {domain}")
read_domain(domain, values) conf = (values['metadir'], values['signer'])
signer[domain] = values['signer'] server[domain] = conf
app.run(host='127.0.0.1', port=5001)
event_notifier.start() event_notifier.start()
app.run(host='127.0.0.1', port=5001)
...@@ -14,7 +14,7 @@ idps = [] ...@@ -14,7 +14,7 @@ idps = []
success = 0 success = 0
failed = 0 failed = 0
maxthreads = 8 maxthreads = 8
signer = Signers()['normal_signer'] signer = Signers('normal_signer')
def sign(xml, name): def sign(xml, name):
......
from signxml import XMLSigner
cert = open("meta.crt").read()
key = open("meta.key").read()
def Signers(signer):
def _normal_signer(xml):
print("Normal signer")
return XMLSigner().sign(xml, key=key, cert=cert)
def _test_signer(xml):
print("Test signer")
return XMLSigner().sign(xml, key=key, cert=cert)
def _foobar_signer(xml):
print("Foobar signer")
return XMLSigner().sign(xml, key=key, cert=cert)
signers = {
'normal_signer': _normal_signer,
'test_signer': _test_signer,
'foobar_signer': _foobar_signer
}
return signers[signer]
utils.py 100644 → 100755
import os import os
import copy
import yaml
from lxml import etree as ET from lxml import etree as ET
import hashlib
from signxml import XMLSigner
from isoduration import parse_duration
from dateutil import parser, tz from dateutil import parser, tz
from isoduration import parse_duration
from datetime import datetime from datetime import datetime
import hashlib
from urllib.parse import unquote
import yaml
import pyinotify import pyinotify
from signers import Signers
cert = open("meta.crt").read() watch_list = {}
key = open("meta.key").read()
watch_manager = pyinotify.WatchManager() watch_manager = pyinotify.WatchManager()
watch_list = {}
idps = {} def read_config():
signed = {} with open('mdserver.yaml') as f:
signer = {} config = yaml.safe_load(f)
return config
def hasher(entity_id): def hasher(entity_id):
...@@ -28,95 +26,119 @@ def hasher(entity_id): ...@@ -28,95 +26,119 @@ def hasher(entity_id):
return sha1_digest return sha1_digest
class Entity(object):
md = None
valid_until = 0
cache_duration = 0
class EventProcessor(pyinotify.ProcessEvent): class EventProcessor(pyinotify.ProcessEvent):
def process_IN_CLOSE_WRITE(self, event): def process_IN_CLOSE_WRITE(self, event):
domain = watch_list[event.path] domain = watch_list[event.path]
read_metadata(domain, event.path) print(f"Notify {domain} {event.path}")
server[domain].walk_location(event.path)
class Signers(dict): class Entity:
def __init__(self): def __init__(self):
self['normal_signer'] = self._normal_signer self.md = None
self['test_signer'] = self._test_signer self.valid_until = 0
self['foobar_signer'] = self._foobar_signer self.cache_duration = 0
def _normal_signer(self, xml):
print("Normal signer") class Resource:
return XMLSigner().sign(xml, key=key, cert=cert) def __init__(self, location, signer):
self.idps = {}
def _test_signer(self, xml): self.mdfiles = {}
print("Test signer") self.signer = Signers(signer)
return XMLSigner().sign(xml, key=key, cert=cert) self.walk_location(location)
def _foobar_signer(self, xml): def walk_location(self, location):
print("Foobar signer") files = os.listdir(location)
return XMLSigner().sign(xml, key=key, cert=cert) for file in files:
mdfile = os.path.realpath(os.path.join(location, file))
if os.path.isfile(mdfile):
def read_metadata(domain, mdfile): self.mdfiles[mdfile] = []
print("--- READ METADATA --") self._read_metadata(mdfile)
global idps, signed
found = 0 def _read_metadata(self, mdfile):
removed = 0 print("--- READ METADATA --")
old_idps = copy.deepcopy(idps.get(domain, {})) found = 0
idps[domain] = idps.get(domain, {}) removed = 0
signed[domain] = signed.get(domain, {}) old_idps = self.mdfiles[mdfile].copy()
tree = ET.ElementTree(file=mdfile) print(f"old_idps: {old_idps}")
root = tree.getroot() tree = ET.ElementTree(file=mdfile)
ns = copy.deepcopy(root.nsmap) root = tree.getroot()
ns['xml'] = 'http://www.w3.org/XML/1998/namespace' ns = root.nsmap.copy()
validUntil = root.get('validUntil') ns['xml'] = 'http://www.w3.org/XML/1998/namespace'
cacheDuration = root.get('cacheDuration') validUntil = root.get('validUntil')
valid_until = parser.isoparse(validUntil) cacheDuration = root.get('cacheDuration')
cache_duration = parse_duration(cacheDuration) valid_until = parser.isoparse(validUntil)
if valid_until > datetime.now(tz.tzutc()): cache_duration = parse_duration(cacheDuration)
for entity_descriptor in root.findall('md:EntityDescriptor', ns): if valid_until > datetime.now(tz.tzutc()):
entityID = entity_descriptor.attrib.get('entityID', 'none') for entity_descriptor in root.findall('md:EntityDescriptor', ns):
entityID = entity_descriptor.attrib.get('entityID', 'none')
sha1 = hasher(entityID)
print(f"{{sha1}}{sha1} {entityID}")
entity_descriptor.set('validUntil', validUntil)
entity_descriptor.set('cacheDuration', cacheDuration)
entity = Entity()
entity.md = entity_descriptor
entity.valid_until = valid_until
entity.cache_duration = cache_duration
self.idps[sha1] = entity
self.__dict__.pop(sha1, None)
if sha1 in self.mdfiles[mdfile]:
self.mdfiles[mdfile].remove(sha1)
found += 1
for idp in old_idps:
self.idps.pop(idp, None)
self.__dict__.pop(idp, None)
removed += 1
self.mdfiles[mdfile] = self.idps.keys()
print(f"Found: {found} entities")
print(f"Removed: {removed} entities")
print(f"validUntil: {validUntil}")
def __getitem__(self, key):
entityID = unquote(key)
if entityID[:6] == "{sha1}":
sha1 = entityID[6:]
else:
sha1 = hasher(entityID) sha1 = hasher(entityID)
print(f"{{sha1}}{sha1} {entityID}")
entity_descriptor.set('validUntil', validUntil)
entity_descriptor.set('cacheDuration', cacheDuration)
entity = Entity()
entity.md = entity_descriptor
entity.valid_until = valid_until
entity.cache_duration = cache_duration
idps[domain][sha1] = entity
signed[domain].pop(sha1, None)
old_idps.pop(sha1, None)
found += 1
for idp in old_idps:
idps[domain].pop(idp, None)
signed[domain].pop(idp, None)
removed += 1
print(f"Found: {found} entities")
print(f"Removed: {removed} entities")
print(f"validUntil: {validUntil}")
def read_domain(domain, values):
metadir = values['metadir']
print(f" metadir: {metadir}")
files = os.listdir(metadir)
for file in files:
mdfile = os.path.realpath(os.path.join(metadir, file))
if os.path.isfile(mdfile):
print(f" file: {mdfile}")
read_metadata(domain, mdfile)
watch_list[mdfile] = domain
watch_manager.add_watch(mdfile, pyinotify.IN_CLOSE_WRITE)
def read_config():
with open('mdserver.yaml') as f:
config = yaml.safe_load(f)
return config
if sha1 in self.__dict__:
signed_entity = self.__dict__[sha1]
if signed_entity.valid_until > datetime.now(tz.tzutc()):
data = self.__dict__[sha1].md
elif sha1 in self.idps:
try:
print(f"sign {sha1}")
valid_until = self.idps[sha1].valid_until
if valid_until > datetime.now(tz.tzutc()):
signed_element = self.signer(self.idps[sha1].md)
signed_xml = ET.tostring(signed_element, pretty_print=True).decode()
signed_entity = Entity()
signed_entity.md = signed_xml
signed_entity.valid_until = self.idps[sha1].valid_until
self.__dict__[sha1] = signed_entity
data = signed_xml
except Exception as e:
print(sha1)
print(f" {e}")
else:
raise KeyError
print(f"serve {sha1}")
return data
class Server:
def __setitem__(self, domain, conf):
location, signer = conf
self.__dict__[domain] = Resource(location, signer)
watch_list[location] = domain
watch_manager.add_watch(location, pyinotify.IN_CLOSE_WRITE)
def __getitem__(self, domain):
return self.__dict__[domain]
server = Server()
event_notifier = pyinotify.ThreadedNotifier(watch_manager, EventProcessor()) event_notifier = pyinotify.ThreadedNotifier(watch_manager, EventProcessor())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment