From b670047441da80360649c53ed401f6fb173a951c Mon Sep 17 00:00:00 2001 From: Neda Moeini <neda.moeini@geant.org> Date: Mon, 27 Jan 2025 09:56:58 +0000 Subject: [PATCH] Validate Source and TraderType values. --- sage_validation/file_validator/forms.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/sage_validation/file_validator/forms.py b/sage_validation/file_validator/forms.py index 9eac16e..6312319 100644 --- a/sage_validation/file_validator/forms.py +++ b/sage_validation/file_validator/forms.py @@ -1,7 +1,7 @@ """Forms for the file_validator app.""" import csv -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from typing import ClassVar from django import forms @@ -69,6 +69,9 @@ class CSVUploadForm(forms.Form): fieldnames = reader.fieldnames if reader.fieldnames is not None else [] self._validate_headers(fieldnames) + # Step 3: Validate 'Source' and 'SYSTraderTranType' values + self._validate_source_and_trader_type(reader) + return file @staticmethod @@ -108,3 +111,18 @@ class CSVUploadForm(forms.Form): if missing_columns: msg = f"Missing required columns: {', '.join(missing_columns)}" raise forms.ValidationError(msg) + + @staticmethod + def _validate_source_and_trader_type(data: Iterable[dict]) -> None: + """Validate that 'Source' is always 80 and 'SYSTraderTranType' is always 4.""" + errors = [] + + for index, row in enumerate(data, start=1): + if row.get("Source") != "80": + errors.append(f"Row {index}: 'Source' must be 80, but found {row.get('Source')}.") + + if row.get("SYSTraderTranType") != "4": + errors.append(f"Row {index}: 'SYSTraderTranType' must be 4, but found {row.get('SYSTraderTranType')}.") + + if errors: + raise forms.ValidationError(errors) -- GitLab