diff --git a/mdserver.py b/mdserver.py index 5942ecb8bc37da8607ad91c57575520c0dbb93bb..5636e5470febd606525b7ce52ff331835bf2598d 100755 --- a/mdserver.py +++ b/mdserver.py @@ -19,27 +19,15 @@ server = Server() strict_slashes=False, methods=['GET']) def serve_all(realm): - dirty = False - for key, resource in server.items(): - if resource.dirty: - dirty = True - resource.dirty = False - + print(f"all in {realm}") response = Response() response.headers['Content-Type'] = "application/samlmetadata+xml" response.headers['Content-Disposition'] = "filename = \"metadata.xml\"" - if server.all_entities is not None and not dirty: - print("cache all") - data = server.all_entities - response.data = data.md - else: - print("sign all") - data = server[realm].all_entities() - response.data = data.md - server.all_entities = data - - max_age = int((data.valid_until - datetime.now(tz.tzutc())).total_seconds()) + data = server[realm].all_entities() + response.data = data.md + max_age = int((data.valid_until - + datetime.now(tz.tzutc())).total_seconds()) response.headers['Cache-Control'] = f"max-age={max_age}" response.headers['Last-Modified'] = formatdate(timeval=mktime(data.last_modified.timetuple()), diff --git a/utils.py b/utils.py index 9ec536ff52585d96fc08eddf2b8d07098bb91d6b..5d9c7a1fcc8fab51a868e1e14fc639652e734138 100755 --- a/utils.py +++ b/utils.py @@ -62,6 +62,7 @@ class Server(dict): class Resource: watch_list = {} dirty = False + all_cache = None def __init__(self, location, signer): self.idps = {} @@ -120,7 +121,7 @@ class Resource: cache_duration = parse_duration(cacheDuration) last_modified = datetime.now(tz.tzutc()) if valid_until > datetime.now(tz.tzutc()): - self.dirty = True + self.all_cache = None for entity_descriptor in root.findall('md:EntityDescriptor', ns): entityID = entity_descriptor.attrib.get('entityID', 'none') sha1 = hasher(entityID) @@ -191,28 +192,35 @@ class Resource: return data def all_entities(self): - data = MData() - ns = NSMAP - root = ET.Element(f"{{{MD_NAMESPACE}}}EntitiesDescriptor", - nsmap=ns) - # We are going to minimize expires, so set to some inf value - valid_until = (datetime.now(tz.tzutc()) + - timedelta(days=365)) - cache_duration = parse_duration("P1D") - for sha1, entity in self.idps.items(): - valid_until = min(valid_until, entity.valid_until) - cache_duration = min(cache_duration, entity.cache_duration) - ET.strip_attributes(entity.md, 'validUntil', 'cacheDuration') - root.append(entity.md) - - vu_zulu = str(valid_until).replace('+00:00', 'Z') - root.set('validUntil', vu_zulu) - root.set('cacheDuration', duration_isoformat(cache_duration)) - last_modified = datetime.now(tz.tzutc()) - - signed_root = self.signer(root) - data.md = ET.tostring(signed_root, pretty_print=True) - data.valid_until = valid_until - data.last_modified = last_modified + if self.all_cache is not None: + print("cache all") + data = self.all_cache + else: + print("sign all") + data = MData() + ns = NSMAP + root = ET.Element(f"{{{MD_NAMESPACE}}}EntitiesDescriptor", + nsmap=ns) + # We are going to minimize expires, so set to some inf value + valid_until = (datetime.now(tz.tzutc()) + + timedelta(days=365)) + cache_duration = parse_duration("P1D") + for sha1, entity in self.idps.items(): + valid_until = min(valid_until, entity.valid_until) + cache_duration = min(cache_duration, entity.cache_duration) + ET.strip_attributes(entity.md, 'validUntil', 'cacheDuration') + root.append(entity.md) + + vu_zulu = str(valid_until).replace('+00:00', 'Z') + root.set('validUntil', vu_zulu) + root.set('cacheDuration', duration_isoformat(cache_duration)) + last_modified = datetime.now(tz.tzutc()) + + signed_root = self.signer(root) + data.md = ET.tostring(signed_root, pretty_print=True) + data.valid_until = valid_until + data.last_modified = last_modified + + self.all_cache = data return data