Skip to content
Snippets Groups Projects
Commit edd74561 authored by Neda Moeini's avatar Neda Moeini
Browse files

Merge branch 'feature/FINSUP-37' into 'develop'

Refactor CSV file validation and error handling in forms.py and views.py in...

See merge request !12
parents a1bec936 0263d365
Branches
Tags
1 merge request!12Refactor CSV file validation and error handling in forms.py and views.py in...
"""Forms for the file_validator app.""" """Forms for the file_validator app."""
import csv import csv
import io
import re import re
from collections.abc import Sequence from collections.abc import Sequence
from typing import ClassVar from typing import ClassVar, Self
from django import forms from django import forms
from django.core.files.uploadedfile import UploadedFile from django.core.files.uploadedfile import UploadedFile
...@@ -71,29 +72,45 @@ class CSVUploadForm(forms.Form): ...@@ -71,29 +72,45 @@ class CSVUploadForm(forms.Form):
def clean_file(self) -> UploadedFile: def clean_file(self) -> UploadedFile:
"""Validate the uploaded file.""" """Validate the uploaded file."""
file = self.cleaned_data["file"] file = self.cleaned_data["file"]
# Step 1: Validate file type
self._validate_file_type(file) self._validate_file_type(file)
# Step 2: Parse file and validate headers text_stream = io.TextIOWrapper(file, encoding="utf-8-sig")
raw_data = file.read().decode("utf-8-sig") csv_content = text_stream.read().strip()
normalized_data = raw_data.replace("\r\n", "\n").replace("\r", "\n")
csv_file = normalized_data.splitlines() if not csv_content:
reader = csv.DictReader(csv_file, delimiter=",") error_message = "CSV upload failed."
fieldnames = reader.fieldnames if reader.fieldnames is not None else [] raise forms.ValidationError(error_message)
reader = csv.DictReader(io.StringIO(csv_content))
fieldnames = reader.fieldnames or []
self._validate_headers(fieldnames) self._validate_headers(fieldnames)
self._load_reference_data()
error_list = [] errors = []
data = list(reader) for index, row in enumerate(reader, start=1):
error_list.extend(self._validate_source_and_trader_type(data)) errors.extend(self._validate_source_and_trader_type(row, index))
error_list.extend(self._validate_nominal_analysis_account(data)) errors.extend(self._validate_nominal_analysis_account(row, index))
error_list.extend(self._validate_nc_cc_dep_combination_against_meo_sage_account(data)) errors.extend(self._validate_nc_cc_dep_combination_against_meo_sage_account(row, index))
error_list.extend(self._cheque_fields_must_be_empty(data)) errors.extend(self._cheque_fields_must_be_empty(row, index))
if error_list:
raise forms.ValidationError(error_list)
if errors:
raise forms.ValidationError(errors)
self.cleaned_data["csv_data"] = csv_content
return file return file
def _load_reference_data(self: Self) -> None:
self.supplier_map = {
s.supplier_account_number: s.supplier_account_name
for s in MeoValidSuppliers.objects.using("meo").all()
}
self.cost_centre_map = {
cc.cc: cc.cc_type for cc in MeoCostCentres.objects.using("meo").all()
}
self.xx_data_map = {
x.xx_value: (x.project, x.overhead) for x in XxData.objects.using("meo").all()
}
@staticmethod @staticmethod
def _get_max_repeat(fieldnames: Sequence[str], section_prefix: str) -> int: def _get_max_repeat(fieldnames: Sequence[str], section_prefix: str) -> int:
"""Identify the maximum number of repeats for a section.""" """Identify the maximum number of repeats for a section."""
...@@ -132,66 +149,54 @@ class CSVUploadForm(forms.Form): ...@@ -132,66 +149,54 @@ class CSVUploadForm(forms.Form):
msg = f"Missing required columns: {', '.join(missing_columns)}" msg = f"Missing required columns: {', '.join(missing_columns)}"
raise forms.ValidationError(msg) raise forms.ValidationError(msg)
def _validate_source_and_trader_type(self, data: list[dict]) -> list: def _validate_source_and_trader_type(self, row: dict, index: int) -> list:
"""Validate that 'Source' is always 80 and 'SYSTraderTranType' is always 4.""" """Validate that 'Source' is always 80 and 'SYSTraderTranType' is always 4."""
errors = [] errors = []
claimant_name = self.supplier_map.get(row.get("AccountNumber"))
claim_number = row.get("SecondReference")
for index, row in enumerate(data, start=1): if row.get("Source") != "80":
claimant_name = self.get_account_name_from_code(row.get("AccountNumber")) errors.append(
claim_number = row.get("SecondReference") f"Row {index}, claimant: {claimant_name} with claim number: {claim_number}: "
if row.get("Source") != "80": f"'Source' must be 80, but found {row.get('Source')}.")
errors.append(f"Row {index}, claimant: {claimant_name} with claim number: {claim_number}: "
f"'Source' must be 80, but found {row.get('Source')}.")
if row.get("SYSTraderTranType") != "4": if row.get("SYSTraderTranType") != "4":
errors.append(f"Row {index}, claimant: {claimant_name} with claim number: {claim_number}: " errors.append(
f"'SYSTraderTranType' must be 4, but found {row.get('SYSTraderTranType')}.") f"Row {index}, claimant: {claimant_name} with claim number: {claim_number}: "
f"'SYSTraderTranType' must be 4, but found {row.get('SYSTraderTranType')}.")
return errors return errors
@staticmethod def _validate_nominal_analysis_account(self, row: dict, index: int) -> list:
def _validate_nominal_analysis_account(data: list[dict]) -> list[str]:
"""Validate that 'AccountNumber' matches the name in 'NominalAnalysisNominalAnalysisNarrative/1'. """Validate that 'AccountNumber' matches the name in 'NominalAnalysisNominalAnalysisNarrative/1'.
This only checks the first group of NominalAnalysis columns. A list of codes/names This only checks the first group of NominalAnalysis columns. A list of codes/names
is fetched from the database for validation (from the 'PL Account Codes' table). is fetched from the database for validation (from the 'PL Account Codes' table).
Args: Args:
data (list[dict]): The rows of data to validate. row (dict): The row of data to validate.
index (int): The index of the row in the CSV file.
Returns: Returns:
List[str]: A list of error messages, if any. List[str]: A list of error messages, if any.
""" """
errors = [] errors: list[str] = []
account_code = row.get("AccountNumber")
account_code_map = { nominal = row.get("NominalAnalysisNominalAnalysisNarrative/1")
obj.supplier_account_number: obj.supplier_account_name if not account_code or not nominal:
for obj in MeoValidSuppliers.objects.using("meo").all() # type: ignore[attr-defined] return errors
}
pl_account_name = self.supplier_map.get(account_code)
for index, row in enumerate(data, start=1): if pl_account_name is None:
account_code = row.get("AccountNumber") errors.append(f"Row {index}: 'AccountNumber' {account_code} does not exist in PL Account Codes.")
nominal = row.get("NominalAnalysisNominalAnalysisNarrative/1") else:
revised_name = re.sub(r"\bSoldo\b|\s*-\s*", "", pl_account_name, flags=re.IGNORECASE).strip()
# Skip rows without 'AccountNumber' or 'NominalAnalysisNominalAnalysisNarrative/1' if revised_name not in nominal:
if not account_code or not nominal: errors.append(
continue f"Row {index}: 'AccountNumber' must match '{revised_name}' in "
pl_account_name = account_code_map.get(account_code)
if pl_account_name is None:
errors.append(f"Row {index}: 'AccountNumber' {account_code} does not exist in PL Account Codes.")
else:
# Remove 'Soldo' and any hyphens from the PL account name. This is for credit card accounts.
revised_pl_account_name = re.sub(
r"\bSoldo\b|\s*-\s*", "", pl_account_name, flags=re.IGNORECASE).strip()
if revised_pl_account_name not in nominal:
errors.append(
f"Row {index}: 'AccountNumber' must match '{revised_pl_account_name}' in "
f"'NominalAnalysisNominalAnalysisNarrative/1', but found '{nominal}'." f"'NominalAnalysisNominalAnalysisNarrative/1', but found '{nominal}'."
) )
return errors return errors
@staticmethod @staticmethod
...@@ -205,95 +210,82 @@ class CSVUploadForm(forms.Form): ...@@ -205,95 +210,82 @@ class CSVUploadForm(forms.Form):
except MeoValidSuppliers.DoesNotExist: except MeoValidSuppliers.DoesNotExist:
return None return None
def _validate_nc_cc_dep_combination_against_meo_sage_account(self, data: list[dict]) -> list[str]: def _validate_nc_cc_dep_combination_against_meo_sage_account(self, row: dict, index: int) -> list:
"""Validate that all nominal analysis fields exist in MEO. """Validate that all nominal analysis fields exist in MEO.
This includes 'NominalAnalysisNominalCostCentre/{N}', 'NominalAnalysisNominalDepartment/{N}', This includes 'NominalAnalysisNominalCostCentre/{N}', 'NominalAnalysisNominalDepartment/{N}',
and 'NominalAnalysisNominalAccountNumber/{N}'. and 'NominalAnalysisNominalAccountNumber/{N}'.
Args: Args:
data (list[dict]): The rows of data to validate. row (dict): The row of data to validate.
index (int): The index of the row in the CSV file.
Returns: Returns:
List[str]: A list of error messages, if any. List[str]: A list of error messages, if any.
""" """
errors = [] errors = []
fieldnames = list(row.keys())
cost_centre_map = {
obj.cc: obj.cc_type for obj in MeoCostCentres.objects.using("meo").all()
}
xx_data_map = {
obj.xx_value: (obj.project, obj.overhead) for obj in XxData.objects.using("meo").all()
}
fieldnames = list(data[0].keys())
max_repeat = self._get_max_repeat(fieldnames, "NominalAnalysisNominalCostCentre") max_repeat = self._get_max_repeat(fieldnames, "NominalAnalysisNominalCostCentre")
claimant_name = self.get_account_name_from_code(row.get("AccountNumber"))
claim_number = row.get("SecondReference")
for repeat in range(1, max_repeat + 1):
cc_field = f"NominalAnalysisNominalCostCentre/{repeat}"
dep_field = f"NominalAnalysisNominalDepartment/{repeat}"
nom_field = f"NominalAnalysisNominalAccountNumber/{repeat}"
cc = row.get(cc_field)
dep = row.get(dep_field)
nom= row.get(nom_field)
if not cc and not dep and not nom:
continue
if not cc or not dep or not nom:
errors.append(
f"Row {index}: Missing values in '{cc_field}', '{dep_field}', or '{nom_field}'.")
continue
for index, row in enumerate(data, start=1): cc_type = self.cost_centre_map.get(cc)
claimant_name = self.get_account_name_from_code(row.get("AccountNumber")) if not cc_type:
claim_number = row.get("SecondReference") errors.append(f"Row {index}: '{cc_field}' ({cc}) is not a valid cost centre.")
for repeat in range(1, max_repeat + 1): continue
cc_field = f"NominalAnalysisNominalCostCentre/{repeat}"
dep_field = f"NominalAnalysisNominalDepartment/{repeat}"
nominal_account_field = f"NominalAnalysisNominalAccountNumber/{repeat}"
cc = row.get(cc_field)
dep = row.get(dep_field)
nominal_account_name = row.get(nominal_account_field)
if not cc and not dep and not nominal_account_name:
continue
if not cc or not dep or not nominal_account_name:
errors.append(
f"Row {index}: Missing values in '{cc_field}', '{dep_field}', or '{nominal_account_field}'.")
continue
cc_type = cost_centre_map.get(cc)
if not cc_type:
errors.append(f"Row {index}: '{cc_field}' ({cc}) is not a valid cost centre.")
continue
xx_data = xx_data_map.get(nominal_account_name)
if xx_data: xx_data = self.xx_data_map.get(nom)
nc = xx_data[0] if cc_type == "Project" else xx_data[1] if xx_data:
elif MeoNominal.objects.using("meo").filter(nom=nominal_account_name).exists(): nc = xx_data[0] if cc_type == "Project" else xx_data[1]
nc = nominal_account_name elif MeoNominal.objects.using("meo").filter(nom=nom).exists():
else: nc = nom
errors.append(f"Row {index}: '{nominal_account_field}' ({nominal_account_name}) is not valid.") else:
continue errors.append(f"Row {index}: '{nom_field}' ({nom}) is not valid.")
continue
if not MeoValidSageAccounts.objects.using("meo").filter( if not MeoValidSageAccounts.objects.using("meo").filter(
account_cost_centre=cc, account_department=dep, account_number=nc account_cost_centre=cc, account_department=dep, account_number=nc
).exists(): ).exists():
errors.append( errors.append(
f"Row {index}: The combination of '{cc_field}' ({cc}), " f"Row {index}: The combination of '{cc_field}' ({cc}), "
f"'{dep_field}' ({dep}), and '{nominal_account_field}' " f"'{dep_field}' ({dep}), and '{nom_field}' "
f"({nc}) for claimant '{claimant_name}' and claim number '{claim_number}' " f"({nc}) for claimant '{claimant_name}' and claim number '{claim_number}' "
f"does not exist in MEO valid Sage accounts." f"does not exist in MEO valid Sage accounts."
) )
return errors return errors
def _cheque_fields_must_be_empty(self, data: list[dict]) -> list[str]: def _cheque_fields_must_be_empty(self, row: dict, index: int) -> list:
"""Validate that cheque fields are empty. """Validate that cheque fields are empty.
The cheque fields are 'ChequeCurrencyName', 'ChequeToBankExchangeRate', and 'ChequeValueInChequeCurrency'. The cheque fields are 'ChequeCurrencyName', 'ChequeToBankExchangeRate', and 'ChequeValueInChequeCurrency'.
""" """
errors = [] errors = []
for index, row in enumerate(data, start=1): if any([
cheque_currency_name = row.get("ChequeCurrencyName") row.get("ChequeCurrencyName"),
cheque_to_bank_exchange_rate = row.get("ChequeToBankExchangeRate") row.get("ChequeToBankExchangeRate"),
cheque_value_in_cheque_currency = row.get("ChequeValueInChequeCurrency") row.get("ChequeValueInChequeCurrency")
]):
claimant_name = self.get_account_name_from_code(row.get("AccountNumber")) claimant_name = self.get_account_name_from_code(row.get("AccountNumber"))
claim_number = row.get("SecondReference") claim_number = row.get("SecondReference")
if any([cheque_currency_name, cheque_to_bank_exchange_rate, cheque_value_in_cheque_currency]): errors.append(
errors.append( f"Row {index}: Unexpected values in the Cheque columns for {claimant_name} with claim number: "
f"Row {index}: Unexpected values in the Cheque columns for {claimant_name} with claim number: " f"{claim_number}. All cheque columns must be empty."
f"{claim_number}. All cheque columns must be empty." )
)
return errors return errors
...@@ -51,20 +51,14 @@ class CSVUploadAPIView(APIView): ...@@ -51,20 +51,14 @@ class CSVUploadAPIView(APIView):
if not form.is_valid(): if not form.is_valid():
return Response({"status": "error", "errors": form.errors}, status=status.HTTP_400_BAD_REQUEST) return Response({"status": "error", "errors": form.errors}, status=status.HTTP_400_BAD_REQUEST)
csv_file = form.cleaned_data["file"] decoded_file = form.cleaned_data["csv_data"]
csv_file.seek(0)
decoded_file = csv_file.read().decode("utf-8-sig").strip()
if not decoded_file:
return Response({"status": "error", "message": "Uploaded file is empty."},
status=status.HTTP_400_BAD_REQUEST)
reader = csv.DictReader(io.StringIO(decoded_file)) reader = csv.DictReader(io.StringIO(decoded_file))
csv_data: list[dict[str, str]] = list(reader) csv_data: list[dict[str, str]] = list(reader)
updated_data = self.update_fields(csv_data) updated_data = self.update_fields(csv_data)
request.session["validated_csv"] = updated_data request.session["validated_csv"] = updated_data
request.session["input_file_hash"] = UserActivityLog.generate_file_hash(csv_file) request.session["input_file_hash"] = UserActivityLog.generate_file_hash(form.cleaned_data["csv_data"])
request.session.modified = True request.session.modified = True
return Response({ return Response({
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment