Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
shijianjian committed Sep 7, 2024
1 parent 5c55e80 commit e3006bf
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 17 deletions.
6 changes: 4 additions & 2 deletions kornia/filters/motion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import ClassVar

from kornia.core import ImageModule as Module
from kornia.core import Tensor
from kornia.core.check import KORNIA_CHECK
Expand Down Expand Up @@ -84,8 +86,8 @@ class MotionBlur3D(Module):
>>> output = motion_blur(input) # 2x4x5x7x9
"""

ONNX_DEFAULT_INPUTSHAPE: list[int] = [-1, -1, -1, -1, -1]
ONNX_DEFAULT_OUTPUTSHAPE: list[int] = [-1, -1, -1, -1, -1]
ONNX_DEFAULT_INPUTSHAPE: ClassVar[list[int]] = [-1, -1, -1, -1, -1]
ONNX_DEFAULT_OUTPUTSHAPE: ClassVar[list[int]] = [-1, -1, -1, -1, -1]

def __init__(
self,
Expand Down
10 changes: 6 additions & 4 deletions kornia/filters/sobel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import ClassVar

import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -172,8 +174,8 @@ class SpatialGradient(Module):
>>> output = SpatialGradient()(input) # 1x3x2x4x4
"""

ONNX_DEFAULT_INPUTSHAPE: list[int] = [-1, -1, -1, -1]
ONNX_DEFAULT_OUTPUTSHAPE: list[int] = [-1, -1, 2, -1, -1]
ONNX_DEFAULT_INPUTSHAPE: ClassVar[list[int]] = [-1, -1, -1, -1]
ONNX_DEFAULT_OUTPUTSHAPE: ClassVar[list[int]] = [-1, -1, 2, -1, -1]

def __init__(self, mode: str = "sobel", order: int = 1, normalized: bool = True) -> None:
super().__init__()
Expand Down Expand Up @@ -209,8 +211,8 @@ class SpatialGradient3d(Module):
torch.Size([1, 4, 3, 2, 4, 4])
"""

ONNX_DEFAULT_INPUTSHAPE: list[int] = [-1, -1, -1, -1, -1]
ONNX_DEFAULT_OUTPUTSHAPE: list[int] = [-1, -1, -1, -1, -1, -1]
ONNX_DEFAULT_INPUTSHAPE: ClassVar[list[int]] = [-1, -1, -1, -1, -1]
ONNX_DEFAULT_OUTPUTSHAPE: ClassVar[list[int]] = [-1, -1, -1, -1, -1, -1]

def __init__(self, mode: str = "diff", order: int = 1) -> None:
super().__init__()
Expand Down
18 changes: 9 additions & 9 deletions kornia/onnx/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ class ONNXSequential:

def __init__(
self,
*args: Union[onnx.ModelProto, str], # type:ignore
*args: Union["onnx.ModelProto", str], # type:ignore
providers: Optional[list[str]] = None,
session_options: Optional[ort.SessionOptions] = None, # type:ignore
session_options: Optional["ort.SessionOptions"] = None, # type:ignore
io_maps: Optional[list[tuple[str, str]]] = None,
cache_dir: Optional[str] = None,
) -> None:
Expand All @@ -50,7 +50,7 @@ def __init__(
self._combined_op = self._combine(io_maps)
self._session = self.create_session()

def _load_op(self, arg: Union[onnx.ModelProto, str]) -> onnx.ModelProto: # type:ignore
def _load_op(self, arg: Union["onnx.ModelProto", str]) -> "onnx.ModelProto": # type:ignore
"""Loads an ONNX model, either from a file path or use the provided ONNX ModelProto.
Args:
Expand All @@ -63,7 +63,7 @@ def _load_op(self, arg: Union[onnx.ModelProto, str]) -> onnx.ModelProto: # type
return self.onnx_loader.load_model(arg)
return arg

def _combine(self, io_maps: Optional[list[tuple[str, str]]] = None) -> onnx.ModelProto: # type:ignore
def _combine(self, io_maps: Optional[list[tuple[str, str]]] = None) -> "onnx.ModelProto": # type:ignore
"""Combine the provided ONNX models into a single ONNX graph. Optionally, map inputs and outputs between
operators using the `io_map`.
Expand Down Expand Up @@ -106,8 +106,8 @@ def export(self, file_path: str) -> None:
def create_session(
self,
providers: Optional[list[str]] = None,
session_options: Optional[ort.SessionOptions] = None, # type:ignore
) -> ort.InferenceSession: # type:ignore
session_options: Optional["ort.SessionOptions"] = None, # type:ignore
) -> "ort.InferenceSession": # type:ignore
"""Create an optimized ONNXRuntime InferenceSession for the combined model.
Args:
Expand All @@ -129,7 +129,7 @@ def create_session(
)
return session

def set_session(self, session: ort.InferenceSession) -> None: # type: ignore
def set_session(self, session: "ort.InferenceSession") -> None: # type: ignore
"""Set a custom ONNXRuntime InferenceSession.
Args:
Expand All @@ -138,15 +138,15 @@ def set_session(self, session: ort.InferenceSession) -> None: # type: ignore
"""
self._session = session

def get_session(self) -> ort.InferenceSession: # type: ignore
def get_session(self) -> "ort.InferenceSession": # type: ignore
"""Get the current ONNXRuntime InferenceSession.
Returns:
ort.InferenceSession: The current ONNXRuntime session.
"""
return self._session

def __call__(self, *inputs: np.ndarray) -> list[np.ndarray]: # type:ignore
def __call__(self, *inputs: "np.ndarray") -> list["np.ndarray"]: # type:ignore
"""Perform inference using the combined ONNX model.
Args:
Expand Down
7 changes: 5 additions & 2 deletions kornia/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _get_file_path(self, model_name: str, cache_dir: Optional[str]) -> str:
file_path = os.path.join(cache_dir, "/".join(model_name.split("/")[:-1]), file_name)
return file_path

def load_model(self, model_name: str, download: bool = False, **kwargs) -> onnx.ModelProto: # type:ignore
def load_model(self, model_name: str, download: bool = False, **kwargs) -> "onnx.ModelProto": # type:ignore
"""Loads an ONNX model from the local cache or downloads it from Hugging Face if necessary.
Args:
Expand Down Expand Up @@ -108,7 +108,10 @@ def _fetch_repo_contents(folder: str) -> list[dict[str, Any]]:
List[dict]: A list of all files in the repository as dictionaries containing file details.
"""
url = f"https://huggingface.co/api/models/kornia/ONNX_models/tree/main/{folder}"
response = requests.get(url)

if not url.startswith(("http:", "https:")):
raise ValueError("URL must start with 'http:' or 'https:'")
response = requests.get(url, timeout=10)

if response.status_code == 200:
return response.json() # Returns the JSON content of the repo
Expand Down

0 comments on commit e3006bf

Please sign in to comment.