Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
shijianjian committed Sep 10, 2024
1 parent 62c4295 commit 0a15677
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 24 deletions.
16 changes: 14 additions & 2 deletions docs/source/onnx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ Here's how you can quickly get started with `ONNXSequential`:
# Initialize with CUDA execution provider
onnx_seq = ONNXSequential(
"hf://operators/kornia.color.gray.RgbToGrayscale",
"hf://operators/kornia.geometry.transform.affwarp.Resize_512x512",
"hf://operators/kornia.geometry.transform.flips.Hflip",
# Or you may use a local model with either a filepath "YOUR_OWN_MODEL.onnx" or a loaded ONNX model.
"hf://models/kornia.models.detection.rtdetr_r18vd_640x640",
providers=['CUDAExecutionProvider']
)
Expand Down Expand Up @@ -122,3 +123,14 @@ API Documentation
-----------------
.. autoclass:: kornia.onnx.sequential.ONNXSequential
:members:

.. autoclass:: kornia.onnx.utils.ONNXLoader

.. code-block:: python
onnx_loader = ONNXLoader()
# Load a HuggingFace operator
onnx_loader.load_model("hf://operators/kornia.color.gray.GrayscaleToRgb")
# Load a local converted/downloaded operator
onnx_loader.load_model("operators/kornia.color.gray.GrayscaleToRgb")
:members:
12 changes: 6 additions & 6 deletions kornia/color/gray.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ class GrayscaleToRgb(Module):
>>> output = rgb(input) # 2x3x4x5
"""

ONNX_DEFAULT_INPUTSHAPE: ClassVar[List[int]] = [-1, 1, -1, -1]
ONNX_DEFAULT_OUTPUTSHAPE: ClassVar[List[int]] = [-1, 3, -1, -1]
ONNX_DEFAULT_INPUTSHAPE: ClassVar[list[int]] = [-1, 1, -1, -1]
ONNX_DEFAULT_OUTPUTSHAPE: ClassVar[list[int]] = [-1, 3, -1, -1]

def forward(self, image: Tensor) -> Tensor:
return grayscale_to_rgb(image)
Expand All @@ -150,8 +150,8 @@ class RgbToGrayscale(Module):
>>> output = gray(input) # 2x1x4x5
"""

ONNX_DEFAULT_INPUTSHAPE: ClassVar[List[int]] = [-1, 3, -1, -1]
ONNX_DEFAULT_OUTPUTSHAPE: ClassVar[List[int]] = [-1, 1, -1, -1]
ONNX_DEFAULT_INPUTSHAPE: ClassVar[list[int]] = [-1, 3, -1, -1]
ONNX_DEFAULT_OUTPUTSHAPE: ClassVar[list[int]] = [-1, 1, -1, -1]

def __init__(self, rgb_weights: Optional[Tensor] = None) -> None:
super().__init__()
Expand Down Expand Up @@ -181,8 +181,8 @@ class BgrToGrayscale(Module):
>>> output = gray(input) # 2x1x4x5
"""

ONNX_DEFAULT_INPUTSHAPE: ClassVar[List[int]] = [-1, 3, -1, -1]
ONNX_DEFAULT_OUTPUTSHAPE: ClassVar[List[int]] = [-1, 1, -1, -1]
ONNX_DEFAULT_INPUTSHAPE: ClassVar[list[int]] = [-1, 3, -1, -1]
ONNX_DEFAULT_OUTPUTSHAPE: ClassVar[list[int]] = [-1, 1, -1, -1]

def forward(self, image: Tensor) -> Tensor:
return bgr_to_grayscale(image)
8 changes: 0 additions & 8 deletions kornia/onnx/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,6 @@ class ONNXSequential:
cache_dir:
cache_dir: The directory where ONNX models are cached locally (only for downloading from HuggingFace).
Defaults to None, which will use a default `.kornia_onnx_models` directory.
.. code-block:: python
# Load ops from HuggingFace repos then chain to your own model!
model = kornia.onnx.ONNXSequential(
"hf://operators/kornia.color.gray.RgbToGrayscale",
"hf://operators/kornia.geometry.transform.affwarp.Resize_512x512",
"MY_OTHER_MODEL.onnx"
)
"""

def __init__(
Expand Down
9 changes: 1 addition & 8 deletions kornia/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,6 @@ class ONNXLoader:
Attributes:
cache_dir: The directory where ONNX models are cached locally.
Defaults to None, which will use a default `.kornia_hub/onnx_models` directory.
.. code-block:: python
onnx_loader = ONNXLoader()
# Load a HuggingFace operator
onnx_loader.load_model("hf://operators/kornia.color.gray.GrayscaleToRgb")
# Load a local converted/downloaded operator
onnx_loader.load_model("operators/kornia.color.gray.GrayscaleToRgb")
"""

def __init__(self, cache_dir: Optional[str] = None):
Expand Down Expand Up @@ -103,7 +96,7 @@ def download(
if url.startswith(("http:", "https:")):
try:
logger.info(f"Downloading `{url}` to `{file_path}`.")
urllib.request.urlretrieve(url, file_path)
urllib.request.urlretrieve(url, file_path) # noqa: S310
except urllib.error.HTTPError as e:
raise ValueError(f"Error in resolving `{url}`. {e}.")
else:
Expand Down

0 comments on commit 0a15677

Please sign in to comment.