From 478e413040d9681fecf90033fa2561aa3a679dcd Mon Sep 17 00:00:00 2001 From: shijianjian Date: Fri, 13 Sep 2024 18:08:04 +0300 Subject: [PATCH] update --- kornia/onnx/sequential.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/kornia/onnx/sequential.py b/kornia/onnx/sequential.py index 672870129d..c46b56c206 100644 --- a/kornia/onnx/sequential.py +++ b/kornia/onnx/sequential.py @@ -6,7 +6,7 @@ from .utils import ONNXLoader -__all__ = ["ONNXSequential"] +__all__ = ["ONNXSequential", "load"] class ONNXSequential: @@ -14,6 +14,8 @@ class ONNXSequential: Args: *args: A variable number of ONNX models (either ONNX ModelProto objects or file paths). + For Hugging Face-hosted models, use the format 'hf://model_name'. Valid `model_name` can be found on + https://huggingface.co/kornia/ONNX_models. providers: A list of execution providers for ONNXRuntime (e.g., ['CUDAExecutionProvider', 'CPUExecutionProvider']). session_options: Optional ONNXRuntime session options for optimizing the session. @@ -36,7 +38,7 @@ def __init__( self.onnx_loader = ONNXLoader(cache_dir) self.operators = args self._combined_op = self._combine(io_maps) - self._session = self.create_session() + self._session = self.create_session(providers=providers, session_options=session_options) def _load_op(self, arg: Union["onnx.ModelProto", str]) -> "onnx.ModelProto": # type:ignore """Loads an ONNX model, either from a file path or use the provided ONNX ModelProto. @@ -148,3 +150,17 @@ def __call__(self, *inputs: "np.ndarray") -> list["np.ndarray"]: # type:ignore outputs = self._session.run(None, ort_input_values) return outputs + + +def load(model_name: str) -> "ONNXSequential": + """Load an ONNX model from either a file path or HuggingFace. + + The loaded model is an ONNXSequential object, of which you may run the model with + the `__call__` method, with less boilerplate. + + Args: + model_name: The name of the model to load. For Hugging Face-hosted models, + use the format 'hf://model_name'. Valid `model_name` can be found on + https://huggingface.co/kornia/ONNX_models. + """ + return ONNXSequential(model_name) \ No newline at end of file