Skip to content

Commit

Permalink
fix(bzlmod): allow both root module and our module to call cuda.local…
Browse files Browse the repository at this point in the history
…_toolchain
  • Loading branch information
cloudhan committed Aug 7, 2024
1 parent b82e0d9 commit b5f4392
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
1 change: 1 addition & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ bazel_dep(name = "bazel_skylib", version = "1.4.2")
bazel_dep(name = "platforms", version = "0.0.6")

cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
cuda.local_toolchain(name = "local_cuda", toolkit_path = "")
use_repo(cuda, "local_cuda")

register_toolchains(
Expand Down
38 changes: 27 additions & 11 deletions cuda/extensions.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,35 @@ cuda_toolkit = tag_class(attrs = {
"toolkit_path": attr.string(doc = "Path to the CUDA SDK, if empty the environment variable CUDA_PATH will be used to deduce this path."),
})

def _find_modules(module_ctx):
root = None
our_module = None
for mod in module_ctx.modules:
if mod.is_root:
root = mod
if mod.name == "rules_cuda":
our_module = mod
if root == None:
root = our_module
if our_module == None:
fail("Unable to find rules_cuda module")

return root, our_module

def _init(module_ctx):
# Toolchain configuration is only allowed in the root module, or in rules_cuda.
root, rules_cuda = _find_modules(module_ctx)
toolchains = root.tags.local_toolchain or rules_cuda.tags.local_toolchain

registrations = {}
for mod in module_ctx.modules:
for toolchain in mod.tags.local_toolchain:
if not mod.is_root:
fail("Only the root module may override the path for the local cuda toolchain")
if toolchain.name in registrations.keys():
if toolchain.toolkit_path == registrations[toolchain.name]:
# No problem to register a matching toolchain twice
continue
fail("Multiple conflicting toolchains declared for name {} ({} and {}".format(toolchain.name, toolchain.toolkit_path, registrations[toolchain.name]))
else:
registrations[toolchain.name] = toolchain.toolkit_path
for toolchain in toolchains:
if toolchain.name in registrations.keys():
if toolchain.toolkit_path == registrations[toolchain.name]:
# No problem to register a matching toolchain twice
continue
fail("Multiple conflicting toolchains declared for name {} ({} and {}".format(toolchain.name, toolchain.toolkit_path, registrations[toolchain.name]))
else:
registrations[toolchain.name] = toolchain.toolkit_path
for name, toolkit_path in registrations.items():
local_cuda(name = name, toolkit_path = toolkit_path)

Expand Down

0 comments on commit b5f4392

Please sign in to comment.