Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
shijianjian committed Sep 7, 2024
1 parent a48ed83 commit 6584aec
Show file tree
Hide file tree
Showing 8 changed files with 284 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Join the community
metrics
morphology
nerf
onnx
tracking
testing
utils
Expand Down
117 changes: 117 additions & 0 deletions docs/source/onnx.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
ONNXSequential: Chain Multiple ONNX Models with Ease
=====================================================

The `ONNXSequential` class is a powerful new feature that allows users to effortlessly combine and chain multiple ONNX models together. This is especially useful when you have several pre-trained models or custom ONNX operators that you want to execute sequentially as part of a larger pipeline.

Whether you're working with models for inference, experimentation, or optimization, `ONNXSequential` makes it easier to manage, combine, and run ONNX models in a streamlined manner. It also supports flexibility in execution environments with ONNXRuntime’s execution providers (CPU, CUDA, etc.).

Key Features
------------

- **Seamless Model Chaining**: Combine multiple ONNX models into a single computational graph.
- **Flexible Input/Output Mapping**: Control how the outputs of one model are passed as inputs to the next.
- **Optimized Execution**: Automatically create optimized `ONNXRuntime` sessions to speed up inference.
- **Export to ONNX**: Save the combined model into a single ONNX file for easy deployment and sharing.
- **Execution Providers Support**: Utilize ONNXRuntime's execution providers (e.g., `CUDAExecutionProvider`, `CPUExecutionProvider`) for accelerated inference on different hardware.
- **PyTorch-like Interface**: Use the `ONNXSequential` class like a PyTorch `nn.Sequential` model, including calling it directly for inference.

Quickstart Guide
----------------

Here's how you can quickly get started with `ONNXSequential`:

1. **Install ONNX and ONNXRuntime**

If you haven't already installed `onnx` and `onnxruntime`, you can install them using `pip`:

.. code-block:: bash
pip install onnx onnxruntime
2. **Combining ONNX Models**

You can initialize the `ONNXSequential` with a list of ONNX models or file paths. Models will be automatically chained together and optimized for inference.

.. code-block:: python
import numpy as np
from kornia.onnx import ONNXSequential
# Initialize ONNXSequential with two models
onnx_seq = ONNXSequential("model1.onnx", "model2.onnx")
# Prepare some input data
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
# Perform inference
outputs = onnx_seq(input_data)
# Print the model outputs
print(outputs)
.. note::
By default, we assume each ONNX model contains only one input node named "input" and one output node named "output". For complex models, you may need to pass an `io_maps` arguement.

3. **Input/Output Mapping Between Models**

When combining models, you can specify how the outputs of one model are mapped to the inputs of the next. This allows you to chain models in custom ways.

.. code-block:: python
io_map = [("model1_output_0", "model2_input_0"), ("model1_output_1", "model2_input_1")]
onnx_seq = ONNXSequential("model1.onnx", "model2.onnx", io_map=io_map)
4. **Exporting the Combined Model**

You can easily export the combined model to an ONNX file:

.. code-block:: python
# Export the combined model to a file
onnx_seq.export("combined_model.onnx")
5. **Optimizing with Execution Providers**

Leverage ONNXRuntime's execution providers for optimized inference. For example, to run the model on a GPU:

.. code-block:: python
# Initialize with CUDA execution provider
onnx_seq = ONNXSequential("model1.onnx", "model2.onnx", providers=['CUDAExecutionProvider'])
# Run inference
outputs = onnx_seq(input_data)
Frequently Asked Questions (FAQ)
-------------------------------

**1. Can I chain models from different sources?**

Yes! You can chain models from different ONNX files or directly from `onnx.ModelProto` objects. `ONNXSequential` handles the integration and merging of their graphs.

**2. What happens if the input/output sizes of models don't match?**

You can use the `io_map` parameter to control how outputs of one model are mapped to the inputs of the next. This allows for greater flexibility when chaining models with different architectures.

**3. Can I use custom ONNXRuntime session options?**

Absolutely! You can pass your own session options to the `create_session` method to fine-tune performance, memory usage, or logging.

Why Choose ONNXSequential?
---------------------------

With the increasing adoption of ONNX for model interoperability and deployment, `ONNXSequential` provides a simple yet powerful interface for combining models and operators. By leveraging ONNXRuntime’s optimization and execution provider capabilities, it gives you the flexibility to:
- Deploy on different hardware (CPU, GPU).
- Run complex pipelines in production environments.
- Combine and experiment with models effortlessly.

Whether you're building an advanced deep learning pipeline or simply trying to chain pre-trained models, `ONNXSequential` makes it easy to manage, optimize, and execute ONNX models at scale.

Get started today and streamline your ONNX workflows!


API Documentation
-----------------
.. autoclass:: kornia.onnx.sequential.ONNXSequential
:members:
1 change: 1 addition & 0 deletions kornia/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
losses,
metrics,
morphology,
onnx,
tracking,
utils,
x,
Expand Down
2 changes: 2 additions & 0 deletions kornia/core/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,5 @@ def __dir__(self) -> List[str]:
numpy = LazyLoader("numpy")
PILImage = LazyLoader("PIL.Image")
diffusers = LazyLoader("diffusers")
onnx = LazyLoader("onnx")
onnxruntime = LazyLoader("onnxruntime")
15 changes: 8 additions & 7 deletions kornia/geometry/transform/affwarp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import nn

from kornia.core import ones, ones_like, zeros
from kornia.core import ImageModule as Module
from kornia.filters import gaussian_blur2d
from kornia.utils import _extract_device_dtype
from kornia.utils.image import perform_keep_shape_image
Expand Down Expand Up @@ -643,7 +644,7 @@ def rescale(
return resize(input, size, interpolation=interpolation, align_corners=align_corners, antialias=antialias)


class Resize(nn.Module):
class Resize(Module):
r"""Resize the input torch.Tensor to the given size.
Args:
Expand Down Expand Up @@ -699,7 +700,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
)


class Affine(nn.Module):
class Affine(Module):
r"""Apply multiple elementary affine transforms simultaneously.
Args:
Expand Down Expand Up @@ -795,7 +796,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return affine(input, matrix[..., :2, :3], self.mode, self.padding_mode, self.align_corners)


class Rescale(nn.Module):
class Rescale(Module):
r"""Rescale the input torch.Tensor with the given factor.
Args:
Expand Down Expand Up @@ -838,7 +839,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
)


class Rotate(nn.Module):
class Rotate(Module):
r"""Rotate the tensor anti-clockwise about the centre.
Args:
Expand Down Expand Up @@ -883,7 +884,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return rotate(input, self.angle, self.center, self.mode, self.padding_mode, self.align_corners)


class Translate(nn.Module):
class Translate(Module):
r"""Translate the tensor in pixel units.
Args:
Expand Down Expand Up @@ -920,7 +921,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return translate(input, self.translation, self.mode, self.padding_mode, self.align_corners)


class Scale(nn.Module):
class Scale(Module):
r"""Scale the tensor by a factor.
Args:
Expand Down Expand Up @@ -967,7 +968,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return scale(input, self.scale_factor, self.center, self.mode, self.padding_mode, self.align_corners)


class Shear(nn.Module):
class Shear(Module):
r"""Shear the tensor.
Args:
Expand Down
3 changes: 2 additions & 1 deletion kornia/geometry/transform/flips.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

from kornia.core import Module, Tensor
from kornia.core import Tensor
from kornia.core import ImageModule as Module

__all__ = ["Vflip", "Hflip", "Rot180", "rot180", "hflip", "vflip"]

Expand Down
1 change: 1 addition & 0 deletions kornia/onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .sequential import *
152 changes: 152 additions & 0 deletions kornia/onnx/sequential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from typing import Optional, Union

from kornia.core.external import numpy as np
from kornia.core.external import onnx
from kornia.core.external import onnxruntime as ort

__all__ = ["ONNXSequential"]


class ONNXSequential:
"""ONNXSequential to chain multiple ONNX operators together.
Args:
*args:
A variable number of ONNX models (either ONNX ModelProto objects or file paths).
providers:
A list of execution providers for ONNXRuntime (e.g., ['CUDAExecutionProvider', 'CPUExecutionProvider']).
session_options:
Optional ONNXRuntime session options for optimizing the session.
io_maps:
An optional list of list of tuples specifying input-output mappings for combining models.
If None, we assume the default input name and output name are "input" and "output" accordingly, and
only one input and output node for each graph.
If not None, `io_maps[0]` shall represent the `io_map` for combining the first and second ONNX models.
"""
def __init__(
self,
*args: Union[onnx.ModelProto, str], # type:ignore
providers: Optional[list[str]] = None,
session_options: Optional[ort.SessionOptions] = None, # type:ignore
io_maps: Optional[list[tuple[str, str]]] = None
) -> None:
self.operators = args
self._combined_op = self._combine(io_maps)
self._session = self.create_session()

def _load_op(self, arg: Union[onnx.ModelProto, str]) -> onnx.ModelProto: # type:ignore
"""Loads an ONNX model, either from a file path or use the provided ONNX ModelProto.
Args:
arg: Either an ONNX ModelProto object or a file path to an ONNX model.
Returns:
onnx.ModelProto: The loaded ONNX model.
"""
if isinstance(arg, str):
return onnx.load(arg) # type:ignore
return arg

def _combine(self, io_maps: Optional[list[tuple[str, str]]] = None) -> onnx.ModelProto: # type:ignore
""" Combine the provided ONNX models into a single ONNX graph. Optionally, map inputs and outputs
between operators using the `io_map`.
Args:
io_maps:
A list of list of tuples representing input-output mappings for combining the models.
Example: [[(model1_output_name, model2_input_name)], [(model2_output_name, model3_input_name)]].
Returns:
onnx.ModelProto: The combined ONNX model as a single ONNX graph.
Raises:
ValueError: If no operators are provided for combination.
"""
if len(self.operators) == 0:
raise ValueError("No operators found.")

combined_op = self._load_op(self.operators[0])
combined_op = onnx.compose.add_prefix(combined_op, prefix=f"K{str(0).zfill(2)}-")

for i, op in enumerate(self.operators[1:]):
next_op = onnx.compose.add_prefix(self._load_op(op), prefix=f"K{str(i + 1).zfill(2)}-")
if io_maps is None:
io_map = [(f"K{str(i).zfill(2)}-output", f"K{str(i + 1).zfill(2)}-input")]
else:
io_map = [(f"K{str(i).zfill(2)}-{it[0]}", f"K{str(i + 1).zfill(2)}-{it[1]}") for it in io_maps[i]]
combined_op = onnx.compose.merge_models(combined_op, next_op, io_map=io_map)

return combined_op

def export(self, file_path: str) -> None:
"""Export the combined ONNX model to a file.
Args:
file_path: str
The file path to export the combined ONNX model.
"""
onnx.save(self._combined_op, file_path)

def create_session(
self,
providers: Optional[list[str]] = None,
session_options: Optional[ort.SessionOptions] = None
) -> ort.InferenceSession: # type:ignore
"""Create an optimized ONNXRuntime InferenceSession for the combined model.
Args:
providers:
Execution providers for ONNXRuntime (e.g., ['CUDAExecutionProvider', 'CPUExecutionProvider']).
session_options:
Optional ONNXRuntime session options for session configuration and optimizations.
Returns:
ort.InferenceSession: The ONNXRuntime session optimized for inference.
"""
if providers is None:
sess_options = ort.SessionOptions() # type:ignore
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED # type:ignore
if session_options is None:
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
session = ort.InferenceSession(
self._combined_op.SerializeToString(),
sess_options=sess_options,
providers=providers or ['CPUExecutionProvider']
)
return session

def set_session(self, session: ort.InferenceSession) -> None: # type: ignore
"""Set a custom ONNXRuntime InferenceSession.
Args:
session: ort.InferenceSession
The custom ONNXRuntime session to be set for inference.
"""
self._session = session

def get_session(self) -> ort.InferenceSession: # type: ignore
"""Get the current ONNXRuntime InferenceSession.
Returns:
ort.InferenceSession: The current ONNXRuntime session.
"""
return self._session

def __call__(self, *inputs: np.ndarray) -> list[np.ndarray]: # type:ignore
"""Perform inference using the combined ONNX model.
Args:
*inputs: Inputs to the ONNX model. The number of inputs must match the expected inputs of the session.
Returns:
List: The outputs from the ONNX model inference.
"""
ort_inputs = self._session.get_inputs()
if len(ort_inputs) != len(inputs):
raise ValueError(f"Expected {len(ort_inputs)} for the session while only {len(inputs)} received.")

ort_input_values = {ort_inputs[i].name: inputs[i] for i in range(len(ort_inputs))}
outputs = self._session.run(None, ort_input_values)

return outputs

0 comments on commit 6584aec

Please sign in to comment.