diff --git a/scripts/freq/freqProcessor.py b/scripts/freq/freqProcessor.py index 6304edcb200c4b8978c4c065986af900818436ce..13599ad67ecd91995babe27cb2c32edb997d2b16 100755 --- a/scripts/freq/freqProcessor.py +++ b/scripts/freq/freqProcessor.py @@ -5,6 +5,7 @@ import traceback from org.apache.nifi.processor import AbstractProcessor from org.apache.nifi.processor import Relationship from org.apache.nifi.components import PropertyDescriptor +from org.apache.nifi.expression import ExpressionLanguageScope from org.apache.nifi.processor.util import StandardValidators from org.apache.nifi.serialization import RecordReaderFactory, RecordSetWriterFactory, SimpleRecordSchema from org.apache.nifi.serialization.record import RecordField, RecordFieldType @@ -23,12 +24,12 @@ class FreqProcessor(AbstractProcessor): _record_reader = PropertyDescriptor.Builder().name("record-reader").displayName("Record Reader").description("Specifies the Controller Service to use for reading incoming data").identifiesControllerService(RecordReaderFactory).required(True).build() _record_writer = PropertyDescriptor.Builder().name("record-writer").displayName("Record Writer").description("Specifies the Controller Service to use for writing out the records").identifiesControllerService(RecordSetWriterFactory).required(True).build() # Record field to get the string to be analyzed from - _input_field = PropertyDescriptor.Builder().name("Input Field").addValidator(StandardValidators.NON_EMPTY_VALIDATOR).required(True).build() + _input_field = PropertyDescriptor.Builder().name("Input Field").addValidator(StandardValidators.NON_EMPTY_VALIDATOR).expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY).required(True).build() # Record fields to store the results (prob1, prob2) into - _result_field1 = PropertyDescriptor.Builder().name("Result Field 1").addValidator(StandardValidators.NON_EMPTY_VALIDATOR).required(True).build() - _result_field2 = PropertyDescriptor.Builder().name("Result Field 2").addValidator(StandardValidators.NON_EMPTY_VALIDATOR).required(True).build() + _result_field1 = PropertyDescriptor.Builder().name("Result Field 1").addValidator(StandardValidators.NON_EMPTY_VALIDATOR).expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY).required(True).build() + _result_field2 = PropertyDescriptor.Builder().name("Result Field 2").addValidator(StandardValidators.NON_EMPTY_VALIDATOR).expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY).required(True).build() # File with character frequency table (as created by freq.py) - _freq_file = PropertyDescriptor.Builder().name("Frequency File").addValidator(StandardValidators.FILE_EXISTS_VALIDATOR).required(True).build() + _freq_file = PropertyDescriptor.Builder().name("Frequency File").addValidator(StandardValidators.FILE_EXISTS_VALIDATOR).expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY).required(True).build() def __init__(self): self.fc = FreqCounter() @@ -59,14 +60,15 @@ class FreqProcessor(AbstractProcessor): if not self.fc_loaded: # Load character frequency table from given file # (Note: this cannot be done in initialize() since the context there doesn't contain the getProperty() method) - self.fc.load(str(context.getProperty("Frequency File").getValue())) + self.fc.load(str(context.getProperty("Frequency File").evaluateAttributeExpressions().getValue())) self.fc_loaded=True self.logger.info("Sucessfully loaded frequency file") # Get field names to work with (set in configuration) - self.input_field_name = context.getProperty("Input Field").getValue()[1:] - self.result1_field_name = context.getProperty("Result Field 1").getValue()[1:] - self.result2_field_name = context.getProperty("Result Field 2").getValue()[1:] + # (remove leading '/' if present to simulate trivial support of RecordPath) + self.input_field_name = context.getProperty("Input Field").evaluateAttributeExpressions().getValue().lstrip('/') + self.result1_field_name = context.getProperty("Result Field 1").evaluateAttributeExpressions().getValue().lstrip('/') + self.result2_field_name = context.getProperty("Result Field 2").evaluateAttributeExpressions().getValue().lstrip('/') try: flowfile = session.get()