diff --git a/compendium_v2/routes/survey.py b/compendium_v2/routes/survey.py index bfbcaab00ec7eb2a378cb59074d9705cc4bdb56d..bd2c6293c0eba5c5c72cd12144625b38bc6b979e 100644 --- a/compendium_v2/routes/survey.py +++ b/compendium_v2/routes/survey.py @@ -3,7 +3,7 @@ from enum import Enum from typing import Any, TypedDict, List, Dict from flask import Blueprint, jsonify, request -from flask_login import login_required # type: ignore +from flask_login import login_required, current_user # type: ignore from sqlalchemy import select from sqlalchemy.orm import joinedload, load_only @@ -11,7 +11,7 @@ from compendium_v2.db import db from compendium_v2.db.model import NREN from compendium_v2.db.survey_model import Survey, SurveyResponse, SurveyStatus, ResponseStatus from compendium_v2.routes import common -from compendium_v2.auth.session_management import admin_required +from compendium_v2.auth.session_management import admin_required, User routes = Blueprint('survey', __name__) @@ -68,6 +68,16 @@ class VerificationStatus(str, Enum): Edited = "edited" # a question for which last years answer was edited +def check_access_nren(user: User, nren: str) -> bool: + if user.is_anonymous: + return False + if user.is_admin: + return True + if nren == user.nren: + return True + return False + + @routes.route('/list', methods=['GET']) @common.require_accepts_json @admin_required @@ -132,14 +142,17 @@ def start_new_survey() -> Any: """ all_surveys = db.session.scalars(select(Survey).options(load_only(Survey.status))) if any([survey.status != SurveyStatus.published for survey in all_surveys]): - return "All earlier surveys should be published before starting a new one", 400 + return jsonify({ + 'success': False, + 'message': 'All earlier surveys should be published before starting a new one' + }), 400 last_survey = db.session.scalar( select(Survey).order_by(Survey.year.desc()).limit(1) ) if not last_survey: - return "No survey found", 404 + return jsonify({'success': False, 'message': 'No surveys found'}), 404 new_year = last_survey.year + 1 new_survey = last_survey.survey @@ -161,14 +174,14 @@ def open_survey(year) -> Any: """ survey = db.session.scalar(select(Survey).where(Survey.year == year)) if not survey: - return "Survey not found", 404 + return jsonify({'success': False, 'message': 'Survey not found'}), 404 if survey.status != SurveyStatus.closed: - return "Survey is not closed and can therefore not be opened", 400 + return jsonify({'success': False, 'message': 'Survey is not closed and can therefore not be opened'}), 400 all_surveys = db.session.scalars(select(Survey)) if any([s.status == SurveyStatus.open for s in all_surveys]): - return "There already is an open survey", 400 + return jsonify({'success': False, 'message': 'There already is an open survey'}), 400 survey.status = SurveyStatus.open db.session.commit() @@ -187,10 +200,10 @@ def close_survey(year) -> Any: """ survey = db.session.scalar(select(Survey).where(Survey.year == year)) if not survey: - return "Survey not found", 404 + return jsonify({'success': False, 'message': 'Survey not found'}), 404 if survey.status != SurveyStatus.open: - return "Survey is not open and can therefore not be closed", 400 + return jsonify({'success': False, 'message': 'Survey is not open and can therefore not be closed'}), 400 survey.status = SurveyStatus.closed db.session.commit() @@ -209,13 +222,16 @@ def publish_survey(year) -> Any: """ survey = db.session.scalar(select(Survey).where(Survey.year == year)) if not survey: - return "Survey not found", 404 + return jsonify({'success': False, 'message': 'Survey not found'}), 404 if survey.status not in [SurveyStatus.closed, SurveyStatus.published]: - return "Survey is not closed or published and can therefore not be published", 400 + return jsonify({ + 'success': False, + 'message': 'Survey is not closed or published and can therefore not be published' + }), 400 if any([response.status != ResponseStatus.checked for response in survey.responses]): - return "There are responses that arent checked yet", 400 + return jsonify({'success': False, 'message': 'There are responses that arent checked yet'}), 400 # TODO call new survey_publisher with all responses and the year @@ -313,7 +329,8 @@ def load_survey(year, nren_name) -> Any: if not survey: return "Survey not found", 404 - # TODO validation (if not admin) on year (is survey open?) and nren (logged in user is part of nren?) + if not check_access_nren(current_user, nren): + return jsonify({'success': False, 'message': 'You do not have permissions to access this survey.'}), 403 data = {} page = 0 @@ -362,7 +379,11 @@ def save_survey(year, nren_name) -> Any: if survey is None: return "Survey not found", 404 - # TODO validation (if not admin) on year (is survey open?) and nren (logged in user is part of nren?) + if not check_access_nren(current_user, nren): + return jsonify({'success': False, 'message': 'You do not have permission to edit this survey.'}), 403 + + if survey.status != SurveyStatus.open and not current_user.is_admin: + return jsonify({'success': False, 'message': 'Survey is closed'}), 400 response = db.session.scalar( select(SurveyResponse).where(SurveyResponse.survey_year == year).where(SurveyResponse.nren_id == nren.id) @@ -373,7 +394,7 @@ def save_survey(year, nren_name) -> Any: save_survey = request.json if not save_survey: - raise Exception("Invalid format") + return jsonify({'success': False, 'message': 'Invalid Survey Format'}), 400 response.answers = { "data": save_survey["data"], diff --git a/test/conftest.py b/test/conftest.py index be9e652dafe3d58f91cfca9c86663cc7214531ff..001aee506856395d69d023cb33d57a2aef8bb204 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -3,11 +3,12 @@ import os import pytest import random +from sqlalchemy import select from flask_login import LoginManager # type: ignore import compendium_v2 from compendium_v2.db import db, model, survey_model from compendium_v2.survey_db import model as survey_db_model -from compendium_v2.auth.session_management import setup_login_manager +from compendium_v2.auth.session_management import setup_login_manager, User, ROLES def _test_data_csv(filename): @@ -24,6 +25,38 @@ def dummy_config(): } +@pytest.fixture +def mocked_admin_user(app, mocker): + with app.app_context(): + user = User(email='testemail123@email.local', fullname='testfullname', oidc_sub='fakesub', roles=ROLES.admin) + + nren2 = db.session.scalar(select(model.NREN).filter(model.NREN.name == 'nren2')) + user.nrens.append(nren2) + db.session.add(user) + db.session.commit() + + def user_loader(*args): + return user + mocker.patch('flask_login.utils._get_user', user_loader) + yield user + + +@pytest.fixture +def mocked_user(app, mocker): + with app.app_context(): + user = User(email='testemail123@email.local', fullname='testfullname', oidc_sub='fakesub') + + nren2 = db.session.scalar(select(model.NREN).filter(model.NREN.name == 'nren2')) + user.nrens.append(nren2) + db.session.add(user) + db.session.commit() + + def user_loader(*args): + return user + mocker.patch('flask_login.utils._get_user', user_loader) + yield user + + @pytest.fixture def test_budget_data(app): with app.app_context(): @@ -126,7 +159,7 @@ def test_survey_data(app): survey2022 = survey_model.Survey( year=2022, survey={'part1': [{'title': 'ha', 'visibleIf': 'false'}]}, - status=survey_model.SurveyStatus.published + status=survey_model.SurveyStatus.open ) db.session.add_all([survey2021, survey2022]) diff --git a/test/test_survey.py b/test/test_survey.py index b4b4d4f6aa64d0d62bc353139cf131988732ac05..d832cc1a689b7d2dc6afc71f5a36deac1c379eb8 100644 --- a/test/test_survey.py +++ b/test/test_survey.py @@ -1,5 +1,7 @@ import json import jsonschema +from compendium_v2.db import db +from compendium_v2.db.survey_model import Survey, SurveyStatus from compendium_v2.routes.survey import LIST_SURVEYS_RESPONSE_SCHEMA, SURVEY_RESPONSE_SCHEMA, VerificationStatus @@ -13,7 +15,20 @@ def test_survey_route_list_response(client, test_survey_data): assert result -def test_survey_route_new(client, test_survey_data): +def test_survey_route_new(app, client, test_survey_data, mocked_user): + rv = client.post( + '/api/survey/new', + headers={'Accept': ['application/json']}) + assert rv.status_code == 400 + result = json.loads(rv.data.decode('utf-8')) + assert not result.get('success') + + # mark all surveys as published + with app.app_context(): + for survey in db.session.query(Survey).all(): + survey.status = SurveyStatus.published + db.session.commit() + rv = client.post( '/api/survey/new', headers={'Accept': ['application/json']}) @@ -27,13 +42,26 @@ def test_survey_route_new(client, test_survey_data): assert rv.status_code != 200 -def test_survey_route_open_close(client, test_survey_data): +def test_survey_route_open_close(app, client, test_survey_data, mocked_user): + rv = client.post( + '/api/survey/new', + headers={'Accept': ['application/json']}) + assert rv.status_code == 400 + result = json.loads(rv.data.decode('utf-8')) + assert not result.get('success') + + # mark all surveys as published + with app.app_context(): + for survey in db.session.query(Survey).all(): + survey.status = SurveyStatus.published + db.session.commit() + rv = client.post( '/api/survey/new', headers={'Accept': ['application/json']}) assert rv.status_code == 200 result = json.loads(rv.data.decode('utf-8')) - assert result == {'success': True} + assert result.get('success') rv = client.post( '/api/survey/open/2023', @@ -60,13 +88,25 @@ def test_survey_route_open_close(client, test_survey_data): assert rv.status_code != 200 -def test_survey_route_publish(client, test_survey_data): +def test_survey_route_publish(app, client, test_survey_data, mocked_admin_user): + rv = client.post( + '/api/survey/publish/2022', + headers={'Accept': ['application/json']}) + assert rv.status_code == 400 + result = json.loads(rv.data.decode('utf-8')) + assert not result.get('success') + + with app.app_context(): + survey = db.session.scalar(Survey.query.filter(Survey.year == 2022)) + survey.status = SurveyStatus.closed + db.session.commit() + rv = client.post( '/api/survey/publish/2022', headers={'Accept': ['application/json']}) assert rv.status_code == 200 result = json.loads(rv.data.decode('utf-8')) - assert result == {'success': True} + assert result.get('success') def test_survey_route_try_response(client, test_survey_data): @@ -89,7 +129,7 @@ def test_survey_route_inspect_response(client, test_survey_data): assert result -def test_survey_route_save_load_response(client, test_survey_data): +def test_survey_route_save_load_response(client, test_survey_data, mocked_user): rv = client.post( '/api/survey/save/2021/nren2', headers={'Accept': ['application/json']}, @@ -98,9 +138,10 @@ def test_survey_route_save_load_response(client, test_survey_data): 'page': 3, 'verification_status': {'q1': VerificationStatus.Verified} }) - assert rv.status_code == 200 + assert rv.status_code == 400 result = json.loads(rv.data.decode('utf-8')) - assert result == {'success': True} + assert not result.get('success') + assert result.get('message') == 'Survey is closed' rv = client.get( '/api/survey/load/2021/nren2', @@ -108,9 +149,21 @@ def test_survey_route_save_load_response(client, test_survey_data): assert rv.status_code == 200 result = json.loads(rv.data.decode('utf-8')) jsonschema.validate(result, SURVEY_RESPONSE_SCHEMA) - assert result['page'] == 3 - assert result['data'] == {'q1': 'yes', 'q2': ['no']} - assert result['verification_status'] == {'q1': VerificationStatus.Verified} + assert result['page'] == 0 + assert result['data'] == {} + assert result['verification_status'] == {} + + rv = client.post( + '/api/survey/save/2022/nren2', + headers={'Accept': ['application/json']}, + json={ + 'data': {'q1': 'yes', 'q2': ['no']}, + 'page': 3, + 'verification_status': {'q1': VerificationStatus.Verified} + }) + assert rv.status_code == 200 + result = json.loads(rv.data.decode('utf-8')) + assert result.get('success') rv = client.get( '/api/survey/load/2022/nren2', @@ -118,6 +171,6 @@ def test_survey_route_save_load_response(client, test_survey_data): assert rv.status_code == 200 result = json.loads(rv.data.decode('utf-8')) jsonschema.validate(result, SURVEY_RESPONSE_SCHEMA) - assert result['page'] == 0 + assert result['page'] == 3 assert result['data'] == {'q1': 'yes', 'q2': ['no']} - assert result['verification_status'] == {'q1': VerificationStatus.Unverified, 'q2': VerificationStatus.Unverified} + assert result['verification_status'] == {'q1': VerificationStatus.Verified}