Skip to content

Commit

Permalink
Fix cquery-ing with cuda targets
Browse files Browse the repository at this point in the history
  • Loading branch information
mvukov committed Jan 3, 2024
1 parent c595fd1 commit 4d86e4c
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 0 deletions.
2 changes: 2 additions & 0 deletions cuda/defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ load(
_use_cuda_toolchain = "use_cuda_toolchain",
)
load("//cuda/private:toolchain_configs/clang.bzl", _cuda_toolchain_config_clang = "cuda_toolchain_config")
load("//cuda/private:toolchain_configs/dummy.bzl", _cuda_toolchain_config_dummy = "dummy_toolchain_config")
load("//cuda/private:toolchain_configs/nvcc.bzl", _cuda_toolchain_config_nvcc = "cuda_toolchain_config")
load("//cuda/private:toolchain_configs/nvcc_msvc.bzl", _cuda_toolchain_config_nvcc_msvc = "cuda_toolchain_config")

Expand All @@ -24,6 +25,7 @@ cuda_toolchain = _cuda_toolchain
find_cuda_toolchain = _find_cuda_toolchain
use_cuda_toolchain = _use_cuda_toolchain
cuda_toolchain_config_clang = _cuda_toolchain_config_clang
cuda_toolchain_config_dummy = _cuda_toolchain_config_dummy
cuda_toolchain_config_nvcc_msvc = _cuda_toolchain_config_nvcc_msvc
cuda_toolchain_config_nvcc = _cuda_toolchain_config_nvcc

Expand Down
6 changes: 6 additions & 0 deletions cuda/private/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,19 @@ def config_clang(repository_ctx, cuda, clang_path):
}
repository_ctx.template("toolchain/clang/BUILD", tpl_label, substitutions = substitutions, executable = False)

def config_none(repository_ctx):
tpl_label = Label("//cuda:templates/BUILD.local_toolchain_none")
repository_ctx.template("toolchain/none/BUILD", tpl_label, executable = False)

def _local_cuda_impl(repository_ctx):
cuda = detect_cuda_toolkit(repository_ctx)
config_cuda_toolkit_and_nvcc(repository_ctx, cuda)

clang_path = detect_clang(repository_ctx)
config_clang(repository_ctx, cuda, clang_path)

config_none(repository_ctx)

local_cuda = repository_rule(
implementation = _local_cuda_impl,
attrs = {"toolkit_path": attr.string(mandatory = False)},
Expand Down
1 change: 1 addition & 0 deletions cuda/private/toolchain.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,5 @@ User can setup their own toolchain if needed and ignore the detected ones by not
native.register_toolchains(
"@local_cuda//toolchain:nvcc-local-toolchain",
"@local_cuda//toolchain/clang:clang-local-toolchain",
"@local_cuda//toolchain/none:none-local-toolchain",
)
4 changes: 4 additions & 0 deletions cuda/private/toolchain_configs/dummy.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def _dummy_toolchain_config_impl(_ctx):
return [platform_common.ToolchainInfo()]

dummy_toolchain_config = rule(_dummy_toolchain_config_impl, attrs = {})
16 changes: 16 additions & 0 deletions cuda/templates/BUILD.local_toolchain_none
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
load("@rules_cuda//cuda:defs.bzl", "cuda_toolchain_config_dummy")

config_setting(
name = "cuda_is_disabled",
flag_values = {"@rules_cuda//cuda:enable": "False"},
)

cuda_toolchain_config_dummy(name = "dummy-local")

toolchain(
name = "none-local-toolchain",
target_settings = [":cuda_is_disabled"],
toolchain = ":dummy-local",
toolchain_type = "@rules_cuda//cuda:toolchain_type",
visibility = ["//visibility:public"],
)
4 changes: 4 additions & 0 deletions examples/if_cuda/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ cuda_library(
name = "kernel",
srcs = ["kernel.cu"],
hdrs = ["kernel.h"],
target_compatible_with = select({
"@rules_cuda//cuda:is_enabled": [],
"//conditions:default": ["@platforms//:incompatible"],
}),
)

cc_binary(
Expand Down

0 comments on commit 4d86e4c

Please sign in to comment.