Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
shijianjian committed Sep 13, 2024
1 parent d883dd3 commit 478e413
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions kornia/onnx/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@

from .utils import ONNXLoader

__all__ = ["ONNXSequential"]
__all__ = ["ONNXSequential", "load"]


class ONNXSequential:
"""ONNXSequential to chain multiple ONNX operators together.
Args:
*args: A variable number of ONNX models (either ONNX ModelProto objects or file paths).
For Hugging Face-hosted models, use the format 'hf://model_name'. Valid `model_name` can be found on
https://huggingface.co/kornia/ONNX_models.
providers: A list of execution providers for ONNXRuntime
(e.g., ['CUDAExecutionProvider', 'CPUExecutionProvider']).
session_options: Optional ONNXRuntime session options for optimizing the session.
Expand All @@ -36,7 +38,7 @@ def __init__(
self.onnx_loader = ONNXLoader(cache_dir)
self.operators = args
self._combined_op = self._combine(io_maps)
self._session = self.create_session()
self._session = self.create_session(providers=providers, session_options=session_options)

def _load_op(self, arg: Union["onnx.ModelProto", str]) -> "onnx.ModelProto": # type:ignore
"""Loads an ONNX model, either from a file path or use the provided ONNX ModelProto.
Expand Down Expand Up @@ -148,3 +150,17 @@ def __call__(self, *inputs: "np.ndarray") -> list["np.ndarray"]: # type:ignore
outputs = self._session.run(None, ort_input_values)

return outputs


def load(model_name: str) -> "ONNXSequential":
"""Load an ONNX model from either a file path or HuggingFace.
The loaded model is an ONNXSequential object, of which you may run the model with
the `__call__` method, with less boilerplate.
Args:
model_name: The name of the model to load. For Hugging Face-hosted models,
use the format 'hf://model_name'. Valid `model_name` can be found on
https://huggingface.co/kornia/ONNX_models.
"""
return ONNXSequential(model_name)

0 comments on commit 478e413

Please sign in to comment.