Skip to content
Snippets Groups Projects
Commit 10df9b45 authored by Martin van Es's avatar Martin van Es
Browse files

Move all entities logic to Resource class

parent 32ee2257
No related branches found
No related tags found
No related merge requests found
...@@ -19,27 +19,15 @@ server = Server() ...@@ -19,27 +19,15 @@ server = Server()
strict_slashes=False, strict_slashes=False,
methods=['GET']) methods=['GET'])
def serve_all(realm): def serve_all(realm):
dirty = False print(f"all in {realm}")
for key, resource in server.items():
if resource.dirty:
dirty = True
resource.dirty = False
response = Response() response = Response()
response.headers['Content-Type'] = "application/samlmetadata+xml" response.headers['Content-Type'] = "application/samlmetadata+xml"
response.headers['Content-Disposition'] = "filename = \"metadata.xml\"" response.headers['Content-Disposition'] = "filename = \"metadata.xml\""
if server.all_entities is not None and not dirty: data = server[realm].all_entities()
print("cache all") response.data = data.md
data = server.all_entities max_age = int((data.valid_until -
response.data = data.md datetime.now(tz.tzutc())).total_seconds())
else:
print("sign all")
data = server[realm].all_entities()
response.data = data.md
server.all_entities = data
max_age = int((data.valid_until - datetime.now(tz.tzutc())).total_seconds())
response.headers['Cache-Control'] = f"max-age={max_age}" response.headers['Cache-Control'] = f"max-age={max_age}"
response.headers['Last-Modified'] = formatdate(timeval=mktime(data.last_modified.timetuple()), response.headers['Last-Modified'] = formatdate(timeval=mktime(data.last_modified.timetuple()),
......
...@@ -62,6 +62,7 @@ class Server(dict): ...@@ -62,6 +62,7 @@ class Server(dict):
class Resource: class Resource:
watch_list = {} watch_list = {}
dirty = False dirty = False
all_cache = None
def __init__(self, location, signer): def __init__(self, location, signer):
self.idps = {} self.idps = {}
...@@ -120,7 +121,7 @@ class Resource: ...@@ -120,7 +121,7 @@ class Resource:
cache_duration = parse_duration(cacheDuration) cache_duration = parse_duration(cacheDuration)
last_modified = datetime.now(tz.tzutc()) last_modified = datetime.now(tz.tzutc())
if valid_until > datetime.now(tz.tzutc()): if valid_until > datetime.now(tz.tzutc()):
self.dirty = True self.all_cache = None
for entity_descriptor in root.findall('md:EntityDescriptor', ns): for entity_descriptor in root.findall('md:EntityDescriptor', ns):
entityID = entity_descriptor.attrib.get('entityID', 'none') entityID = entity_descriptor.attrib.get('entityID', 'none')
sha1 = hasher(entityID) sha1 = hasher(entityID)
...@@ -191,28 +192,35 @@ class Resource: ...@@ -191,28 +192,35 @@ class Resource:
return data return data
def all_entities(self): def all_entities(self):
data = MData() if self.all_cache is not None:
ns = NSMAP print("cache all")
root = ET.Element(f"{{{MD_NAMESPACE}}}EntitiesDescriptor", data = self.all_cache
nsmap=ns) else:
# We are going to minimize expires, so set to some inf value print("sign all")
valid_until = (datetime.now(tz.tzutc()) + data = MData()
timedelta(days=365)) ns = NSMAP
cache_duration = parse_duration("P1D") root = ET.Element(f"{{{MD_NAMESPACE}}}EntitiesDescriptor",
for sha1, entity in self.idps.items(): nsmap=ns)
valid_until = min(valid_until, entity.valid_until) # We are going to minimize expires, so set to some inf value
cache_duration = min(cache_duration, entity.cache_duration) valid_until = (datetime.now(tz.tzutc()) +
ET.strip_attributes(entity.md, 'validUntil', 'cacheDuration') timedelta(days=365))
root.append(entity.md) cache_duration = parse_duration("P1D")
for sha1, entity in self.idps.items():
vu_zulu = str(valid_until).replace('+00:00', 'Z') valid_until = min(valid_until, entity.valid_until)
root.set('validUntil', vu_zulu) cache_duration = min(cache_duration, entity.cache_duration)
root.set('cacheDuration', duration_isoformat(cache_duration)) ET.strip_attributes(entity.md, 'validUntil', 'cacheDuration')
last_modified = datetime.now(tz.tzutc()) root.append(entity.md)
signed_root = self.signer(root) vu_zulu = str(valid_until).replace('+00:00', 'Z')
data.md = ET.tostring(signed_root, pretty_print=True) root.set('validUntil', vu_zulu)
data.valid_until = valid_until root.set('cacheDuration', duration_isoformat(cache_duration))
data.last_modified = last_modified last_modified = datetime.now(tz.tzutc())
signed_root = self.signer(root)
data.md = ET.tostring(signed_root, pretty_print=True)
data.valid_until = valid_until
data.last_modified = last_modified
self.all_cache = data
return data return data
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment