diff --git a/cuda/private/repositories.bzl b/cuda/private/repositories.bzl index 9a29f4c7..27265722 100644 --- a/cuda/private/repositories.bzl +++ b/cuda/private/repositories.bzl @@ -50,7 +50,7 @@ def detect_cuda_toolkit(repository_ctx): # Some distributions instead put CUDA binaries in a seperate path # Manually check and redirect there when necessary - alternative = repository_ctx.path('/usr/lib/nvidia-cuda-toolkit/bin/nvcc') + alternative = repository_ctx.path("/usr/lib/nvidia-cuda-toolkit/bin/nvcc") if str(ptxas_path) == "/usr/bin/ptxas" and alternative.exists: ptxas_path = alternative cuda_path = str(ptxas_path.dirname.dirname) @@ -206,8 +206,12 @@ local_cuda = repository_rule( # remotable = True, ) -def rules_cuda_dependencies(): - """Populate the dependencies for rules_cuda. This will setup workspace dependencies (other bazel rules) and local toolchains.""" +def rules_cuda_dependencies(toolkit_path = None): + """Populate the dependencies for rules_cuda. This will setup workspace dependencies (other bazel rules) and local toolchains. + + Args: + toolkit_path: Optionally specify the path to CUDA toolkit. If not specified, it will be detected automatically. + """ maybe( name = "bazel_skylib", repo_rule = http_archive, @@ -228,4 +232,4 @@ def rules_cuda_dependencies(): ], ) - local_cuda(name = "local_cuda") + local_cuda(name = "local_cuda", toolkit_path = toolkit_path)