Skip to content

Commit

Permalink
Support toolkit_path in rules_cuda_dependencies (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
garymm authored Jul 6, 2023
1 parent 22578d7 commit 33c3843
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions cuda/private/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def detect_cuda_toolkit(repository_ctx):

# Some distributions instead put CUDA binaries in a seperate path
# Manually check and redirect there when necessary
alternative = repository_ctx.path('/usr/lib/nvidia-cuda-toolkit/bin/nvcc')
alternative = repository_ctx.path("/usr/lib/nvidia-cuda-toolkit/bin/nvcc")
if str(ptxas_path) == "/usr/bin/ptxas" and alternative.exists:
ptxas_path = alternative
cuda_path = str(ptxas_path.dirname.dirname)
Expand Down Expand Up @@ -206,8 +206,12 @@ local_cuda = repository_rule(
# remotable = True,
)

def rules_cuda_dependencies():
"""Populate the dependencies for rules_cuda. This will setup workspace dependencies (other bazel rules) and local toolchains."""
def rules_cuda_dependencies(toolkit_path = None):
"""Populate the dependencies for rules_cuda. This will setup workspace dependencies (other bazel rules) and local toolchains.
Args:
toolkit_path: Optionally specify the path to CUDA toolkit. If not specified, it will be detected automatically.
"""
maybe(
name = "bazel_skylib",
repo_rule = http_archive,
Expand All @@ -228,4 +232,4 @@ def rules_cuda_dependencies():
],
)

local_cuda(name = "local_cuda")
local_cuda(name = "local_cuda", toolkit_path = toolkit_path)

0 comments on commit 33c3843

Please sign in to comment.