Skip to content
Snippets Groups Projects
Commit 4a93eb39 authored by Martin van Es's avatar Martin van Es
Browse files

Add per-domain signer, inotiify listener

parent c703d147
No related branches found
No related tags found
No related merge requests found
......@@ -6,3 +6,4 @@ __pycache__
meta.crt
meta.key
*.xml
mdserver.yaml
......@@ -16,8 +16,9 @@ Reads source metadata file(s) and outputs them signed to filesystem
## ```mdserver.py [mdfile] [mdfile] [mdfile] ...```
Starts a metadata signer server.
Reads source metadata files(s) on SIGHUP signal.
Serves and caches them signed from memory, on request
Reads source metadata files(s) from mdsigner.yaml configuration, see example.
Reloads metadata on inotify CLOSE_WRITE of metadata file.
Serves and caches signed by domain signer from memory, on request
## ```mdproxy.py```
Caches signed and cached ```mdserver.py``` metadata requests
......
#!/usr/bin/env python
import sys
import copy
import signal
from lxml import etree as ET
from flask import Flask, Response
from urllib.parse import unquote
from dateutil import parser, tz
from dateutil import tz
from datetime import datetime
from isoduration import parse_duration
# import pyinotify
import traceback
from utils import hasher, signer, Entity
from utils import read_config, read_domain, hasher, idps, \
signed, signer, Signers, Entity
app = Flask(__name__)
# Find all IdP's in edugain metadata
idps = {}
signed = {}
cert = open("meta.crt").read()
key = open("meta.key").read()
@app.route('/sign/<path:eid>', methods=['GET'])
def sign(eid):
global idps, signed, cert, key
@app.route('/<domain>/entities/<path:eid>', methods=['GET'])
def sign(domain, eid):
entityID = unquote(eid)
if entityID[:6] == "{sha1}":
sha1 = entityID[6:]
......@@ -37,21 +24,21 @@ def sign(eid):
response = Response()
response.headers['Content-Type'] = "application/samlmetadata+xml"
if sha1 in signed:
signed_entity = signed[sha1]
if sha1 in signed[domain]:
signed_entity = signed[domain][sha1]
if signed_entity.valid_until > datetime.now(tz.tzutc()):
response.data = signed[sha1].md
elif sha1 in idps:
response.data = signed[domain][sha1].md
elif sha1 in idps[domain]:
try:
print(f"sign {sha1}")
valid_until = idps[sha1].valid_until
print(f"sign {domain} {sha1}")
valid_until = idps[domain][sha1].valid_until
if valid_until > datetime.now(tz.tzutc()):
signed_element = signer(idps[sha1].md, cert, key)
signed_element = Signers()[signer[domain]](idps[domain][sha1].md)
signed_xml = ET.tostring(signed_element, pretty_print=True).decode()
signed_entity = Entity()
signed_entity.md = signed_xml
signed_entity.valid_until = idps[sha1].valid_until
signed[sha1] = signed_entity
signed_entity.valid_until = idps[domain][sha1].valid_until
signed[domain][sha1] = signed_entity
response.data = signed_xml
except Exception as e:
print(sha1)
......@@ -63,51 +50,15 @@ def sign(eid):
response.status = 404
return response
print(f"serve {sha1}")
print(f"serve {domain} {sha1}")
return response
def read_metadata(signum, frm):
print("\n--- SIGHUP ---")
global idps, signed
found = 0
removed = 0
old_idps = copy.deepcopy(idps)
for mdfile in sys.argv[1:]:
tree = ET.ElementTree(file=mdfile)
root = tree.getroot()
ns = copy.deepcopy(root.nsmap)
ns['xml'] = 'http://www.w3.org/XML/1998/namespace'
validUntil = root.get('validUntil')
cacheDuration = root.get('cacheDuration')
valid_until = parser.isoparse(validUntil)
cache_duration = parse_duration(cacheDuration)
if valid_until > datetime.now(tz.tzutc()):
for entity_descriptor in root.findall('md:EntityDescriptor', ns):
entityID = entity_descriptor.attrib.get('entityID', 'none')
sha1 = hasher(entityID)
print(f"{{sha1}}{sha1} {entityID}")
entity_descriptor.set('validUntil', validUntil)
entity_descriptor.set('cacheDuration', cacheDuration)
entity = Entity()
entity.md = entity_descriptor
entity.valid_until = valid_until
entity.cache_duration = cache_duration
idps[sha1] = entity
old_idps.pop(sha1, None)
# signed.pop(sha1, None)
found += 1
for idp in old_idps:
idps.pop(idp, None)
signed.pop(idp, None)
removed += 1
print(f"Found: {found} entities")
print(f"Removed: {removed} entities")
print(f"validUntil: {validUntil}")
signal.signal(signal.SIGHUP, read_metadata)
config = read_config()
read_metadata(None, None)
for domain, values in config.items():
print(f"domain: {domain}")
read_domain(domain, values)
signer[domain] = values['signer']
app.run(host='127.0.0.1', port=5001)
---
test:
signer: test_signer
metadir: metadata/test
foobar:
signer: foobar_signer
metadir: metadata/foobar
......@@ -4,3 +4,5 @@ flask
requests
python-dateutil
isoduration
pyyaml
pyinotify
import os
import copy
import yaml
from lxml import etree as ET
import hashlib
from signxml import XMLSigner
from isoduration import parse_duration
from dateutil import parser, tz
from datetime import datetime
import pyinotify
cert = open("meta.crt").read()
key = open("meta.key").read()
watch_manager = pyinotify.WatchManager()
watch_list = {}
idps = {}
signed = {}
signer = {}
class Entity(object):
......@@ -8,6 +27,91 @@ class Entity(object):
cache_duration = 0
class EventProcessor(pyinotify.ProcessEvent):
def process_IN_CLOSE_WRITE(self, event):
domain = watch_list[event.path]
read_metadata(domain, event.path)
class Signers(dict):
def __init__(self):
self['normal_signer'] = self._normal_signer
self['test_signer'] = self._test_signer
self['foobar_signer'] = self._foobar_signer
def _normal_signer(self, xml):
print("Normal signer")
return XMLSigner().sign(xml, key=key, cert=cert)
def _test_signer(self, xml):
print("Test signer")
return XMLSigner().sign(xml, key=key, cert=cert)
def _foobar_signer(self, xml):
print("Foobar signer")
return XMLSigner().sign(xml, key=key, cert=cert)
def read_metadata(domain, mdfile):
print("--- READ METADATA --")
global idps, signed
found = 0
removed = 0
old_idps = copy.deepcopy(idps.get(domain, {}))
idps[domain] = idps.get(domain, {})
signed[domain] = signed.get(domain, {})
tree = ET.ElementTree(file=mdfile)
root = tree.getroot()
ns = copy.deepcopy(root.nsmap)
ns['xml'] = 'http://www.w3.org/XML/1998/namespace'
validUntil = root.get('validUntil')
cacheDuration = root.get('cacheDuration')
valid_until = parser.isoparse(validUntil)
cache_duration = parse_duration(cacheDuration)
if valid_until > datetime.now(tz.tzutc()):
for entity_descriptor in root.findall('md:EntityDescriptor', ns):
entityID = entity_descriptor.attrib.get('entityID', 'none')
sha1 = hasher(entityID)
print(f"{{sha1}}{sha1} {entityID}")
entity_descriptor.set('validUntil', validUntil)
entity_descriptor.set('cacheDuration', cacheDuration)
entity = Entity()
entity.md = entity_descriptor
entity.valid_until = valid_until
entity.cache_duration = cache_duration
idps[domain][sha1] = entity
signed[domain].pop(sha1, None)
old_idps.pop(sha1, None)
found += 1
for idp in old_idps:
idps[domain].pop(idp, None)
signed[domain].pop(idp, None)
removed += 1
print(f"Found: {found} entities")
print(f"Removed: {removed} entities")
print(f"validUntil: {validUntil}")
def read_domain(domain, values):
metadir = values['metadir']
print(f" metadir: {metadir}")
files = os.listdir(metadir)
for file in files:
mdfile = os.path.realpath(os.path.join(metadir, file))
if os.path.isfile(mdfile):
print(f" file: {mdfile}")
read_metadata(domain, mdfile)
watch_list[mdfile] = domain
watch_manager.add_watch(mdfile, pyinotify.IN_CLOSE_WRITE)
def read_config():
with open('mdserver.yaml') as f:
config = yaml.safe_load(f)
return config
def hasher(entity_id):
sha1 = hashlib.sha1()
sha1.update(entity_id.encode())
......@@ -15,5 +119,6 @@ def hasher(entity_id):
return sha1_digest
def signer(xml, cert, key):
return XMLSigner().sign(xml, key=key, cert=cert)
event_notifier = pyinotify.ThreadedNotifier(watch_manager, EventProcessor())
event_notifier.start()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment