diff --git a/MODULE.bazel b/MODULE.bazel index 7eaba5b8..07c72ff8 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -8,6 +8,7 @@ bazel_dep(name = "bazel_skylib", version = "1.4.2") bazel_dep(name = "platforms", version = "0.0.6") cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain") +cuda.local_toolchain(name = "local_cuda", toolkit_path = "") use_repo(cuda, "local_cuda") register_toolchains( diff --git a/cuda/extensions.bzl b/cuda/extensions.bzl index b6014dd8..e320ae7b 100644 --- a/cuda/extensions.bzl +++ b/cuda/extensions.bzl @@ -7,19 +7,35 @@ cuda_toolkit = tag_class(attrs = { "toolkit_path": attr.string(doc = "Path to the CUDA SDK, if empty the environment variable CUDA_PATH will be used to deduce this path."), }) +def _find_modules(module_ctx): + root = None + our_module = None + for mod in module_ctx.modules: + if mod.is_root: + root = mod + if mod.name == "rules_cuda": + our_module = mod + if root == None: + root = our_module + if our_module == None: + fail("Unable to find rules_cuda module") + + return root, our_module + def _init(module_ctx): + # Toolchain configuration is only allowed in the root module, or in rules_cuda. + root, rules_cuda = _find_modules(module_ctx) + toolchains = root.tags.local_toolchain or rules_cuda.tags.local_toolchain + registrations = {} - for mod in module_ctx.modules: - for toolchain in mod.tags.local_toolchain: - if not mod.is_root: - fail("Only the root module may override the path for the local cuda toolchain") - if toolchain.name in registrations.keys(): - if toolchain.toolkit_path == registrations[toolchain.name]: - # No problem to register a matching toolchain twice - continue - fail("Multiple conflicting toolchains declared for name {} ({} and {}".format(toolchain.name, toolchain.toolkit_path, registrations[toolchain.name])) - else: - registrations[toolchain.name] = toolchain.toolkit_path + for toolchain in toolchains: + if toolchain.name in registrations.keys(): + if toolchain.toolkit_path == registrations[toolchain.name]: + # No problem to register a matching toolchain twice + continue + fail("Multiple conflicting toolchains declared for name {} ({} and {}".format(toolchain.name, toolchain.toolkit_path, registrations[toolchain.name])) + else: + registrations[toolchain.name] = toolchain.toolkit_path for name, toolkit_path in registrations.items(): local_cuda(name = name, toolkit_path = toolkit_path)