Skip to content
Snippets Groups Projects
session_management.py 3.11 KiB
from functools import wraps
from sqlalchemy import select
from flask import jsonify, current_app
from datetime import datetime
from flask_login import LoginManager, current_user  # type: ignore
from compendium_v2.db import session_scope
from compendium_v2.db.auth_model import User, ROLES
from compendium_v2.email import send_admin_signup_notification, send_user_signup_notification


def admin_required(func):
    """
    Decorator function used to require admin access to a view.

    :param func: The view function to decorate.
    :return: The decorated view function.
    """

    def wrapper(*args, **kwargs):
        if not current_app.config.get('LOGIN_DISABLED'):
            if not current_user.is_authenticated:
                return jsonify(success=False,
                               data={'login_required': True},
                               message='Authorize to access this page.'), 401
            role = current_user.roles
            if role != ROLES.admin:
                return jsonify(success=False,
                               data={'admin_required': True},
                               message='Admin privileges required to access this page.'), 401

        return func(*args, **kwargs)

    return wraps(func)(wrapper)


def create_user(email: str, fullname: str, oidc_sub: str, given_name: str):
    """
    Function used to create a new user in the database.

    :param email: The email of the user
    :param fullname: The full name of the user
    :param oidc_sub: The OIDC subject identifier (ID) of the user
    :param given_name: The given name of the user
    :return: The user object
    """

    with session_scope() as session:
        user = User(email=email, fullname=fullname, oidc_sub=oidc_sub)
        session.add(user)
        send_admin_signup_notification(user)
        send_user_signup_notification(user, given_name)
        return user


def fetch_user(profile: dict):
    """
    Function used to resolve an OIDC profile to a user in the database.

    :param profile: OIDC profile information
    :return: User object if the user exists, None otherwise.
    """

    with session_scope() as session:
        sub_id = profile['sub']
        user = session.scalar(select(User).where(User.oidc_sub == sub_id))
        if user is None:
            return None
        user.last_login = datetime.utcnow()
        return user


def user_loader(user_id: str):
    """
    Function used to retrieve the internal user model for the user attempting login.

    :param user_id: The ID of the user attempting login.
    :return: User object if the user exists, None otherwise.
    """

    with session_scope() as session:
        user = session.scalar(select(User).where(User.id == user_id))
        if user is None:
            return None
        user.last_login = datetime.utcnow()
        return user


def unauth_handler():
    return jsonify(success=False,
                   data={'login_required': True},
                   message='Authorize to access this page.'), 401


def setup_login_manager(login_manager: LoginManager):

    login_manager.user_loader(user_loader)
    login_manager.unauthorized_handler(unauth_handler)