From f193e98405f6d96adbe4c02195ff575fbbcf3705 Mon Sep 17 00:00:00 2001 From: pinto0309 Date: Tue, 12 Dec 2023 00:07:42 +0900 Subject: [PATCH] Support for tflite_runtime --- .../demo/demo_yolox_onnx_tfite.py | 38 ++++++++++++++++--- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/426_YOLOX-Body-Head-Hand/demo/demo_yolox_onnx_tfite.py b/426_YOLOX-Body-Head-Hand/demo/demo_yolox_onnx_tfite.py index 1e33b9510c..d54f038716 100644 --- a/426_YOLOX-Body-Head-Hand/demo/demo_yolox_onnx_tfite.py +++ b/426_YOLOX-Body-Head-Hand/demo/demo_yolox_onnx_tfite.py @@ -122,7 +122,29 @@ def __init__( self._h_index = 2 self._w_index = 3 - elif self._runtime == 'tflite': + elif self._runtime == 'tflite_runtime': + from tflite_runtime.interpreter import Interpreter # type: ignore + self._interpreter = Interpreter(model_path=model_path) + self._input_details = self._interpreter.get_input_details() + self._output_details = self._interpreter.get_output_details() + self._input_shapes = [ + input.get('shape', None) for input in self._input_details + ] + self._input_names = [ + input.get('name', None) for input in self._input_details + ] + self._output_shapes = [ + output.get('shape', None) for output in self._output_details + ] + self._output_names = [ + output.get('name', None) for output in self._output_details + ] + self._model = self._interpreter.get_signature_runner() + self._swap = (0, 1, 2) + self._h_index = 1 + self._w_index = 2 + + elif self._runtime == 'tensorflow': import tensorflow as tf # type: ignore self._interpreter = tf.lite.Interpreter(model_path=model_path) self._input_details = self._interpreter.get_input_details() @@ -163,7 +185,7 @@ def __call__( ) ] return outputs - elif self._runtime == 'tflite': + elif self._runtime in ['tflite_runtime', 'tensorflow']: outputs = [ output for output in \ self._model( @@ -421,11 +443,17 @@ def main(): if not is_package_installed('onnxruntime'): print(Color.RED('ERROR: onnxruntime is not installed. pip install onnxruntime or pip install onnxruntime-gpu')) sys.exit(0) + runtime = 'onnx' elif model_ext == 'tflite': - if not is_package_installed('tensorflow'): - print(Color.RED('ERROR: tensorflow is not installed. pip install tensorflow')) + if is_package_installed('tflite_runtime'): + runtime = 'tflite_runtime' + elif is_package_installed('tensorflow'): + runtime = 'tensorflow' + else: + print(Color.RED('ERROR: tflite_runtime or tensorflow is not installed.')) + print(Color.RED('ERROR: https://github.com/PINTO0309/TensorflowLite-bin')) + print(Color.RED('ERROR: https://github.com/tensorflow/tensorflow')) sys.exit(0) - runtime = model_ext video: str = args.video execution_provider: str = args.execution_provider providers: List[Tuple[str, Dict] | str] = None