import os from lxml import etree as ET from dateutil import parser, tz from isoduration import parse_duration from datetime import datetime import hashlib from urllib.parse import unquote import yaml import pyinotify from signers import Signers # watch_list = {} # watch_manager = pyinotify.WatchManager() def read_config(config): with open(config) as f: config = yaml.safe_load(f) return config def hasher(entity_id): sha1 = hashlib.sha1() sha1.update(entity_id.encode()) sha1_digest = sha1.hexdigest() return sha1_digest class Entity: def __init__(self): self.md = None self.valid_until = 0 self.cache_duration = 0 class Resource: def __init__(self, location, signer): self.idps = {} self.mdfiles = {} self.signer = Signers(signer) self.walk_location(location) def walk_location(self, location): files = os.listdir(location) for file in files: mdfile = os.path.realpath(os.path.join(location, file)) if os.path.isfile(mdfile): self.mdfiles[mdfile] = [] self._read_metadata(mdfile) def _read_metadata(self, mdfile): print("\n--- READ METADATA --") print(mdfile) found = 0 removed = 0 old_idps = self.mdfiles[mdfile].copy() tree = ET.ElementTree(file=mdfile) root = tree.getroot() ns = root.nsmap.copy() ns['xml'] = 'http://www.w3.org/XML/1998/namespace' validUntil = root.get('validUntil') cacheDuration = root.get('cacheDuration') valid_until = parser.isoparse(validUntil) cache_duration = parse_duration(cacheDuration) if valid_until > datetime.now(tz.tzutc()): for entity_descriptor in root.findall('md:EntityDescriptor', ns): entityID = entity_descriptor.attrib.get('entityID', 'none') sha1 = hasher(entityID) print(f" {{sha1}}{sha1} {entityID}") entity_descriptor.set('validUntil', validUntil) entity_descriptor.set('cacheDuration', cacheDuration) entity = Entity() entity.md = entity_descriptor entity.valid_until = valid_until entity.cache_duration = cache_duration self.idps[sha1] = entity self.__dict__.pop(sha1, None) if sha1 in self.mdfiles[mdfile]: self.mdfiles[mdfile].remove(sha1) found += 1 for idp in old_idps: self.idps.pop(idp, None) self.__dict__.pop(idp, None) removed += 1 self.mdfiles[mdfile] = self.idps.keys() print(f"Found: {found} entities") print(f"Removed: {removed} entities") print(f"validUntil: {validUntil}") def __getitem__(self, key): entityID = unquote(key) if entityID[:6] == "{sha1}": sha1 = entityID[6:] else: sha1 = hasher(entityID) if sha1 in self.__dict__: signed_entity = self.__dict__[sha1] if signed_entity.valid_until > datetime.now(tz.tzutc()): data = self.__dict__[sha1].md elif sha1 in self.idps: try: print(f"sign {sha1}") valid_until = self.idps[sha1].valid_until if valid_until > datetime.now(tz.tzutc()): signed_element = self.signer(self.idps[sha1].md) signed_xml = ET.tostring(signed_element, pretty_print=True).decode() signed_entity = Entity() signed_entity.md = signed_xml signed_entity.valid_until = self.idps[sha1].valid_until self.__dict__[sha1] = signed_entity data = signed_xml except Exception as e: print(sha1) print(f" {e}") else: raise KeyError print(f"serve {sha1}") return data class EventProcessor(pyinotify.ProcessEvent): def __init__(self, server): self.server = server def process_IN_CLOSE_WRITE(self, event): self.server.process(event.path) class Server: watch_list = {} def __init__(self): self.watch_manager = pyinotify.WatchManager() self.event_notifier = pyinotify.ThreadedNotifier(self.watch_manager, EventProcessor(self)) self.event_notifier.start() def add_watch(self, domain, location): self.watch_list[location] = domain self.watch_manager.add_watch(location, pyinotify.IN_CLOSE_WRITE) def process(self, location): domain = self.watch_list[location] print(f"Notify {domain} {location}") self.__dict__[domain].walk_location(location) def __setitem__(self, domain, resource): self.__dict__[domain] = resource # watch_list[location] = domain # watch_manager.add_watch(location, pyinotify.IN_CLOSE_WRITE) def __getitem__(self, domain): return self.__dict__[domain]