-
Notifications
You must be signed in to change notification settings - Fork 278
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move Torch-TRT install to file (#2092)
Summary: - Add custom installer support for userbenchmark testing - Add support for installing Torch-TRT outside of main container installs for nightly runs - Add necessary hooks and subprocess commands in code Pull Request resolved: #2092 Reviewed By: aaronenyeshi Differential Revision: D52266771 Pulled By: xuzhao9 fbshipit-source-id: 369009a5bd1d5681a8aa7f72f736592da952ffda
- Loading branch information
1 parent
b599ae4
commit 11376ae
Showing
3 changed files
with
54 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import subprocess | ||
|
||
import torch | ||
|
||
|
||
def install_torch_tensorrt(): | ||
# Install Torch-TensorRT with validation | ||
uninstall_torchtrt_cmd = ["pip", "uninstall", "-y", "torch_tensorrt"] | ||
subprocess.check_call(uninstall_torchtrt_cmd) | ||
|
||
if torch.version.cuda.startswith("12"): | ||
cuda_index_modifier = "cu121" | ||
elif torch.version.cuda.startswith("11"): | ||
cuda_index_modifier = "cu118" | ||
else: | ||
raise AssertionError( | ||
f"Detected Torch-TRT unsupported CUDA version {torch.version.cuda}" | ||
) | ||
|
||
pytorch_nightly_url = ( | ||
f"https://download.pytorch.org/whl/nightly/{cuda_index_modifier}" | ||
) | ||
install_torchtrt_cmd = [ | ||
"pip", | ||
"install", | ||
"--pre", | ||
"--no-cache-dir", | ||
"torch_tensorrt", | ||
"--extra-index-url", | ||
pytorch_nightly_url, | ||
] | ||
validate_torchtrt_cmd = ["python", "-c", "'import torch_tensorrt'"] | ||
|
||
# Install and validate Torch-TensorRT | ||
try: | ||
subprocess.check_call(install_torchtrt_cmd) | ||
subprocess.check_call(validate_torchtrt_cmd) | ||
except subprocess.CalledProcessError: | ||
subprocess.check_call(uninstall_torchtrt_cmd) | ||
print("Failed to install torch-tensorrt, skipping install") | ||
|
||
|
||
if __name__ == "__main__": | ||
install_torch_tensorrt() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters