Skip to content

Commit

Permalink
Replace deprecated torch.cuda.amp decorators for torch 2.4 (kornia#2967)
Browse files Browse the repository at this point in the history
* Replace deprecated torch.cuda.amp decorators for torch 2.4

* Fix custom_fwd typing

* Try to fix mypy warning on new custom_fwd import

* Import torch.amp decorators from _compat module
  • Loading branch information
loichuder authored Aug 2, 2024
1 parent b338aa0 commit ef82436
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
2 changes: 1 addition & 1 deletion kornia/contrib/models/efficient_vit/nn/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.cuda.amp import autocast

from kornia.contrib.models.efficient_vit.nn.act import build_act
from kornia.contrib.models.efficient_vit.nn.norm import build_norm
from kornia.contrib.models.efficient_vit.utils import get_same_padding, val2tuple
from kornia.utils._compat import autocast

__all__ = [
"ConvLayer",
Expand Down
3 changes: 2 additions & 1 deletion kornia/feature/lightglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from kornia.core.check import KORNIA_CHECK
from kornia.feature.laf import laf_to_three_points, scale_laf
from kornia.utils._compat import custom_fwd

try:
from flash_attn.modules.mha import FlashCrossAttention
Expand All @@ -41,7 +42,7 @@ def math_clamp(x, min_, max_): # type: ignore
return max(min(x, min_), min_)


@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
@custom_fwd(cast_inputs=torch.float32)
def normalize_keypoints(kpts: Tensor, size: Tensor) -> Tensor:
if isinstance(size, torch.Size):
size = Tensor(size)[None]
Expand Down
18 changes: 17 additions & 1 deletion kornia/utils/_compat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Callable, ContextManager, List, Optional, Tuple, TypeVar
from typing import TYPE_CHECKING, Any, Callable, ContextManager, List, Optional, Tuple, TypeVar

import torch
from packaging import version
Expand Down Expand Up @@ -52,3 +52,19 @@ def torch_meshgrid(tensors: List[Tensor], indexing: str):
else:
# TODO: remove this branch when kornia relies on torch >= 1.10.0
torch_inference_mode = torch.no_grad

if TYPE_CHECKING: # TODO (@johnnv1): remove this branch when bump the pytorch CI to support torch 2.4
custom_fwd: Callable[..., Any]
autocast: Callable[..., Any]
elif torch_version_ge(2, 4):
from functools import partial

from torch.amp import autocast as _autocast
from torch.amp import custom_fwd as _custom_fwd

custom_fwd = partial(_custom_fwd, device_type="cuda")
autocast = partial(_autocast, "cuda")

else:
custom_fwd = torch.cuda.amp.custom_fwd
autocast = torch.cuda.amp.autocast

0 comments on commit ef82436

Please sign in to comment.