Skip to content

Commit

Permalink
Make the change more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan committed Jan 7, 2024
1 parent 29aa9aa commit d15b965
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
17 changes: 11 additions & 6 deletions cuda/private/macros/cuda_binary.bzl
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
load("//cuda/private:rules/cuda_library.bzl", _cuda_library = "cuda_library")

def cuda_binary(name, alwayslink = True, **attrs):
def cuda_binary(name, **attrs):
"""A macro wraps cuda_library and cc_binary to ensure the binary is compiled with the CUDA compiler.
Args:
name: A unique name for this target (cc_binary).
alwayslink: pass to the hidden cuda_library target.
**attrs: attrs of cc_binary and cuda_library.
"""
cuda_library_only_attrs = ["deps", "srcs", "hdrs"]
cuda_library_only_attrs = ["deps", "srcs", "hdrs", "alwayslink"]
cuda_library_only_attrs_defaults = {
"alwayslink": True,
}

# https://bazel.build/reference/be/common-definitions?hl=en#common-attributes-binaries
cc_binary_only_attrs = ["args", "env", "output_licenses"]

cuda_library_name = "_" + name
cuda_library_attrs = {k: v for k, v in attrs.items() if k not in cc_binary_only_attrs}
for attr in cuda_library_only_attrs_defaults:
if attr not in cuda_library_attrs:
cuda_library_attrs[attr] = cuda_library_only_attrs_defaults[attr]

cuda_library_name = "_" + name
_cuda_library(
name = cuda_library_name,
alwayslink = alwayslink,
**{k: v for k, v in attrs.items() if k not in cc_binary_only_attrs}
**cuda_library_attrs
)

native.cc_binary(
Expand Down
19 changes: 12 additions & 7 deletions cuda/private/macros/cuda_test.bzl
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
load("//cuda/private:rules/cuda_library.bzl", _cuda_library = "cuda_library")

def cuda_test(name, alwayslink = True, **attrs):
def cuda_test(name, **attrs):
"""A macro wraps cuda_library and cc_test to ensure the test is compiled with the CUDA compiler.
Args:
name: A unique name for this target (cc_test).
alwayslink: pass to the hidden cuda_library target.
**attrs: attrs of cc_test and cuda_library.
"""
cuda_library_only_attrs = ["deps", "srcs", "hdrs"]
cuda_library_only_attrs = ["deps", "srcs", "hdrs", "testonly", "alwayslink"]
cuda_library_only_attrs_defaults = {
"testonly": True,
"alwayslink": True,
}

# https://bazel.build/reference/be/common-definitions?hl=en#common-attributes-tests
cc_test_only_attrs = ["args", "env", "env_inherit", "size", "timeout", "flaky", "shard_count", "local"]

cuda_library_name = "_" + name
cuda_library_attrs = {k: v for k, v in attrs.items() if k not in cc_test_only_attrs}
for attr in cuda_library_only_attrs_defaults:
if attr not in cuda_library_attrs:
cuda_library_attrs[attr] = cuda_library_only_attrs_defaults[attr]

cuda_library_name = "_" + name
_cuda_library(
name = cuda_library_name,
alwayslink = alwayslink,
testonly = True,
**{k: v for k, v in attrs.items() if k not in cc_test_only_attrs}
**cuda_library_attrs,
)

native.cc_test(
Expand Down

0 comments on commit d15b965

Please sign in to comment.