diff --git a/sage_validation/file_validator/forms.py b/sage_validation/file_validator/forms.py index 63123197e33c6bb5bddfe49e101b1ce454eab868..0442892bb4842d37a09e17ef4693ef7b6e9bacbc 100644 --- a/sage_validation/file_validator/forms.py +++ b/sage_validation/file_validator/forms.py @@ -7,6 +7,8 @@ from typing import ClassVar from django import forms from django.core.files.uploadedfile import UploadedFile +from sage_validation.file_validator.models import PlAccountCodes + class CSVUploadForm(forms.Form): """Form for uploading CSV files only.""" @@ -69,8 +71,13 @@ class CSVUploadForm(forms.Form): fieldnames = reader.fieldnames if reader.fieldnames is not None else [] self._validate_headers(fieldnames) + error_list = [] # Step 3: Validate 'Source' and 'SYSTraderTranType' values - self._validate_source_and_trader_type(reader) + data = list(reader) + error_list.extend(self._validate_source_and_trader_type(data)) + error_list.extend(self._validate_nominal_analysis_account(data)) + if error_list: + raise forms.ValidationError(error_list) return file @@ -113,7 +120,7 @@ class CSVUploadForm(forms.Form): raise forms.ValidationError(msg) @staticmethod - def _validate_source_and_trader_type(data: Iterable[dict]) -> None: + def _validate_source_and_trader_type(data: Iterable[dict]) -> list: """Validate that 'Source' is always 80 and 'SYSTraderTranType' is always 4.""" errors = [] @@ -124,5 +131,44 @@ class CSVUploadForm(forms.Form): if row.get("SYSTraderTranType") != "4": errors.append(f"Row {index}: 'SYSTraderTranType' must be 4, but found {row.get('SYSTraderTranType')}.") - if errors: - raise forms.ValidationError(errors) + return errors + + @staticmethod + def _validate_nominal_analysis_account(data: Iterable[dict]) -> list[str]: + """ + Validate that 'AccountNumber' matches the name in 'NominalAnalysisNominalAnalysisNarrative/1'. + + This only checks the first group of NominalAnalysis columns. A list of codes/names + is fetched from the database for validation (from the 'PL Account Codes' table). + + Args: + data (Iterable[dict]): The rows of data to validate. + + Returns: + List[str]: A list of error messages, if any. + """ + errors = [] + + account_code_map = { + obj.pl_account_code: obj.pl_account_name + for obj in PlAccountCodes.objects.using("meo").all() + } + + for index, row in enumerate(data, start=1): + account_code = row.get("AccountNumber") + nominal = row.get("NominalAnalysisNominalAnalysisNarrative/1") + + # Skip rows without 'AccountNumber' or 'NominalAnalysisNominalAnalysisNarrative/1' + if not account_code or not nominal: + continue + + pl_account_name = account_code_map.get(account_code) + if pl_account_name is None: + errors.append(f"Row {index}: 'AccountNumber' {account_code} does not exist in PL Account Codes.") + elif pl_account_name not in nominal: + errors.append( + f"Row {index}: 'AccountNumber' must match '{pl_account_name}' in " + f"'NominalAnalysisNominalAnalysisNarrative/1', but found '{nominal}'." + ) + + return errors \ No newline at end of file