Skip to content

Commit

Permalink
feat(cli): Support GPU acceleration via ONNX Runtime DirectML on Wind…
Browse files Browse the repository at this point in the history
…ows (#176)

* feat(cli): Support GPU acceleration via ONNX Runtime DirectML on Windows

This commit adds support for GPU acceleration in the CLI using ONNX Runtime DirectML. This allows users with compatible GPUs on Windows to leverage their hardware for faster inference.

* fix small bug
  • Loading branch information
ccddos authored May 12, 2024
1 parent 24c9fe5 commit 01430ed
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
29 changes: 24 additions & 5 deletions python/rapidocr_onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,37 @@ def __init__(self, config):
}

EP_list = []
if (
config["use_cuda"]
and get_device() == "GPU"
and cuda_ep in get_available_providers()
):
is_use_cude = config["use_cuda"] and get_device() == "GPU" and cuda_ep in get_available_providers()
if (is_use_cude):
EP_list = [(cuda_ep, cuda_provider_options)]
EP_list.append((cpu_ep, cpu_provider_options))

# if platform is windows, use directml as primary provider
if os.name == "nt":
directml_ep = "DmlExecutionProvider"
# print (get_available_providers())
if directml_ep in get_available_providers():
print ("Windows platform detected, try to use DirectML as primary provider")
EP_list.insert(0, (directml_ep,
cuda_provider_options if is_use_cude else cpu_provider_options
))


self._verify_model(config["model_path"])
self.session = InferenceSession(
config["model_path"], sess_options=sess_opt, providers=EP_list
)

# TODO: verify this is correct for detecting current_provider
current_provider = self.session.get_providers()[0]

# verify if the DirectML provider is used
if os.name == "nt":
if current_provider != directml_ep:
warnings.warn(
f"DirectML is not available for the current environment, the inference part is automatically shifted to be executed under other EP.\n"
)


if config["use_cuda"] and cuda_ep not in self.session.get_providers():
warnings.warn(
Expand Down
4 changes: 3 additions & 1 deletion python/requirements_ort.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
pyclipper>=1.2.0
onnxruntime>=1.7.0
opencv_python>=4.5.1.48
numpy>=1.19.5
six>=1.15.0
Shapely>=1.7.1
PyYAML
Pillow
# install the onnxruntime-directml if on windows platform, notice that the onnxruntime-directml is conflict with onnxruntime, we can only install one of them
onnxruntime-directml;platform_system=="Windows"
onnxruntime>=1.7.0;platform_system!="Windows"

0 comments on commit 01430ed

Please sign in to comment.