Skip to content
Snippets Groups Projects
session_management.py 1.94 KiB
from sqlalchemy import select
from flask import jsonify
from datetime import datetime
from compendium_v2.db import session_scope
from compendium_v2.db.auth_model import User
from flask_login import LoginManager  # type: ignore


def create_user(email: str, fullname: str, oidc_sub: str):
    """
    Function used to create a new user in the database.

    :param email: The email of the user
    :param fullname: The full name of the user
    :param oidc_sub: The OIDC subject identifier (ID) of the user
    :return: The user object
    """

    with session_scope() as session:
        user = User(email=email, fullname=fullname, oidc_sub=oidc_sub)
        session.add(user)
        return user


def fetch_user(profile: dict):
    """
    Function used to resolve an OIDC profile to a user in the database.

    :param profile: OIDC profile information
    :return: User object if the user exists, None otherwise.
    """

    with session_scope() as session:
        sub_id = profile['sub']
        user = session.scalar(select(User).where(User.oidc_sub == sub_id))
        if user is None:
            return None
        user.last_login = datetime.utcnow()
        return user


def user_loader(user_id: str):
    """
    Function used to retrieve the internal user model for the user attempting login.

    :param user_id: The ID of the user attempting login.
    :return: User object if the user exists, None otherwise.
    """

    with session_scope() as session:
        user = session.scalar(select(User).where(User.id == user_id))
        if user is None:
            return None
        user.last_login = datetime.utcnow()
        return user


def unauth_handler():
    return jsonify(success=False,
                   data={'login_required': True},
                   message='Authorize to access this page.'), 401


def setup_login_manager(login_manager: LoginManager):

    login_manager.user_loader(user_loader)
    login_manager.unauthorized_handler(unauth_handler)