From 17e2f68adc20b254f21ad47d6867562aeacf0892 Mon Sep 17 00:00:00 2001
From: Neda Moeini <neda.moeini@geant.org>
Date: Mon, 27 Jan 2025 09:30:45 +0000
Subject: [PATCH] Clean up the structure to make more room for adding new
 functionality

---
 sage_validation/file_validator/apps.py  |   1 +
 sage_validation/file_validator/forms.py | 118 ++++++++++++++----------
 sage_validation/file_validator/urls.py  |   1 +
 sage_validation/file_validator/views.py |   1 +
 4 files changed, 74 insertions(+), 47 deletions(-)

diff --git a/sage_validation/file_validator/apps.py b/sage_validation/file_validator/apps.py
index bd972d5..0356f13 100644
--- a/sage_validation/file_validator/apps.py
+++ b/sage_validation/file_validator/apps.py
@@ -1,4 +1,5 @@
 """App configuration for file_validator app."""
+
 from django.apps import AppConfig
 
 
diff --git a/sage_validation/file_validator/forms.py b/sage_validation/file_validator/forms.py
index e3a2929..9eac16e 100644
--- a/sage_validation/file_validator/forms.py
+++ b/sage_validation/file_validator/forms.py
@@ -1,9 +1,11 @@
 """Forms for the file_validator app."""
+
 import csv
 from collections.abc import Sequence
 from typing import ClassVar
 
 from django import forms
+from django.core.files.uploadedfile import UploadedFile
 
 
 class CSVUploadForm(forms.Form):
@@ -12,68 +14,65 @@ class CSVUploadForm(forms.Form):
     file = forms.FileField(label="Select a CSV file")
 
     required_columns: ClassVar[list] = [
-        "AccountNumber", "CBAccountNumber", "DaysDiscountValid", "DiscountValue",
-        "DiscountPercentage", "DueDate", "GoodsValueInAccountCurrency",
-        "PurControlValueInBaseCurrency", "DocumentToBaseCurrencyRate",
-        "DocumentToAccountCurrencyRate", "PostedDate", "QueryCode",
-        "TransactionReference", "SecondReference", "Source",
-        "SYSTraderTranType", "TransactionDate", "UniqueReferenceNumber",
-        "UserNumber", "TaxValue", "SYSTraderGenerationReasonType",
-        "GoodsValueInBaseCurrency"
+        "AccountNumber",
+        "CBAccountNumber",
+        "DaysDiscountValid",
+        "DiscountValue",
+        "DiscountPercentage",
+        "DueDate",
+        "GoodsValueInAccountCurrency",
+        "PurControlValueInBaseCurrency",
+        "DocumentToBaseCurrencyRate",
+        "DocumentToAccountCurrencyRate",
+        "PostedDate",
+        "QueryCode",
+        "TransactionReference",
+        "SecondReference",
+        "Source",
+        "SYSTraderTranType",
+        "TransactionDate",
+        "UniqueReferenceNumber",
+        "UserNumber",
+        "TaxValue",
+        "SYSTraderGenerationReasonType",
+        "GoodsValueInBaseCurrency",
     ]
 
     repeating_columns: ClassVar[dict] = {
         "NominalAnalysis": [
-            "NominalAnalysisTransactionValue", "NominalAnalysisNominalAccountNumber",
-            "NominalAnalysisNominalCostCentre", "NominalAnalysisNominalDepartment",
-            "NominalAnalysisNominalAnalysisNarrative", "NominalAnalysisTransactionAnalysisCode"
+            "NominalAnalysisTransactionValue",
+            "NominalAnalysisNominalAccountNumber",
+            "NominalAnalysisNominalCostCentre",
+            "NominalAnalysisNominalDepartment",
+            "NominalAnalysisNominalAnalysisNarrative",
+            "NominalAnalysisTransactionAnalysisCode",
         ],
         "TaxAnalysis": [
-            "TaxAnalysisTaxRate", "TaxAnalysisGoodsValueBeforeDiscount",
-            "TaxAnalysisDiscountValue", "TaxAnalysisDiscountPercentage",
-            "TaxAnalysisTaxOnGoodsValue"
-        ]
+            "TaxAnalysisTaxRate",
+            "TaxAnalysisGoodsValueBeforeDiscount",
+            "TaxAnalysisDiscountValue",
+            "TaxAnalysisDiscountPercentage",
+            "TaxAnalysisTaxOnGoodsValue",
+        ],
     }
 
-    def clean_file(self) -> str:
-        """Validate the uploaded file format and contents."""
+    def clean_file(self) -> UploadedFile:
+        """Validate the uploaded file."""
         file = self.cleaned_data["file"]
 
-        if not file.name.endswith(".csv"):
-            err_msg = "File must be in CSV format"
-            raise forms.ValidationError(err_msg)
-
-        try:
-            csv_file = file.read().decode("utf-8").splitlines()
-            reader = csv.DictReader(csv_file)
-            fieldnames = reader.fieldnames if reader.fieldnames is not None else []
-
-            missing_columns = [col for col in self.required_columns if col not in fieldnames]
-
-            for section_name, column_list in self.repeating_columns.items():
-                max_repeat = self.get_max_repeat(fieldnames, section_name)
+        # Step 1: Validate file type
+        self._validate_file_type(file)
 
-                if max_repeat == 0:
-                    missing_columns.extend([f"{base_col}/1" for base_col in column_list])
-                else:
-                    for repeat in range(1, max_repeat + 1):
-                        missing_columns.extend(
-                            [f"{base_col}/{repeat}" for base_col in column_list if
-                             f"{base_col}/{repeat}" not in fieldnames]
-                        )
-
-            if missing_columns:
-                err_msg = f"Missing required columns: {', '.join(missing_columns)}"
-                raise forms.ValidationError(err_msg)
-
-        except (UnicodeDecodeError, csv.Error) as e:
-            err_msg = f"File could not be processed: {e!s}"
-            raise forms.ValidationError(err_msg) from e
+        # Step 2: Parse file and validate headers
+        csv_file = file.read().decode("utf-8").splitlines()
+        reader = csv.DictReader(csv_file, delimiter=";")
+        fieldnames = reader.fieldnames if reader.fieldnames is not None else []
+        self._validate_headers(fieldnames)
 
         return file
 
     @staticmethod
-    def get_max_repeat(fieldnames: Sequence[str], section_prefix: str) -> int:
+    def _get_max_repeat(fieldnames: Sequence[str], section_prefix: str) -> int:
         """Identify the maximum number of repeats for a section."""
         max_repeat = 0
         for field in fieldnames:
@@ -84,3 +83,28 @@ class CSVUploadForm(forms.Form):
                 except ValueError:
                     continue
         return max_repeat
+
+    @staticmethod
+    def _validate_file_type(file: UploadedFile) -> None:
+        """Validate that the uploaded file is a CSV."""
+        if not file.name.endswith(".csv"):
+            msg = "File must be in CSV format."
+            raise forms.ValidationError(msg)
+
+    def _validate_headers(self, fieldnames: Sequence[str]) -> None:
+        """Validate required and repeating columns in the headers."""
+        missing_columns = [col for col in self.required_columns if col not in fieldnames]
+
+        for section_name, column_list in self.repeating_columns.items():
+            max_repeat = self._get_max_repeat(fieldnames, section_name)
+            if max_repeat == 0:
+                missing_columns.extend([f"{base_col}/1" for base_col in column_list])
+            else:
+                for repeat in range(1, max_repeat + 1):
+                    missing_columns.extend([
+                        f"{base_col}/{repeat}" for base_col in column_list if f"{base_col}/{repeat}" not in fieldnames
+                    ])
+
+        if missing_columns:
+            msg = f"Missing required columns: {', '.join(missing_columns)}"
+            raise forms.ValidationError(msg)
diff --git a/sage_validation/file_validator/urls.py b/sage_validation/file_validator/urls.py
index 78e88de..12e1520 100644
--- a/sage_validation/file_validator/urls.py
+++ b/sage_validation/file_validator/urls.py
@@ -1,4 +1,5 @@
 """Urls for the file_validator app."""
+
 from django.urls import path
 
 from sage_validation.file_validator.views import CSVUploadView
diff --git a/sage_validation/file_validator/views.py b/sage_validation/file_validator/views.py
index 3f75b92..33ec72f 100644
--- a/sage_validation/file_validator/views.py
+++ b/sage_validation/file_validator/views.py
@@ -1,4 +1,5 @@
 """Views for the file_validator app."""
+
 from django.http import HttpRequest, HttpResponse, JsonResponse
 from django.shortcuts import render
 from django.urls import reverse_lazy
-- 
GitLab