import os
from lxml import etree as ET
from dateutil import parser, tz
from isoduration import parse_duration
from datetime import datetime
import hashlib
from urllib.parse import unquote
import yaml
import pyinotify
from signers import Signers

# watch_list = {}
# watch_manager = pyinotify.WatchManager()


def read_config(config):
    with open(config) as f:
        config = yaml.safe_load(f)
    return config


def hasher(entity_id):
    sha1 = hashlib.sha1()
    sha1.update(entity_id.encode())
    sha1_digest = sha1.hexdigest()
    return sha1_digest


class Entity:
    def __init__(self):
        self.md = None
        self.valid_until = 0
        self.cache_duration = 0


class Resource:
    def __init__(self, location, signer):
        self.idps = {}
        self.mdfiles = {}
        self.signer = Signers(signer)
        self.walk_location(location)

    def walk_location(self, location):
        files = os.listdir(location)
        for file in files:
            mdfile = os.path.realpath(os.path.join(location, file))
            if os.path.isfile(mdfile):
                self.mdfiles[mdfile] = []
                self._read_metadata(mdfile)

    def _read_metadata(self, mdfile):
        print("\n--- READ METADATA --")
        print(mdfile)
        found = 0
        removed = 0
        old_idps = self.mdfiles[mdfile].copy()
        tree = ET.ElementTree(file=mdfile)
        root = tree.getroot()
        ns = root.nsmap.copy()
        ns['xml'] = 'http://www.w3.org/XML/1998/namespace'
        validUntil = root.get('validUntil')
        cacheDuration = root.get('cacheDuration')
        valid_until = parser.isoparse(validUntil)
        cache_duration = parse_duration(cacheDuration)
        if valid_until > datetime.now(tz.tzutc()):
            for entity_descriptor in root.findall('md:EntityDescriptor', ns):
                entityID = entity_descriptor.attrib.get('entityID', 'none')
                sha1 = hasher(entityID)
                print(f"  {{sha1}}{sha1} {entityID}")
                entity_descriptor.set('validUntil', validUntil)
                entity_descriptor.set('cacheDuration', cacheDuration)
                entity = Entity()
                entity.md = entity_descriptor
                entity.valid_until = valid_until
                entity.cache_duration = cache_duration
                self.idps[sha1] = entity
                self.__dict__.pop(sha1, None)
                if sha1 in self.mdfiles[mdfile]:
                    self.mdfiles[mdfile].remove(sha1)
                found += 1
        for idp in old_idps:
            self.idps.pop(idp, None)
            self.__dict__.pop(idp, None)
            removed += 1

        self.mdfiles[mdfile] = self.idps.keys()

        print(f"Found: {found} entities")
        print(f"Removed: {removed} entities")
        print(f"validUntil: {validUntil}")

    def __getitem__(self, key):
        entityID = unquote(key)
        if entityID[:6] == "{sha1}":
            sha1 = entityID[6:]
        else:
            sha1 = hasher(entityID)

        if sha1 in self.__dict__:
            signed_entity = self.__dict__[sha1]
            if signed_entity.valid_until > datetime.now(tz.tzutc()):
                data = self.__dict__[sha1].md
        elif sha1 in self.idps:
            try:
                print(f"sign {sha1}")
                valid_until = self.idps[sha1].valid_until
                if valid_until > datetime.now(tz.tzutc()):
                    signed_element = self.signer(self.idps[sha1].md)
                    signed_xml = ET.tostring(signed_element, pretty_print=True).decode()
                    signed_entity = Entity()
                    signed_entity.md = signed_xml
                    signed_entity.valid_until = self.idps[sha1].valid_until
                    self.__dict__[sha1] = signed_entity
                    data = signed_xml
            except Exception as e:
                print(sha1)
                print(f"  {e}")
        else:
            raise KeyError

        print(f"serve {sha1}")
        return data


class EventProcessor(pyinotify.ProcessEvent):
    def __init__(self, server):
        self.server = server

    def process_IN_CLOSE_WRITE(self, event):
        self.server.process(event.path)


class Server:
    watch_list = {}

    def __init__(self):
        self.watch_manager = pyinotify.WatchManager()
        self.event_notifier = pyinotify.ThreadedNotifier(self.watch_manager, EventProcessor(self))
        self.event_notifier.start()

    def add_watch(self, domain, location):
        self.watch_list[location] = domain
        self.watch_manager.add_watch(location, pyinotify.IN_CLOSE_WRITE)

    def process(self, location):
        domain = self.watch_list[location]
        print(f"Notify {domain} {location}")
        self.__dict__[domain].walk_location(location)

    def __setitem__(self, domain, resource):
        self.__dict__[domain] = resource
        # watch_list[location] = domain
        # watch_manager.add_watch(location, pyinotify.IN_CLOSE_WRITE)

    def __getitem__(self, domain):
        return self.__dict__[domain]