Skip to content
Snippets Groups Projects
Commit 10df9b45 authored by Martin van Es's avatar Martin van Es
Browse files

Move all entities logic to Resource class

parent 32ee2257
No related branches found
No related tags found
No related merge requests found
......@@ -19,27 +19,15 @@ server = Server()
strict_slashes=False,
methods=['GET'])
def serve_all(realm):
dirty = False
for key, resource in server.items():
if resource.dirty:
dirty = True
resource.dirty = False
print(f"all in {realm}")
response = Response()
response.headers['Content-Type'] = "application/samlmetadata+xml"
response.headers['Content-Disposition'] = "filename = \"metadata.xml\""
if server.all_entities is not None and not dirty:
print("cache all")
data = server.all_entities
response.data = data.md
else:
print("sign all")
data = server[realm].all_entities()
response.data = data.md
server.all_entities = data
max_age = int((data.valid_until - datetime.now(tz.tzutc())).total_seconds())
data = server[realm].all_entities()
response.data = data.md
max_age = int((data.valid_until -
datetime.now(tz.tzutc())).total_seconds())
response.headers['Cache-Control'] = f"max-age={max_age}"
response.headers['Last-Modified'] = formatdate(timeval=mktime(data.last_modified.timetuple()),
......
......@@ -62,6 +62,7 @@ class Server(dict):
class Resource:
watch_list = {}
dirty = False
all_cache = None
def __init__(self, location, signer):
self.idps = {}
......@@ -120,7 +121,7 @@ class Resource:
cache_duration = parse_duration(cacheDuration)
last_modified = datetime.now(tz.tzutc())
if valid_until > datetime.now(tz.tzutc()):
self.dirty = True
self.all_cache = None
for entity_descriptor in root.findall('md:EntityDescriptor', ns):
entityID = entity_descriptor.attrib.get('entityID', 'none')
sha1 = hasher(entityID)
......@@ -191,28 +192,35 @@ class Resource:
return data
def all_entities(self):
data = MData()
ns = NSMAP
root = ET.Element(f"{{{MD_NAMESPACE}}}EntitiesDescriptor",
nsmap=ns)
# We are going to minimize expires, so set to some inf value
valid_until = (datetime.now(tz.tzutc()) +
timedelta(days=365))
cache_duration = parse_duration("P1D")
for sha1, entity in self.idps.items():
valid_until = min(valid_until, entity.valid_until)
cache_duration = min(cache_duration, entity.cache_duration)
ET.strip_attributes(entity.md, 'validUntil', 'cacheDuration')
root.append(entity.md)
vu_zulu = str(valid_until).replace('+00:00', 'Z')
root.set('validUntil', vu_zulu)
root.set('cacheDuration', duration_isoformat(cache_duration))
last_modified = datetime.now(tz.tzutc())
signed_root = self.signer(root)
data.md = ET.tostring(signed_root, pretty_print=True)
data.valid_until = valid_until
data.last_modified = last_modified
if self.all_cache is not None:
print("cache all")
data = self.all_cache
else:
print("sign all")
data = MData()
ns = NSMAP
root = ET.Element(f"{{{MD_NAMESPACE}}}EntitiesDescriptor",
nsmap=ns)
# We are going to minimize expires, so set to some inf value
valid_until = (datetime.now(tz.tzutc()) +
timedelta(days=365))
cache_duration = parse_duration("P1D")
for sha1, entity in self.idps.items():
valid_until = min(valid_until, entity.valid_until)
cache_duration = min(cache_duration, entity.cache_duration)
ET.strip_attributes(entity.md, 'validUntil', 'cacheDuration')
root.append(entity.md)
vu_zulu = str(valid_until).replace('+00:00', 'Z')
root.set('validUntil', vu_zulu)
root.set('cacheDuration', duration_isoformat(cache_duration))
last_modified = datetime.now(tz.tzutc())
signed_root = self.signer(root)
data.md = ET.tostring(signed_root, pretty_print=True)
data.valid_until = valid_until
data.last_modified = last_modified
self.all_cache = data
return data
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment