Skip to content
Snippets Groups Projects
Commit 2f36f31d authored by Martin van Es's avatar Martin van Es
Browse files

Added request for all entities in a realm

parent 7362b875
No related branches found
No related tags found
No related merge requests found
...@@ -8,7 +8,7 @@ from datetime import datetime ...@@ -8,7 +8,7 @@ from datetime import datetime
from isoduration import parse_duration from isoduration import parse_duration
from email.utils import formatdate from email.utils import formatdate
from utils import read_config, hasher, Entity from utils import read_config, hasher, Entity, Server
import logging import logging
log = logging.getLogger('werkzeug') log = logging.getLogger('werkzeug')
...@@ -19,11 +19,73 @@ app = Flask(__name__) ...@@ -19,11 +19,73 @@ app = Flask(__name__)
# Find all IdP's in edugain metadata # Find all IdP's in edugain metadata
cached = {} cached = Server()
@app.route('/<domain>/entities/<path:eid>', methods=['GET']) @app.route('/<domain>/entities',
def serve(domain, eid): strict_slashes=False,
methods=['GET'])
def serve_all(domain):
response = Response()
response.headers['Content-Type'] = "application/samlmetadata+xml"
response.headers['Content-Disposition'] = "filename = \"metadata.xml\""
if cached.all_entities is not None:
print("cache all")
cache = cached.all_entities
data = cache.md
last_modified = cache.last_modified
expires = min(datetime.now(tz.tzutc()) +
cache.cache_duration,
cache.valid_until)
max_age = int((expires -
datetime.now(tz.tzutc())).total_seconds())
else:
print("request all")
request = requests.get(f"{config[domain]['signer']}/{domain}"
f"/entities")
data = request.text
last_modified = request.headers.get('Last-Modified',
formatdate(timeval=None,
localtime=False,
usegmt=True))
try:
root = ET.fromstring(data)
validUntil = root.get('validUntil')
cacheDuration = root.get('cacheDuration')
cached_entity = Entity()
cached_entity.md = data
cached_entity.valid_until = parser.isoparse(validUntil)
cached_entity.cache_duration = parse_duration(cacheDuration)
cached_entity.expires = min(datetime.now(tz.tzutc())
+ cached_entity.cache_duration,
cached_entity.valid_until)
cached_entity.last_modified = last_modified
max_age = int((cached_entity.expires -
datetime.now(tz.tzutc())).total_seconds())
cached_entity.max_age = max_age
if cached_entity.expires > datetime.now(tz.tzutc()):
cached.all_entities = cached_entity
else:
raise KeyError
except Exception as e:
print(f"{e}")
data = "No valid metadata\n"
max_age = 60
response.headers['Content-type'] = "text/html"
response.headers['Cache-Control'] = "max-age=60"
response.status = 404
response.headers['Cache-Control'] = f"max-age={max_age}"
response.headers['Last-Modified'] = last_modified
response.data = data
return response
@app.route('/<domain>/entities/<path:eid>',
methods=['GET'])
def serve_one(domain, eid):
entityID = unquote(eid) entityID = unquote(eid)
if entityID[:6] == "{sha1}": if entityID[:6] == "{sha1}":
entityID = entityID[6:] entityID = entityID[6:]
......
#!/usr/bin/env python #!/usr/bin/env python
from utils import read_config, Resource from utils import read_config, Resource, Server
from flask import Flask, Response from flask import Flask, Response
from datetime import datetime from datetime import datetime
from dateutil import tz from dateutil import tz
...@@ -12,12 +12,40 @@ log.setLevel(logging.ERROR) ...@@ -12,12 +12,40 @@ log.setLevel(logging.ERROR)
config = read_config('mdserver.yaml') config = read_config('mdserver.yaml')
app = Flask(__name__) app = Flask(__name__)
server = {} server = Server()
@app.route('/<domain>/entities',
strict_slashes=False,
methods=['GET'])
def serve_all(domain):
response = Response()
response.headers['Content-Type'] = "application/samlmetadata+xml"
response.headers['Content-Disposition'] = "filename = \"metadata.xml\""
if server.all_entities is not None:
print("cache all")
data = server.all_entities
response.data = data.md
else:
print("sign all")
data = server[domain].all_entities()
response.data = data.md
server.all_entities = data
max_age = int((data.expires - datetime.now(tz.tzutc())).total_seconds())
response.headers['Cache-Control'] = f"max-age={max_age}"
response.headers['Last-Modified'] = formatdate(timeval=mktime(data.last_modified.timetuple()),
localtime=False, usegmt=True)
return response
@app.route('/<domain>/entities/<path:entity_id>', @app.route('/<domain>/entities/<path:entity_id>',
strict_slashes=False,
methods=['GET']) methods=['GET'])
def serve(domain, entity_id): def serve_one(domain, entity_id):
print(f"entity_id: {entity_id}")
response = Response() response = Response()
response.headers['Content-Type'] = "application/samlmetadata+xml" response.headers['Content-Type'] = "application/samlmetadata+xml"
response.headers['Content-Disposition'] = "filename = \"metadata.xml\"" response.headers['Content-Disposition'] = "filename = \"metadata.xml\""
......
import os import os
from lxml import etree as ET from lxml import etree as ET
from dateutil import parser, tz from dateutil import parser, tz
from isoduration import parse_duration from isoduration import parse_duration, format_duration
from datetime import datetime, timedelta from datetime import datetime, timedelta
import hashlib import hashlib
from urllib.parse import unquote from urllib.parse import unquote
...@@ -36,6 +36,7 @@ class MData(object): ...@@ -36,6 +36,7 @@ class MData(object):
self.md = None self.md = None
self.max_age = (datetime.now(tz.tzutc()) + self.max_age = (datetime.now(tz.tzutc()) +
timedelta(seconds=60)) timedelta(seconds=60))
self.last_modified = 0
class EventProcessor(pyinotify.ProcessEvent): class EventProcessor(pyinotify.ProcessEvent):
...@@ -50,6 +51,9 @@ class EventProcessor(pyinotify.ProcessEvent): ...@@ -50,6 +51,9 @@ class EventProcessor(pyinotify.ProcessEvent):
else: else:
self.resource.read_metadata(event.pathname) self.resource.read_metadata(event.pathname)
class Server(dict):
def __init__(self):
self.all_entities = None
class Resource: class Resource:
watch_list = {} watch_list = {}
...@@ -103,6 +107,8 @@ class Resource: ...@@ -103,6 +107,8 @@ class Resource:
return return
ns = root.nsmap.copy() ns = root.nsmap.copy()
ns['xml'] = 'http://www.w3.org/XML/1998/namespace' ns['xml'] = 'http://www.w3.org/XML/1998/namespace'
for signature in root.findall('.//ds:Signature', ns):
signature.getparent().remove(signature)
validUntil = root.get('validUntil') validUntil = root.get('validUntil')
cacheDuration = root.get('cacheDuration') cacheDuration = root.get('cacheDuration')
valid_until = parser.isoparse(validUntil) valid_until = parser.isoparse(validUntil)
...@@ -153,7 +159,6 @@ class Resource: ...@@ -153,7 +159,6 @@ class Resource:
print(f"cache {sha1}") print(f"cache {sha1}")
data.md = self.__dict__[sha1].md data.md = self.__dict__[sha1].md
if data.md is None and sha1 in self.idps: if data.md is None and sha1 in self.idps:
try: try:
print(f"sign {sha1}") print(f"sign {sha1}")
...@@ -164,8 +169,8 @@ class Resource: ...@@ -164,8 +169,8 @@ class Resource:
pretty_print=True).decode() pretty_print=True).decode()
signed_entity = Entity() signed_entity = Entity()
signed_entity.md = signed_xml signed_entity.md = signed_xml
signed_entity.expires = (datetime.now(tz.tzutc()) signed_entity.expires = (datetime.now(tz.tzutc()) +
+ self.idps[sha1].cache_duration) self.idps[sha1].cache_duration)
signed_entity.last_modified = self.idps[sha1].last_modified signed_entity.last_modified = self.idps[sha1].last_modified
self.__dict__[sha1] = signed_entity self.__dict__[sha1] = signed_entity
data.md = signed_xml data.md = signed_xml
...@@ -179,3 +184,27 @@ class Resource: ...@@ -179,3 +184,27 @@ class Resource:
datetime.now(tz.tzutc())).total_seconds()) datetime.now(tz.tzutc())).total_seconds())
data.last_modified = signed_entity.last_modified data.last_modified = signed_entity.last_modified
return data return data
def all_entities(self):
data = MData()
ns = {'md': 'urn:oasis:names:tc:SAML:2.0:metadata'}
root = ET.Element('{urn:oasis:names:tc:SAML:2.0:metadata}EntitiesDescriptor', nsmap=ns)
# We are going to minimize expires, so set to some inf value
expires = (datetime.now(tz.tzutc()) +
timedelta(days=365))
for sha1, entity in self.idps.items():
expires = min(expires, entity.expires)
root.append(entity.md)
last_modified = datetime.now(tz.tzutc())
ezulu = str(expires).replace('+00:00', 'Z')
root.set('validUntil', ezulu)
root.set('cacheDuration', "PT6H")
signed_root = self.signer(root)
data.md = ET.tostring(signed_root, pretty_print=True)
data.expires = expires
data.last_modified = last_modified
return data
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment