Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
shijianjian committed Sep 13, 2024
1 parent 56f8190 commit 14c277c
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions kornia/onnx/sequential.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import List, Optional, Tuple, Union

from kornia.core.external import numpy as np
from kornia.core.external import onnx
Expand Down Expand Up @@ -30,9 +30,9 @@ class ONNXSequential:
def __init__(
self,
*args: Union["onnx.ModelProto", str], # type:ignore
providers: Optional[list[str]] = None,
providers: Optional[List[str]] = None,
session_options: Optional["ort.SessionOptions"] = None, # type:ignore
io_maps: Optional[list[tuple[str, str]]] = None,
io_maps: Optional[List[Tuple[str, str]]] = None,
cache_dir: Optional[str] = None,
) -> None:
self.onnx_loader = ONNXLoader(cache_dir)
Expand All @@ -53,7 +53,7 @@ def _load_op(self, arg: Union["onnx.ModelProto", str]) -> "onnx.ModelProto": #
return self.onnx_loader.load_model(arg)
return arg

def _combine(self, io_maps: Optional[list[tuple[str, str]]] = None) -> "onnx.ModelProto": # type:ignore
def _combine(self, io_maps: Optional[List[Tuple[str, str]]] = None) -> "onnx.ModelProto": # type:ignore
"""Combine the provided ONNX models into a single ONNX graph. Optionally, map inputs and outputs between
operators using the `io_map`.
Expand Down Expand Up @@ -92,7 +92,7 @@ def export(self, file_path: str) -> None:

def create_session(
self,
providers: Optional[list[str]] = None,
providers: Optional[List[str]] = None,
session_options: Optional["ort.SessionOptions"] = None, # type:ignore
) -> "ort.InferenceSession": # type:ignore
"""Create an optimized ONNXRuntime InferenceSession for the combined model.
Expand Down Expand Up @@ -133,7 +133,7 @@ def get_session(self) -> "ort.InferenceSession": # type: ignore
"""
return self._session

def __call__(self, *inputs: "np.ndarray") -> list["np.ndarray"]: # type:ignore
def __call__(self, *inputs: "np.ndarray") -> List["np.ndarray"]: # type:ignore
"""Perform inference using the combined ONNX model.
Args:
Expand Down

0 comments on commit 14c277c

Please sign in to comment.