diff --git a/cuda/defs.bzl b/cuda/defs.bzl index e7a948f4..80f3b7df 100644 --- a/cuda/defs.bzl +++ b/cuda/defs.bzl @@ -16,6 +16,7 @@ load( _use_cuda_toolchain = "use_cuda_toolchain", ) load("//cuda/private:toolchain_configs/clang.bzl", _cuda_toolchain_config_clang = "cuda_toolchain_config") +load("//cuda/private:toolchain_configs/dummy.bzl", _cuda_toolchain_config_dummy = "dummy_toolchain_config") load("//cuda/private:toolchain_configs/nvcc.bzl", _cuda_toolchain_config_nvcc = "cuda_toolchain_config") load("//cuda/private:toolchain_configs/nvcc_msvc.bzl", _cuda_toolchain_config_nvcc_msvc = "cuda_toolchain_config") @@ -24,6 +25,7 @@ cuda_toolchain = _cuda_toolchain find_cuda_toolchain = _find_cuda_toolchain use_cuda_toolchain = _use_cuda_toolchain cuda_toolchain_config_clang = _cuda_toolchain_config_clang +cuda_toolchain_config_dummy = _cuda_toolchain_config_dummy cuda_toolchain_config_nvcc_msvc = _cuda_toolchain_config_nvcc_msvc cuda_toolchain_config_nvcc = _cuda_toolchain_config_nvcc diff --git a/cuda/private/repositories.bzl b/cuda/private/repositories.bzl index 2dd95e3b..8720b887 100644 --- a/cuda/private/repositories.bzl +++ b/cuda/private/repositories.bzl @@ -190,6 +190,10 @@ def config_clang(repository_ctx, cuda, clang_path): } repository_ctx.template("toolchain/clang/BUILD", tpl_label, substitutions = substitutions, executable = False) +def config_none(repository_ctx): + tpl_label = Label("//cuda:templates/BUILD.local_toolchain_none") + repository_ctx.template("toolchain/none/BUILD", tpl_label, executable = False) + def _local_cuda_impl(repository_ctx): cuda = detect_cuda_toolkit(repository_ctx) config_cuda_toolkit_and_nvcc(repository_ctx, cuda) @@ -197,6 +201,8 @@ def _local_cuda_impl(repository_ctx): clang_path = detect_clang(repository_ctx) config_clang(repository_ctx, cuda, clang_path) + config_none(repository_ctx) + local_cuda = repository_rule( implementation = _local_cuda_impl, attrs = {"toolkit_path": attr.string(mandatory = False)}, diff --git a/cuda/private/toolchain.bzl b/cuda/private/toolchain.bzl index 132d3f22..89948852 100644 --- a/cuda/private/toolchain.bzl +++ b/cuda/private/toolchain.bzl @@ -86,4 +86,5 @@ User can setup their own toolchain if needed and ignore the detected ones by not native.register_toolchains( "@local_cuda//toolchain:nvcc-local-toolchain", "@local_cuda//toolchain/clang:clang-local-toolchain", + "@local_cuda//toolchain/none:none-local-toolchain", ) diff --git a/cuda/private/toolchain_configs/dummy.bzl b/cuda/private/toolchain_configs/dummy.bzl new file mode 100644 index 00000000..d2659230 --- /dev/null +++ b/cuda/private/toolchain_configs/dummy.bzl @@ -0,0 +1,4 @@ +def _dummy_toolchain_config_impl(_ctx): + return [platform_common.ToolchainInfo()] + +dummy_toolchain_config = rule(_dummy_toolchain_config_impl, attrs = {}) diff --git a/cuda/templates/BUILD.local_toolchain_none b/cuda/templates/BUILD.local_toolchain_none new file mode 100644 index 00000000..1948210c --- /dev/null +++ b/cuda/templates/BUILD.local_toolchain_none @@ -0,0 +1,16 @@ +load("@rules_cuda//cuda:defs.bzl", "cuda_toolchain_config_dummy") + +config_setting( + name = "cuda_is_disabled", + flag_values = {"@rules_cuda//cuda:enable": "False"}, +) + +cuda_toolchain_config_dummy(name = "dummy-local") + +toolchain( + name = "none-local-toolchain", + target_settings = [":cuda_is_disabled"], + toolchain = ":dummy-local", + toolchain_type = "@rules_cuda//cuda:toolchain_type", + visibility = ["//visibility:public"], +) diff --git a/examples/if_cuda/BUILD.bazel b/examples/if_cuda/BUILD.bazel index bbe313fb..923ec9e9 100644 --- a/examples/if_cuda/BUILD.bazel +++ b/examples/if_cuda/BUILD.bazel @@ -4,6 +4,10 @@ cuda_library( name = "kernel", srcs = ["kernel.cu"], hdrs = ["kernel.h"], + target_compatible_with = select({ + "@rules_cuda//cuda:is_enabled": [], + "//conditions:default": ["@platforms//:incompatible"], + }), ) cc_binary(