Skip to content

Commit

Permalink
Support for tflite_runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Dec 11, 2023
1 parent f1501aa commit f193e98
Showing 1 changed file with 33 additions and 5 deletions.
38 changes: 33 additions & 5 deletions 426_YOLOX-Body-Head-Hand/demo/demo_yolox_onnx_tfite.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,29 @@ def __init__(
self._h_index = 2
self._w_index = 3

elif self._runtime == 'tflite':
elif self._runtime == 'tflite_runtime':
from tflite_runtime.interpreter import Interpreter # type: ignore
self._interpreter = Interpreter(model_path=model_path)
self._input_details = self._interpreter.get_input_details()
self._output_details = self._interpreter.get_output_details()
self._input_shapes = [
input.get('shape', None) for input in self._input_details
]
self._input_names = [
input.get('name', None) for input in self._input_details
]
self._output_shapes = [
output.get('shape', None) for output in self._output_details
]
self._output_names = [
output.get('name', None) for output in self._output_details
]
self._model = self._interpreter.get_signature_runner()
self._swap = (0, 1, 2)
self._h_index = 1
self._w_index = 2

elif self._runtime == 'tensorflow':
import tensorflow as tf # type: ignore
self._interpreter = tf.lite.Interpreter(model_path=model_path)
self._input_details = self._interpreter.get_input_details()
Expand Down Expand Up @@ -163,7 +185,7 @@ def __call__(
)
]
return outputs
elif self._runtime == 'tflite':
elif self._runtime in ['tflite_runtime', 'tensorflow']:
outputs = [
output for output in \
self._model(
Expand Down Expand Up @@ -421,11 +443,17 @@ def main():
if not is_package_installed('onnxruntime'):
print(Color.RED('ERROR: onnxruntime is not installed. pip install onnxruntime or pip install onnxruntime-gpu'))
sys.exit(0)
runtime = 'onnx'
elif model_ext == 'tflite':
if not is_package_installed('tensorflow'):
print(Color.RED('ERROR: tensorflow is not installed. pip install tensorflow'))
if is_package_installed('tflite_runtime'):
runtime = 'tflite_runtime'
elif is_package_installed('tensorflow'):
runtime = 'tensorflow'
else:
print(Color.RED('ERROR: tflite_runtime or tensorflow is not installed.'))
print(Color.RED('ERROR: https://github.com/PINTO0309/TensorflowLite-bin'))
print(Color.RED('ERROR: https://github.com/tensorflow/tensorflow'))
sys.exit(0)
runtime = model_ext
video: str = args.video
execution_provider: str = args.execution_provider
providers: List[Tuple[str, Dict] | str] = None
Expand Down

0 comments on commit f193e98

Please sign in to comment.