Skip to content
Snippets Groups Projects
Commit 913e1634 authored by Václav Bartoš's avatar Václav Bartoš
Browse files

FreqProcessor: fixed ExpressionLanguage support, error handling

parent 8c643029
No related branches found
No related tags found
No related merge requests found
...@@ -25,10 +25,10 @@ class FreqProcessor(AbstractProcessor): ...@@ -25,10 +25,10 @@ class FreqProcessor(AbstractProcessor):
_record_reader = PropertyDescriptor.Builder().name("record-reader").displayName("Record Reader").description("Specifies the Controller Service to use for reading incoming data").identifiesControllerService(RecordReaderFactory).required(True).build() _record_reader = PropertyDescriptor.Builder().name("record-reader").displayName("Record Reader").description("Specifies the Controller Service to use for reading incoming data").identifiesControllerService(RecordReaderFactory).required(True).build()
_record_writer = PropertyDescriptor.Builder().name("record-writer").displayName("Record Writer").description("Specifies the Controller Service to use for writing out the records").identifiesControllerService(RecordSetWriterFactory).required(True).build() _record_writer = PropertyDescriptor.Builder().name("record-writer").displayName("Record Writer").description("Specifies the Controller Service to use for writing out the records").identifiesControllerService(RecordSetWriterFactory).required(True).build()
# Record field to get the string to be analyzed from # Record field to get the string to be analyzed from
_input_field = PropertyDescriptor.Builder().name("Input Field").addValidator(StandardValidators.NON_EMPTY_VALIDATOR).expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY).required(True).build() _input_field = PropertyDescriptor.Builder().name("Input Field").addValidator(StandardValidators.NON_EMPTY_VALIDATOR).expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES).required(True).build()
# Record fields to store the results (prob1, prob2) into # Record fields to store the results (prob1, prob2) into
_result_field1 = PropertyDescriptor.Builder().name("Result Field 1").addValidator(StandardValidators.NON_EMPTY_VALIDATOR).expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY).required(True).build() _result_field1 = PropertyDescriptor.Builder().name("Result Field 1").addValidator(StandardValidators.NON_EMPTY_VALIDATOR).expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES).required(True).build()
_result_field2 = PropertyDescriptor.Builder().name("Result Field 2").addValidator(StandardValidators.NON_EMPTY_VALIDATOR).expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY).required(True).build() _result_field2 = PropertyDescriptor.Builder().name("Result Field 2").addValidator(StandardValidators.NON_EMPTY_VALIDATOR).expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES).required(True).build()
# File with character frequency table (as created by freq.py) # File with character frequency table (as created by freq.py)
_freq_file = PropertyDescriptor.Builder().name("Frequency File").addValidator(StandardValidators.FILE_EXISTS_VALIDATOR).expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY).required(True).build() _freq_file = PropertyDescriptor.Builder().name("Frequency File").addValidator(StandardValidators.FILE_EXISTS_VALIDATOR).expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY).required(True).build()
...@@ -68,12 +68,6 @@ class FreqProcessor(AbstractProcessor): ...@@ -68,12 +68,6 @@ class FreqProcessor(AbstractProcessor):
self.fc_loaded=True self.fc_loaded=True
self.logger.debug("Sucessfully loaded frequency file") self.logger.debug("Sucessfully loaded frequency file")
# Get field names to work with (set in configuration)
# (remove leading '/' if present to simulate trivial support of RecordPath)
self.input_field_name = context.getProperty("Input Field").evaluateAttributeExpressions().getValue().lstrip('/')
self.result1_field_name = context.getProperty("Result Field 1").evaluateAttributeExpressions().getValue().lstrip('/')
self.result2_field_name = context.getProperty("Result Field 2").evaluateAttributeExpressions().getValue().lstrip('/')
try: try:
flowfile = session.get() flowfile = session.get()
if flowfile is None : if flowfile is None :
...@@ -81,6 +75,13 @@ class FreqProcessor(AbstractProcessor): ...@@ -81,6 +75,13 @@ class FreqProcessor(AbstractProcessor):
self.logger.debug("Processing FlowFile {}".format(flowfile.getAttribute('uuid'))) self.logger.debug("Processing FlowFile {}".format(flowfile.getAttribute('uuid')))
# Get field names to work with (set in configuration)
# (remove leading '/' if present to simulate trivial support of RecordPath)
# TODO: full support for RecordPath?
self.input_field_name = context.getProperty("Input Field").evaluateAttributeExpressions(flowfile).getValue().lstrip('/')
self.result1_field_name = context.getProperty("Result Field 1").evaluateAttributeExpressions(flowfile).getValue().lstrip('/')
self.result2_field_name = context.getProperty("Result Field 2").evaluateAttributeExpressions(flowfile).getValue().lstrip('/')
readerFactory = context.getProperty(self._record_reader).asControllerService(RecordReaderFactory) readerFactory = context.getProperty(self._record_reader).asControllerService(RecordReaderFactory)
writerFactory = context.getProperty(self._record_writer).asControllerService(RecordSetWriterFactory) writerFactory = context.getProperty(self._record_writer).asControllerService(RecordSetWriterFactory)
originalAttributes = flowfile.attributes originalAttributes = flowfile.attributes
...@@ -137,14 +138,16 @@ class FreqProcessor(AbstractProcessor): ...@@ -137,14 +138,16 @@ class FreqProcessor(AbstractProcessor):
# There are some records to process ... # There are some records to process ...
# Add new fields to the schema # Add new fields to the schema
#self.logger.debug("origAttributes: " + str(originalAttributes)) #self.logger.debug("origAttributes: " + str(originalAttributes))
# ref: https://github.com/apache/nifi/blob/main/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/RecordSchema.java
oldSchema = writerFactory.getSchema(originalAttributes, record.schema) oldSchema = writerFactory.getSchema(originalAttributes, record.schema)
fields = oldSchema.getFields() fields = oldSchema.getFields() # type should be List<RecordField>, but is UnmodifiableRandomAccessList.
fields = list(fields) # convert to Python list, so we can add items to it
field_names = [f.getFieldName() for f in fields] field_names = [f.getFieldName() for f in fields]
# ref: https://github.com/apache/nifi/blob/master/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/RecordField.java # ref: https://github.com/apache/nifi/blob/master/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/RecordField.java
if self.result1_field_name not in field_names: if self.result1_field_name not in field_names:
fields.append(RecordField(self.result1_field_name, RecordFieldType.FLOAT.getDataType(), False)) fields.append(RecordField(self.result1_field_name, RecordFieldType.FLOAT.getDataType()))
if self.result2_field_name not in field_names: if self.result2_field_name not in field_names:
fields.append(RecordField(self.result2_field_name, RecordFieldType.FLOAT.getDataType(), False)) fields.append(RecordField(self.result2_field_name, RecordFieldType.FLOAT.getDataType()))
newSchema = SimpleRecordSchema(fields, oldSchema.getIdentifier()) newSchema = SimpleRecordSchema(fields, oldSchema.getIdentifier())
# Create writer # Create writer
...@@ -170,8 +173,12 @@ class FreqProcessor(AbstractProcessor): ...@@ -170,8 +173,12 @@ class FreqProcessor(AbstractProcessor):
attributes['record.count'] = str(writeResult.recordCount) attributes['record.count'] = str(writeResult.recordCount)
attributes.update(writeResult.attributes) attributes.update(writeResult.attributes)
recordCount = writeResult.recordCount recordCount = writeResult.recordCount
except:
raise # This shouldn't be needed, but it is, otherwise the exception isn't propagated up (and results in uncaught exception on Java level)
finally: finally:
writer.close() writer.close()
except:
raise # This shouldn't be needed, but it is, see above
finally: finally:
reader.close() reader.close()
input_stream.close() input_stream.close()
...@@ -203,10 +210,15 @@ class FreqProcessor(AbstractProcessor): ...@@ -203,10 +210,15 @@ class FreqProcessor(AbstractProcessor):
""" """
text = record.getValue(self.input_field_name) text = record.getValue(self.input_field_name)
if text is None: if text is None:
raise ValueError("Can't get value of '{}' field".format(self.input_field_name)) # FlowFile attribute tells we should take the domain from the field
# names as $input_field_name, but such a field i not present in the
# record.
# Don't set the result fields
self.logger.warn("The 'Input field' attribute points to '{}' field which is not present in the record.".format(self.input_field_name))
return
prob1, prob2 = self.fc.probability(text) prob1, prob2 = self.fc.probability(text)
record.setValue(self.result1_field_name, prob1) record.setValue(self.result1_field_name, prob1)
record.setValue(self.result2_field_name, prob2) record.setValue(self.result2_field_name, prob2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment