Skip to content
Snippets Groups Projects
Commit a350f798 authored by Bjarke Madsen's avatar Bjarke Madsen
Browse files

Fix tests after adding admin/nren/user checks

parent ee597098
Branches
Tags
No related merge requests found
...@@ -3,7 +3,7 @@ from enum import Enum ...@@ -3,7 +3,7 @@ from enum import Enum
from typing import Any, TypedDict, List, Dict from typing import Any, TypedDict, List, Dict
from flask import Blueprint, jsonify, request from flask import Blueprint, jsonify, request
from flask_login import login_required # type: ignore from flask_login import login_required, current_user # type: ignore
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import joinedload, load_only from sqlalchemy.orm import joinedload, load_only
...@@ -11,7 +11,7 @@ from compendium_v2.db import db ...@@ -11,7 +11,7 @@ from compendium_v2.db import db
from compendium_v2.db.model import NREN from compendium_v2.db.model import NREN
from compendium_v2.db.survey_model import Survey, SurveyResponse, SurveyStatus, ResponseStatus from compendium_v2.db.survey_model import Survey, SurveyResponse, SurveyStatus, ResponseStatus
from compendium_v2.routes import common from compendium_v2.routes import common
from compendium_v2.auth.session_management import admin_required from compendium_v2.auth.session_management import admin_required, User
routes = Blueprint('survey', __name__) routes = Blueprint('survey', __name__)
...@@ -68,6 +68,16 @@ class VerificationStatus(str, Enum): ...@@ -68,6 +68,16 @@ class VerificationStatus(str, Enum):
Edited = "edited" # a question for which last years answer was edited Edited = "edited" # a question for which last years answer was edited
def check_access_nren(user: User, nren: str) -> bool:
if user.is_anonymous:
return False
if user.is_admin:
return True
if nren == user.nren:
return True
return False
@routes.route('/list', methods=['GET']) @routes.route('/list', methods=['GET'])
@common.require_accepts_json @common.require_accepts_json
@admin_required @admin_required
...@@ -132,14 +142,17 @@ def start_new_survey() -> Any: ...@@ -132,14 +142,17 @@ def start_new_survey() -> Any:
""" """
all_surveys = db.session.scalars(select(Survey).options(load_only(Survey.status))) all_surveys = db.session.scalars(select(Survey).options(load_only(Survey.status)))
if any([survey.status != SurveyStatus.published for survey in all_surveys]): if any([survey.status != SurveyStatus.published for survey in all_surveys]):
return "All earlier surveys should be published before starting a new one", 400 return jsonify({
'success': False,
'message': 'All earlier surveys should be published before starting a new one'
}), 400
last_survey = db.session.scalar( last_survey = db.session.scalar(
select(Survey).order_by(Survey.year.desc()).limit(1) select(Survey).order_by(Survey.year.desc()).limit(1)
) )
if not last_survey: if not last_survey:
return "No survey found", 404 return jsonify({'success': False, 'message': 'No surveys found'}), 404
new_year = last_survey.year + 1 new_year = last_survey.year + 1
new_survey = last_survey.survey new_survey = last_survey.survey
...@@ -161,14 +174,14 @@ def open_survey(year) -> Any: ...@@ -161,14 +174,14 @@ def open_survey(year) -> Any:
""" """
survey = db.session.scalar(select(Survey).where(Survey.year == year)) survey = db.session.scalar(select(Survey).where(Survey.year == year))
if not survey: if not survey:
return "Survey not found", 404 return jsonify({'success': False, 'message': 'Survey not found'}), 404
if survey.status != SurveyStatus.closed: if survey.status != SurveyStatus.closed:
return "Survey is not closed and can therefore not be opened", 400 return jsonify({'success': False, 'message': 'Survey is not closed and can therefore not be opened'}), 400
all_surveys = db.session.scalars(select(Survey)) all_surveys = db.session.scalars(select(Survey))
if any([s.status == SurveyStatus.open for s in all_surveys]): if any([s.status == SurveyStatus.open for s in all_surveys]):
return "There already is an open survey", 400 return jsonify({'success': False, 'message': 'There already is an open survey'}), 400
survey.status = SurveyStatus.open survey.status = SurveyStatus.open
db.session.commit() db.session.commit()
...@@ -187,10 +200,10 @@ def close_survey(year) -> Any: ...@@ -187,10 +200,10 @@ def close_survey(year) -> Any:
""" """
survey = db.session.scalar(select(Survey).where(Survey.year == year)) survey = db.session.scalar(select(Survey).where(Survey.year == year))
if not survey: if not survey:
return "Survey not found", 404 return jsonify({'success': False, 'message': 'Survey not found'}), 404
if survey.status != SurveyStatus.open: if survey.status != SurveyStatus.open:
return "Survey is not open and can therefore not be closed", 400 return jsonify({'success': False, 'message': 'Survey is not open and can therefore not be closed'}), 400
survey.status = SurveyStatus.closed survey.status = SurveyStatus.closed
db.session.commit() db.session.commit()
...@@ -209,13 +222,16 @@ def publish_survey(year) -> Any: ...@@ -209,13 +222,16 @@ def publish_survey(year) -> Any:
""" """
survey = db.session.scalar(select(Survey).where(Survey.year == year)) survey = db.session.scalar(select(Survey).where(Survey.year == year))
if not survey: if not survey:
return "Survey not found", 404 return jsonify({'success': False, 'message': 'Survey not found'}), 404
if survey.status not in [SurveyStatus.closed, SurveyStatus.published]: if survey.status not in [SurveyStatus.closed, SurveyStatus.published]:
return "Survey is not closed or published and can therefore not be published", 400 return jsonify({
'success': False,
'message': 'Survey is not closed or published and can therefore not be published'
}), 400
if any([response.status != ResponseStatus.checked for response in survey.responses]): if any([response.status != ResponseStatus.checked for response in survey.responses]):
return "There are responses that arent checked yet", 400 return jsonify({'success': False, 'message': 'There are responses that arent checked yet'}), 400
# TODO call new survey_publisher with all responses and the year # TODO call new survey_publisher with all responses and the year
...@@ -313,7 +329,8 @@ def load_survey(year, nren_name) -> Any: ...@@ -313,7 +329,8 @@ def load_survey(year, nren_name) -> Any:
if not survey: if not survey:
return "Survey not found", 404 return "Survey not found", 404
# TODO validation (if not admin) on year (is survey open?) and nren (logged in user is part of nren?) if not check_access_nren(current_user, nren):
return jsonify({'success': False, 'message': 'You do not have permissions to access this survey.'}), 403
data = {} data = {}
page = 0 page = 0
...@@ -362,7 +379,11 @@ def save_survey(year, nren_name) -> Any: ...@@ -362,7 +379,11 @@ def save_survey(year, nren_name) -> Any:
if survey is None: if survey is None:
return "Survey not found", 404 return "Survey not found", 404
# TODO validation (if not admin) on year (is survey open?) and nren (logged in user is part of nren?) if not check_access_nren(current_user, nren):
return jsonify({'success': False, 'message': 'You do not have permission to edit this survey.'}), 403
if survey.status != SurveyStatus.open and not current_user.is_admin:
return jsonify({'success': False, 'message': 'Survey is closed'}), 400
response = db.session.scalar( response = db.session.scalar(
select(SurveyResponse).where(SurveyResponse.survey_year == year).where(SurveyResponse.nren_id == nren.id) select(SurveyResponse).where(SurveyResponse.survey_year == year).where(SurveyResponse.nren_id == nren.id)
...@@ -373,7 +394,7 @@ def save_survey(year, nren_name) -> Any: ...@@ -373,7 +394,7 @@ def save_survey(year, nren_name) -> Any:
save_survey = request.json save_survey = request.json
if not save_survey: if not save_survey:
raise Exception("Invalid format") return jsonify({'success': False, 'message': 'Invalid Survey Format'}), 400
response.answers = { response.answers = {
"data": save_survey["data"], "data": save_survey["data"],
......
...@@ -3,11 +3,12 @@ import os ...@@ -3,11 +3,12 @@ import os
import pytest import pytest
import random import random
from sqlalchemy import select
from flask_login import LoginManager # type: ignore from flask_login import LoginManager # type: ignore
import compendium_v2 import compendium_v2
from compendium_v2.db import db, model, survey_model from compendium_v2.db import db, model, survey_model
from compendium_v2.survey_db import model as survey_db_model from compendium_v2.survey_db import model as survey_db_model
from compendium_v2.auth.session_management import setup_login_manager from compendium_v2.auth.session_management import setup_login_manager, User, ROLES
def _test_data_csv(filename): def _test_data_csv(filename):
...@@ -24,6 +25,38 @@ def dummy_config(): ...@@ -24,6 +25,38 @@ def dummy_config():
} }
@pytest.fixture
def mocked_admin_user(app, mocker):
with app.app_context():
user = User(email='testemail123@email.local', fullname='testfullname', oidc_sub='fakesub', roles=ROLES.admin)
nren2 = db.session.scalar(select(model.NREN).filter(model.NREN.name == 'nren2'))
user.nrens.append(nren2)
db.session.add(user)
db.session.commit()
def user_loader(*args):
return user
mocker.patch('flask_login.utils._get_user', user_loader)
yield user
@pytest.fixture
def mocked_user(app, mocker):
with app.app_context():
user = User(email='testemail123@email.local', fullname='testfullname', oidc_sub='fakesub')
nren2 = db.session.scalar(select(model.NREN).filter(model.NREN.name == 'nren2'))
user.nrens.append(nren2)
db.session.add(user)
db.session.commit()
def user_loader(*args):
return user
mocker.patch('flask_login.utils._get_user', user_loader)
yield user
@pytest.fixture @pytest.fixture
def test_budget_data(app): def test_budget_data(app):
with app.app_context(): with app.app_context():
...@@ -126,7 +159,7 @@ def test_survey_data(app): ...@@ -126,7 +159,7 @@ def test_survey_data(app):
survey2022 = survey_model.Survey( survey2022 = survey_model.Survey(
year=2022, year=2022,
survey={'part1': [{'title': 'ha', 'visibleIf': 'false'}]}, survey={'part1': [{'title': 'ha', 'visibleIf': 'false'}]},
status=survey_model.SurveyStatus.published status=survey_model.SurveyStatus.open
) )
db.session.add_all([survey2021, survey2022]) db.session.add_all([survey2021, survey2022])
......
import json import json
import jsonschema import jsonschema
from compendium_v2.db import db
from compendium_v2.db.survey_model import Survey, SurveyStatus
from compendium_v2.routes.survey import LIST_SURVEYS_RESPONSE_SCHEMA, SURVEY_RESPONSE_SCHEMA, VerificationStatus from compendium_v2.routes.survey import LIST_SURVEYS_RESPONSE_SCHEMA, SURVEY_RESPONSE_SCHEMA, VerificationStatus
...@@ -13,7 +15,20 @@ def test_survey_route_list_response(client, test_survey_data): ...@@ -13,7 +15,20 @@ def test_survey_route_list_response(client, test_survey_data):
assert result assert result
def test_survey_route_new(client, test_survey_data): def test_survey_route_new(app, client, test_survey_data, mocked_user):
rv = client.post(
'/api/survey/new',
headers={'Accept': ['application/json']})
assert rv.status_code == 400
result = json.loads(rv.data.decode('utf-8'))
assert not result.get('success')
# mark all surveys as published
with app.app_context():
for survey in db.session.query(Survey).all():
survey.status = SurveyStatus.published
db.session.commit()
rv = client.post( rv = client.post(
'/api/survey/new', '/api/survey/new',
headers={'Accept': ['application/json']}) headers={'Accept': ['application/json']})
...@@ -27,13 +42,26 @@ def test_survey_route_new(client, test_survey_data): ...@@ -27,13 +42,26 @@ def test_survey_route_new(client, test_survey_data):
assert rv.status_code != 200 assert rv.status_code != 200
def test_survey_route_open_close(client, test_survey_data): def test_survey_route_open_close(app, client, test_survey_data, mocked_user):
rv = client.post(
'/api/survey/new',
headers={'Accept': ['application/json']})
assert rv.status_code == 400
result = json.loads(rv.data.decode('utf-8'))
assert not result.get('success')
# mark all surveys as published
with app.app_context():
for survey in db.session.query(Survey).all():
survey.status = SurveyStatus.published
db.session.commit()
rv = client.post( rv = client.post(
'/api/survey/new', '/api/survey/new',
headers={'Accept': ['application/json']}) headers={'Accept': ['application/json']})
assert rv.status_code == 200 assert rv.status_code == 200
result = json.loads(rv.data.decode('utf-8')) result = json.loads(rv.data.decode('utf-8'))
assert result == {'success': True} assert result.get('success')
rv = client.post( rv = client.post(
'/api/survey/open/2023', '/api/survey/open/2023',
...@@ -60,13 +88,25 @@ def test_survey_route_open_close(client, test_survey_data): ...@@ -60,13 +88,25 @@ def test_survey_route_open_close(client, test_survey_data):
assert rv.status_code != 200 assert rv.status_code != 200
def test_survey_route_publish(client, test_survey_data): def test_survey_route_publish(app, client, test_survey_data, mocked_admin_user):
rv = client.post(
'/api/survey/publish/2022',
headers={'Accept': ['application/json']})
assert rv.status_code == 400
result = json.loads(rv.data.decode('utf-8'))
assert not result.get('success')
with app.app_context():
survey = db.session.scalar(Survey.query.filter(Survey.year == 2022))
survey.status = SurveyStatus.closed
db.session.commit()
rv = client.post( rv = client.post(
'/api/survey/publish/2022', '/api/survey/publish/2022',
headers={'Accept': ['application/json']}) headers={'Accept': ['application/json']})
assert rv.status_code == 200 assert rv.status_code == 200
result = json.loads(rv.data.decode('utf-8')) result = json.loads(rv.data.decode('utf-8'))
assert result == {'success': True} assert result.get('success')
def test_survey_route_try_response(client, test_survey_data): def test_survey_route_try_response(client, test_survey_data):
...@@ -89,7 +129,7 @@ def test_survey_route_inspect_response(client, test_survey_data): ...@@ -89,7 +129,7 @@ def test_survey_route_inspect_response(client, test_survey_data):
assert result assert result
def test_survey_route_save_load_response(client, test_survey_data): def test_survey_route_save_load_response(client, test_survey_data, mocked_user):
rv = client.post( rv = client.post(
'/api/survey/save/2021/nren2', '/api/survey/save/2021/nren2',
headers={'Accept': ['application/json']}, headers={'Accept': ['application/json']},
...@@ -98,9 +138,10 @@ def test_survey_route_save_load_response(client, test_survey_data): ...@@ -98,9 +138,10 @@ def test_survey_route_save_load_response(client, test_survey_data):
'page': 3, 'page': 3,
'verification_status': {'q1': VerificationStatus.Verified} 'verification_status': {'q1': VerificationStatus.Verified}
}) })
assert rv.status_code == 200 assert rv.status_code == 400
result = json.loads(rv.data.decode('utf-8')) result = json.loads(rv.data.decode('utf-8'))
assert result == {'success': True} assert not result.get('success')
assert result.get('message') == 'Survey is closed'
rv = client.get( rv = client.get(
'/api/survey/load/2021/nren2', '/api/survey/load/2021/nren2',
...@@ -108,9 +149,21 @@ def test_survey_route_save_load_response(client, test_survey_data): ...@@ -108,9 +149,21 @@ def test_survey_route_save_load_response(client, test_survey_data):
assert rv.status_code == 200 assert rv.status_code == 200
result = json.loads(rv.data.decode('utf-8')) result = json.loads(rv.data.decode('utf-8'))
jsonschema.validate(result, SURVEY_RESPONSE_SCHEMA) jsonschema.validate(result, SURVEY_RESPONSE_SCHEMA)
assert result['page'] == 3 assert result['page'] == 0
assert result['data'] == {'q1': 'yes', 'q2': ['no']} assert result['data'] == {}
assert result['verification_status'] == {'q1': VerificationStatus.Verified} assert result['verification_status'] == {}
rv = client.post(
'/api/survey/save/2022/nren2',
headers={'Accept': ['application/json']},
json={
'data': {'q1': 'yes', 'q2': ['no']},
'page': 3,
'verification_status': {'q1': VerificationStatus.Verified}
})
assert rv.status_code == 200
result = json.loads(rv.data.decode('utf-8'))
assert result.get('success')
rv = client.get( rv = client.get(
'/api/survey/load/2022/nren2', '/api/survey/load/2022/nren2',
...@@ -118,6 +171,6 @@ def test_survey_route_save_load_response(client, test_survey_data): ...@@ -118,6 +171,6 @@ def test_survey_route_save_load_response(client, test_survey_data):
assert rv.status_code == 200 assert rv.status_code == 200
result = json.loads(rv.data.decode('utf-8')) result = json.loads(rv.data.decode('utf-8'))
jsonschema.validate(result, SURVEY_RESPONSE_SCHEMA) jsonschema.validate(result, SURVEY_RESPONSE_SCHEMA)
assert result['page'] == 0 assert result['page'] == 3
assert result['data'] == {'q1': 'yes', 'q2': ['no']} assert result['data'] == {'q1': 'yes', 'q2': ['no']}
assert result['verification_status'] == {'q1': VerificationStatus.Unverified, 'q2': VerificationStatus.Unverified} assert result['verification_status'] == {'q1': VerificationStatus.Verified}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment