Skip to content

Commit

Permalink
Added GPU acceleration support for Apple Silicon (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
NripeshN authored Aug 22, 2023
1 parent 4f5e239 commit 6388a80
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion nbs/homography.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
"error_tol: float = 1e-8 # the optimisation error tolerance\n",
"\n",
"log_interval: int = 100 # print log every N iterations\n",
"device = K.utils.get_cuda_device_if_available()\n",
"device = K.utils.get_cuda_or_mps_device_if_available()\n",
"print(\"Using \", device)"
]
},
Expand Down
2 changes: 1 addition & 1 deletion nbs/image_matching_adalam.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
"import torch\n",
"from kornia_moons.viz import *\n",
"\n",
"# device = K.utils.get_cuda_device_if_available()\n",
"# device = K.utils.get_cuda_or_mps_device_if_available()\n",
"device = torch.device(\"cpu\")"
]
},
Expand Down
3 changes: 1 addition & 2 deletions nbs/image_matching_disk.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@
"from kornia.feature.adalam import AdalamFilter\n",
"from kornia_moons.viz import *\n",
"\n",
"# device = torch.device(\"mps\") # it works also for Apple M1/M2\n",
"device = K.utils.get_cuda_device_if_available()\n",
"device = K.utils.get_cuda_or_mps_device_if_available()\n",
"print(device)"
]
},
Expand Down
4 changes: 2 additions & 2 deletions nbs/image_prompter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
"from kornia.geometry.boxes import Boxes\n",
"from kornia.geometry.keypoints import Keypoints\n",
"from kornia.io import ImageLoadType, load_image\n",
"from kornia.utils import get_cuda_device_if_available, tensor_to_image"
"from kornia.utils import get_cuda_or_mps_device_if_available, tensor_to_image"
]
},
{
Expand All @@ -143,7 +143,7 @@
}
],
"source": [
"device = get_cuda_device_if_available()\n",
"device = get_cuda_or_mps_device_if_available()\n",
"print(device)"
]
},
Expand Down

0 comments on commit 6388a80

Please sign in to comment.