Skip to content

Commit

Permalink
Add a _post_process_result to tflite_predict_extractor.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 620862012
  • Loading branch information
tf-model-analysis-team authored and tfx-copybara committed Apr 1, 2024
1 parent d216096 commit 9be6b09
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions tensorflow_model_analysis/extractors/tflite_predict_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def setup(self):
def _make_interpreter(self, **kwargs) -> tf.lite.Interpreter:
return tf.lite.Interpreter(**kwargs)

def _post_process_result(self, input_tensor: np.ndarray) -> np.ndarray:
"""Custom post processor for TFLite predictions, default is no-op."""
return input_tensor

def _get_input_name_from_input_detail(self, input_detail):
"""Get input name from input detail.
Expand All @@ -85,7 +89,7 @@ def _get_input_name_from_input_detail(self, input_detail):
# of the input names. TFLite rewriter assumes that the default signature key
# ('serving_default') will be used as an exported name when saving.
if input_name.startswith('serving_default_'):
input_name = input_name[len('serving_default_'):]
input_name = input_name[len('serving_default_') :]
# Remove argument that starts with ':'.
input_name = input_name.split(':')[0]
return input_name
Expand Down Expand Up @@ -187,10 +191,12 @@ def _batch_reducible_process(
for o in output_details:
tensor = interpreter.get_tensor(o[_INDEX])
params = o[_QUANTIZATION_PARAMETERS]
outputs[o['name']] = self._dequantize(
dequantized_tensor = self._dequantize(
tensor, params[_SCALES], params[_ZERO_POINTS]
)

outputs[o['name']] = self._post_process_result(dequantized_tensor)

for v in outputs.values():
if len(v) != batch_size:
raise ValueError('Did not get the expected number of results.')
Expand Down

0 comments on commit 9be6b09

Please sign in to comment.