forked from kornia/kornia
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a48ed83
commit 6584aec
Showing
8 changed files
with
284 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,6 +69,7 @@ Join the community | |
metrics | ||
morphology | ||
nerf | ||
onnx | ||
tracking | ||
testing | ||
utils | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
ONNXSequential: Chain Multiple ONNX Models with Ease | ||
===================================================== | ||
|
||
The `ONNXSequential` class is a powerful new feature that allows users to effortlessly combine and chain multiple ONNX models together. This is especially useful when you have several pre-trained models or custom ONNX operators that you want to execute sequentially as part of a larger pipeline. | ||
|
||
Whether you're working with models for inference, experimentation, or optimization, `ONNXSequential` makes it easier to manage, combine, and run ONNX models in a streamlined manner. It also supports flexibility in execution environments with ONNXRuntime’s execution providers (CPU, CUDA, etc.). | ||
|
||
Key Features | ||
------------ | ||
|
||
- **Seamless Model Chaining**: Combine multiple ONNX models into a single computational graph. | ||
- **Flexible Input/Output Mapping**: Control how the outputs of one model are passed as inputs to the next. | ||
- **Optimized Execution**: Automatically create optimized `ONNXRuntime` sessions to speed up inference. | ||
- **Export to ONNX**: Save the combined model into a single ONNX file for easy deployment and sharing. | ||
- **Execution Providers Support**: Utilize ONNXRuntime's execution providers (e.g., `CUDAExecutionProvider`, `CPUExecutionProvider`) for accelerated inference on different hardware. | ||
- **PyTorch-like Interface**: Use the `ONNXSequential` class like a PyTorch `nn.Sequential` model, including calling it directly for inference. | ||
|
||
Quickstart Guide | ||
---------------- | ||
|
||
Here's how you can quickly get started with `ONNXSequential`: | ||
|
||
1. **Install ONNX and ONNXRuntime** | ||
|
||
If you haven't already installed `onnx` and `onnxruntime`, you can install them using `pip`: | ||
|
||
.. code-block:: bash | ||
pip install onnx onnxruntime | ||
2. **Combining ONNX Models** | ||
|
||
You can initialize the `ONNXSequential` with a list of ONNX models or file paths. Models will be automatically chained together and optimized for inference. | ||
|
||
.. code-block:: python | ||
import numpy as np | ||
from kornia.onnx import ONNXSequential | ||
# Initialize ONNXSequential with two models | ||
onnx_seq = ONNXSequential("model1.onnx", "model2.onnx") | ||
# Prepare some input data | ||
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32) | ||
# Perform inference | ||
outputs = onnx_seq(input_data) | ||
# Print the model outputs | ||
print(outputs) | ||
.. note:: | ||
By default, we assume each ONNX model contains only one input node named "input" and one output node named "output". For complex models, you may need to pass an `io_maps` arguement. | ||
|
||
3. **Input/Output Mapping Between Models** | ||
|
||
When combining models, you can specify how the outputs of one model are mapped to the inputs of the next. This allows you to chain models in custom ways. | ||
|
||
.. code-block:: python | ||
io_map = [("model1_output_0", "model2_input_0"), ("model1_output_1", "model2_input_1")] | ||
onnx_seq = ONNXSequential("model1.onnx", "model2.onnx", io_map=io_map) | ||
4. **Exporting the Combined Model** | ||
|
||
You can easily export the combined model to an ONNX file: | ||
|
||
.. code-block:: python | ||
# Export the combined model to a file | ||
onnx_seq.export("combined_model.onnx") | ||
5. **Optimizing with Execution Providers** | ||
|
||
Leverage ONNXRuntime's execution providers for optimized inference. For example, to run the model on a GPU: | ||
|
||
.. code-block:: python | ||
# Initialize with CUDA execution provider | ||
onnx_seq = ONNXSequential("model1.onnx", "model2.onnx", providers=['CUDAExecutionProvider']) | ||
# Run inference | ||
outputs = onnx_seq(input_data) | ||
Frequently Asked Questions (FAQ) | ||
------------------------------- | ||
|
||
**1. Can I chain models from different sources?** | ||
|
||
Yes! You can chain models from different ONNX files or directly from `onnx.ModelProto` objects. `ONNXSequential` handles the integration and merging of their graphs. | ||
|
||
**2. What happens if the input/output sizes of models don't match?** | ||
|
||
You can use the `io_map` parameter to control how outputs of one model are mapped to the inputs of the next. This allows for greater flexibility when chaining models with different architectures. | ||
|
||
**3. Can I use custom ONNXRuntime session options?** | ||
|
||
Absolutely! You can pass your own session options to the `create_session` method to fine-tune performance, memory usage, or logging. | ||
|
||
Why Choose ONNXSequential? | ||
--------------------------- | ||
|
||
With the increasing adoption of ONNX for model interoperability and deployment, `ONNXSequential` provides a simple yet powerful interface for combining models and operators. By leveraging ONNXRuntime’s optimization and execution provider capabilities, it gives you the flexibility to: | ||
- Deploy on different hardware (CPU, GPU). | ||
- Run complex pipelines in production environments. | ||
- Combine and experiment with models effortlessly. | ||
|
||
Whether you're building an advanced deep learning pipeline or simply trying to chain pre-trained models, `ONNXSequential` makes it easy to manage, optimize, and execute ONNX models at scale. | ||
|
||
Get started today and streamline your ONNX workflows! | ||
|
||
|
||
API Documentation | ||
----------------- | ||
.. autoclass:: kornia.onnx.sequential.ONNXSequential | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
losses, | ||
metrics, | ||
morphology, | ||
onnx, | ||
tracking, | ||
utils, | ||
x, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .sequential import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
from typing import Optional, Union | ||
|
||
from kornia.core.external import numpy as np | ||
from kornia.core.external import onnx | ||
from kornia.core.external import onnxruntime as ort | ||
|
||
__all__ = ["ONNXSequential"] | ||
|
||
|
||
class ONNXSequential: | ||
"""ONNXSequential to chain multiple ONNX operators together. | ||
Args: | ||
*args: | ||
A variable number of ONNX models (either ONNX ModelProto objects or file paths). | ||
providers: | ||
A list of execution providers for ONNXRuntime (e.g., ['CUDAExecutionProvider', 'CPUExecutionProvider']). | ||
session_options: | ||
Optional ONNXRuntime session options for optimizing the session. | ||
io_maps: | ||
An optional list of list of tuples specifying input-output mappings for combining models. | ||
If None, we assume the default input name and output name are "input" and "output" accordingly, and | ||
only one input and output node for each graph. | ||
If not None, `io_maps[0]` shall represent the `io_map` for combining the first and second ONNX models. | ||
""" | ||
def __init__( | ||
self, | ||
*args: Union[onnx.ModelProto, str], # type:ignore | ||
providers: Optional[list[str]] = None, | ||
session_options: Optional[ort.SessionOptions] = None, # type:ignore | ||
io_maps: Optional[list[tuple[str, str]]] = None | ||
) -> None: | ||
self.operators = args | ||
self._combined_op = self._combine(io_maps) | ||
self._session = self.create_session() | ||
|
||
def _load_op(self, arg: Union[onnx.ModelProto, str]) -> onnx.ModelProto: # type:ignore | ||
"""Loads an ONNX model, either from a file path or use the provided ONNX ModelProto. | ||
Args: | ||
arg: Either an ONNX ModelProto object or a file path to an ONNX model. | ||
Returns: | ||
onnx.ModelProto: The loaded ONNX model. | ||
""" | ||
if isinstance(arg, str): | ||
return onnx.load(arg) # type:ignore | ||
return arg | ||
|
||
def _combine(self, io_maps: Optional[list[tuple[str, str]]] = None) -> onnx.ModelProto: # type:ignore | ||
""" Combine the provided ONNX models into a single ONNX graph. Optionally, map inputs and outputs | ||
between operators using the `io_map`. | ||
Args: | ||
io_maps: | ||
A list of list of tuples representing input-output mappings for combining the models. | ||
Example: [[(model1_output_name, model2_input_name)], [(model2_output_name, model3_input_name)]]. | ||
Returns: | ||
onnx.ModelProto: The combined ONNX model as a single ONNX graph. | ||
Raises: | ||
ValueError: If no operators are provided for combination. | ||
""" | ||
if len(self.operators) == 0: | ||
raise ValueError("No operators found.") | ||
|
||
combined_op = self._load_op(self.operators[0]) | ||
combined_op = onnx.compose.add_prefix(combined_op, prefix=f"K{str(0).zfill(2)}-") | ||
|
||
for i, op in enumerate(self.operators[1:]): | ||
next_op = onnx.compose.add_prefix(self._load_op(op), prefix=f"K{str(i + 1).zfill(2)}-") | ||
if io_maps is None: | ||
io_map = [(f"K{str(i).zfill(2)}-output", f"K{str(i + 1).zfill(2)}-input")] | ||
else: | ||
io_map = [(f"K{str(i).zfill(2)}-{it[0]}", f"K{str(i + 1).zfill(2)}-{it[1]}") for it in io_maps[i]] | ||
combined_op = onnx.compose.merge_models(combined_op, next_op, io_map=io_map) | ||
|
||
return combined_op | ||
|
||
def export(self, file_path: str) -> None: | ||
"""Export the combined ONNX model to a file. | ||
Args: | ||
file_path: str | ||
The file path to export the combined ONNX model. | ||
""" | ||
onnx.save(self._combined_op, file_path) | ||
|
||
def create_session( | ||
self, | ||
providers: Optional[list[str]] = None, | ||
session_options: Optional[ort.SessionOptions] = None | ||
) -> ort.InferenceSession: # type:ignore | ||
"""Create an optimized ONNXRuntime InferenceSession for the combined model. | ||
Args: | ||
providers: | ||
Execution providers for ONNXRuntime (e.g., ['CUDAExecutionProvider', 'CPUExecutionProvider']). | ||
session_options: | ||
Optional ONNXRuntime session options for session configuration and optimizations. | ||
Returns: | ||
ort.InferenceSession: The ONNXRuntime session optimized for inference. | ||
""" | ||
if providers is None: | ||
sess_options = ort.SessionOptions() # type:ignore | ||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED # type:ignore | ||
if session_options is None: | ||
sess_options = ort.SessionOptions() | ||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED | ||
session = ort.InferenceSession( | ||
self._combined_op.SerializeToString(), | ||
sess_options=sess_options, | ||
providers=providers or ['CPUExecutionProvider'] | ||
) | ||
return session | ||
|
||
def set_session(self, session: ort.InferenceSession) -> None: # type: ignore | ||
"""Set a custom ONNXRuntime InferenceSession. | ||
Args: | ||
session: ort.InferenceSession | ||
The custom ONNXRuntime session to be set for inference. | ||
""" | ||
self._session = session | ||
|
||
def get_session(self) -> ort.InferenceSession: # type: ignore | ||
"""Get the current ONNXRuntime InferenceSession. | ||
Returns: | ||
ort.InferenceSession: The current ONNXRuntime session. | ||
""" | ||
return self._session | ||
|
||
def __call__(self, *inputs: np.ndarray) -> list[np.ndarray]: # type:ignore | ||
"""Perform inference using the combined ONNX model. | ||
Args: | ||
*inputs: Inputs to the ONNX model. The number of inputs must match the expected inputs of the session. | ||
Returns: | ||
List: The outputs from the ONNX model inference. | ||
""" | ||
ort_inputs = self._session.get_inputs() | ||
if len(ort_inputs) != len(inputs): | ||
raise ValueError(f"Expected {len(ort_inputs)} for the session while only {len(inputs)} received.") | ||
|
||
ort_input_values = {ort_inputs[i].name: inputs[i] for i in range(len(ort_inputs))} | ||
outputs = self._session.run(None, ort_input_values) | ||
|
||
return outputs |