Skip to content
Snippets Groups Projects
forms.py 11.56 KiB
"""Forms for the file_validator app."""

import csv
import re
from collections.abc import Sequence
from typing import ClassVar

from django import forms
from django.core.files.uploadedfile import UploadedFile

from sage_validation.file_validator.models import (
    MeoCostCentres,
    MeoNominal,
    MeoValidSageAccounts,
    MeoValidSuppliers,
    XxData,
)


class CSVUploadForm(forms.Form):
    """Form for uploading CSV files only."""

    file = forms.FileField(label="Select a CSV file")

    required_columns: ClassVar[list] = [
        "AccountNumber",
        "CBAccountNumber",
        "DaysDiscountValid",
        "DiscountValue",
        "DiscountPercentage",
        "DueDate",
        "GoodsValueInAccountCurrency",
        "PurControlValueInBaseCurrency",
        "DocumentToBaseCurrencyRate",
        "DocumentToAccountCurrencyRate",
        "PostedDate",
        "QueryCode",
        "TransactionReference",
        "SecondReference",
        "Source",
        "SYSTraderTranType",
        "TransactionDate",
        "UniqueReferenceNumber",
        "UserNumber",
        "TaxValue",
        "SYSTraderGenerationReasonType",
        "GoodsValueInBaseCurrency",
        "ChequeCurrencyName",
        "ChequeToBankExchangeRate",
        "ChequeValueInChequeCurrency",
    ]

    repeating_columns: ClassVar[dict] = {
        "NominalAnalysis": [
            "NominalAnalysisTransactionValue",
            "NominalAnalysisNominalAccountNumber",
            "NominalAnalysisNominalCostCentre",
            "NominalAnalysisNominalDepartment",
            "NominalAnalysisNominalAnalysisNarrative",
            "NominalAnalysisTransactionAnalysisCode",
        ],
        "TaxAnalysis": [
            "TaxAnalysisTaxRate",
            "TaxAnalysisGoodsValueBeforeDiscount",
            "TaxAnalysisDiscountValue",
            "TaxAnalysisDiscountPercentage",
            "TaxAnalysisTaxOnGoodsValue",
        ],
    }

    def clean_file(self) -> UploadedFile:
        """Validate the uploaded file."""
        file = self.cleaned_data["file"]

        # Step 1: Validate file type
        self._validate_file_type(file)

        # Step 2: Parse file and validate headers
        csv_file = file.read().decode("utf-8").splitlines()
        reader = csv.DictReader(csv_file, delimiter=",")
        fieldnames = reader.fieldnames if reader.fieldnames is not None else []
        self._validate_headers(fieldnames)

        error_list = []
        data = list(reader)
        error_list.extend(self._validate_source_and_trader_type(data))
        error_list.extend(self._validate_nominal_analysis_account(data))
        error_list.extend(self._validate_nc_cc_dep_combination_against_meo_sage_account(data))
        error_list.extend(self._cheque_fields_must_be_empty(data))
        if error_list:
            raise forms.ValidationError(error_list)

        return file

    @staticmethod
    def _get_max_repeat(fieldnames: Sequence[str], section_prefix: str) -> int:
        """Identify the maximum number of repeats for a section."""
        max_repeat = 0
        for field in fieldnames:
            if field.startswith(section_prefix):
                try:
                    repeat_number = int(field.split("/")[-1])
                    max_repeat = max(max_repeat, repeat_number)
                except ValueError:
                    continue
        return max_repeat

    @staticmethod
    def _validate_file_type(file: UploadedFile) -> None:
        """Validate that the uploaded file is a CSV."""
        if not file.name.endswith(".csv"):
            msg = "File must be in CSV format."
            raise forms.ValidationError(msg)

    def _validate_headers(self, fieldnames: Sequence[str]) -> None:
        """Validate required and repeating columns in the headers."""
        missing_columns = [col for col in self.required_columns if col not in fieldnames]

        for section_name, column_list in self.repeating_columns.items():
            max_repeat = self._get_max_repeat(fieldnames, section_name)
            if max_repeat == 0:
                missing_columns.extend([f"{base_col}/1" for base_col in column_list])
            else:
                for repeat in range(1, max_repeat + 1):
                    missing_columns.extend([
                        f"{base_col}/{repeat}" for base_col in column_list if f"{base_col}/{repeat}" not in fieldnames
                    ])

        if missing_columns:
            msg = f"Missing required columns: {', '.join(missing_columns)}"
            raise forms.ValidationError(msg)

    def _validate_source_and_trader_type(self, data: list[dict]) -> list:
        """Validate that 'Source' is always 80 and 'SYSTraderTranType' is always 4."""
        errors = []

        for index, row in enumerate(data, start=1):
            claimant_name = self.get_account_name_from_code(row.get("AccountNumber"))
            claim_number = row.get("SecondReference")
            if row.get("Source") != "80":
                errors.append(f"Row {index}, claimant: {claimant_name} with claim number: {claim_number}: "
                              f"'Source' must be 80, but found {row.get('Source')}.")

            if row.get("SYSTraderTranType") != "4":
                errors.append(f"Row {index}, claimant: {claimant_name} with claim number: {claim_number}: "
                              f"'SYSTraderTranType' must be 4, but found {row.get('SYSTraderTranType')}.")

        return errors

    @staticmethod
    def _validate_nominal_analysis_account(data: list[dict]) -> list[str]:
        """Validate that 'AccountNumber' matches the name in 'NominalAnalysisNominalAnalysisNarrative/1'.

        This only checks the first group of NominalAnalysis columns. A list of codes/names
        is fetched from the database for validation (from the 'PL Account Codes' table).

        Args:
            data (list[dict]): The rows of data to validate.

        Returns:
            List[str]: A list of error messages, if any.

        """
        errors = []

        account_code_map = {
            obj.supplier_account_number: obj.supplier_account_name
            for obj in MeoValidSuppliers.objects.using("meo").all()  # type: ignore[attr-defined]
        }

        for index, row in enumerate(data, start=1):
            account_code = row.get("AccountNumber")
            nominal = row.get("NominalAnalysisNominalAnalysisNarrative/1")

            # Skip rows without 'AccountNumber' or 'NominalAnalysisNominalAnalysisNarrative/1'
            if not account_code or not nominal:
                continue

            pl_account_name = account_code_map.get(account_code)

            if pl_account_name is None:
                errors.append(f"Row {index}: 'AccountNumber' {account_code} does not exist in PL Account Codes.")
            else:
                # Remove 'Soldo' and any hyphens from the PL account name. This is for credit card accounts.
                revised_pl_account_name = re.sub(
                    r"\bSoldo\b|\s*-\s*", "", pl_account_name, flags=re.IGNORECASE).strip()
                if revised_pl_account_name not in nominal:
                    errors.append(
                        f"Row {index}: 'AccountNumber' must match '{revised_pl_account_name}' in "
                        f"'NominalAnalysisNominalAnalysisNarrative/1', but found '{nominal}'."
                    )

        return errors

    @staticmethod
    def get_account_name_from_code(account_code: str| None) -> str | None:
        """Get the account name from the PL Account Codes table."""
        if account_code is None:
            return None
        try:
            return MeoValidSuppliers.objects.using("meo").get(
                supplier_account_number=account_code).supplier_account_name
        except MeoValidSuppliers.DoesNotExist:
            return None

    def _validate_nc_cc_dep_combination_against_meo_sage_account(self, data: list[dict]) -> list[str]:
        """Validate that all nominal analysis fields exist in MEO.

        This includes 'NominalAnalysisNominalCostCentre/{N}', 'NominalAnalysisNominalDepartment/{N}',
        and 'NominalAnalysisNominalAccountNumber/{N}'.

        Args:
            data (list[dict]): The rows of data to validate.

        Returns:
            List[str]: A list of error messages, if any.

        """
        errors = []

        cost_centre_map = {
            obj.cc: obj.cc_type for obj in MeoCostCentres.objects.using("meo").all()
        }

        xx_data_map = {
            obj.xx_value: (obj.project, obj.overhead) for obj in XxData.objects.using("meo").all()
        }


        fieldnames = list(data[0].keys())
        max_repeat = self._get_max_repeat(fieldnames, "NominalAnalysisNominalCostCentre")

        for index, row in enumerate(data, start=1):
            claimant_name = self.get_account_name_from_code(row.get("AccountNumber"))
            claim_number = row.get("SecondReference")
            for repeat in range(1, max_repeat + 1):
                cc_field = f"NominalAnalysisNominalCostCentre/{repeat}"
                dep_field = f"NominalAnalysisNominalDepartment/{repeat}"
                nominal_account_field = f"NominalAnalysisNominalAccountNumber/{repeat}"

                cc = row.get(cc_field)
                dep = row.get(dep_field)
                nominal_account_name = row.get(nominal_account_field)

                if not cc or not dep or not nominal_account_name:
                    continue

                cc_type = cost_centre_map.get(cc)
                if not cc_type:
                    errors.append(f"Row {index}: '{cc_field}' ({cc}) is not a valid cost centre.")
                    continue

                xx_data = xx_data_map.get(nominal_account_name)

                if xx_data:
                    nc = xx_data[0] if cc_type == "Project" else xx_data[1]
                elif MeoNominal.objects.using("meo").filter(nom=nominal_account_name).exists():
                    nc = nominal_account_name
                else:
                    errors.append(f"Row {index}: '{nominal_account_field}' ({nominal_account_name}) is not valid.")
                    continue

                if not MeoValidSageAccounts.objects.using("meo").filter(
                        account_cost_centre=cc, account_department=dep, account_number=nc
                ).exists():
                    errors.append(
                        f"Row {index}: The combination of '{cc_field}' ({cc}), "
                        f"'{dep_field}' ({dep}), and '{nominal_account_field}' "
                        f"({nc}) for claimant '{claimant_name}' and claim number '{claim_number}' "
                        f"does not exist in MEO valid Sage accounts."
                    )

        return errors

    def _cheque_fields_must_be_empty(self, data: list[dict]) -> list[str]:
        """Validate that cheque fields are empty.

        The cheque fields are 'ChequeCurrencyName', 'ChequeToBankExchangeRate', and 'ChequeValueInChequeCurrency'.
        """
        errors = []
        for index, row in enumerate(data, start=1):
            cheque_currency_name = row.get("ChequeCurrencyName")
            cheque_to_bank_exchange_rate = row.get("ChequeToBankExchangeRate")
            cheque_value_in_cheque_currency = row.get("ChequeValueInChequeCurrency")
            claimant_name = self.get_account_name_from_code(row.get("AccountNumber"))
            claim_number = row.get("SecondReference")
            if any([cheque_currency_name, cheque_to_bank_exchange_rate, cheque_value_in_cheque_currency]):
                errors.append(
                    f"Row {index}: Unexpected values in the Cheque columns for {claimant_name} with claim number: "
                    f"{claim_number}. All cheque columns must be empty."
                )

        return errors