diff --git a/cuda/private/macros/cuda_binary.bzl b/cuda/private/macros/cuda_binary.bzl index 20a158ce..91911b2f 100644 --- a/cuda/private/macros/cuda_binary.bzl +++ b/cuda/private/macros/cuda_binary.bzl @@ -1,17 +1,29 @@ load("//cuda/private:rules/cuda_library.bzl", _cuda_library = "cuda_library") def cuda_binary(name, **attrs): - """Wrapper to ensure the binary is compiled with the CUDA compiler.""" - cuda_library_only_attrs = ["deps", "srcs", "hdrs"] + """A macro wraps cuda_library and cc_binary to ensure the binary is compiled with the CUDA compiler. + + Args: + name: A unique name for this target (cc_binary). + **attrs: attrs of cc_binary and cuda_library. + """ + cuda_library_only_attrs = ["deps", "srcs", "hdrs", "alwayslink"] + cuda_library_only_attrs_defaults = { + "alwayslink": True, + } # https://bazel.build/reference/be/common-definitions?hl=en#common-attributes-binaries cc_binary_only_attrs = ["args", "env", "output_licenses"] - cuda_library_name = "_" + name + cuda_library_attrs = {k: v for k, v in attrs.items() if k not in cc_binary_only_attrs} + for attr in cuda_library_only_attrs_defaults: + if attr not in cuda_library_attrs: + cuda_library_attrs[attr] = cuda_library_only_attrs_defaults[attr] + cuda_library_name = "_" + name _cuda_library( name = cuda_library_name, - **{k: v for k, v in attrs.items() if k not in cc_binary_only_attrs} + **cuda_library_attrs ) native.cc_binary( diff --git a/cuda/private/macros/cuda_test.bzl b/cuda/private/macros/cuda_test.bzl index b69060be..99442fea 100644 --- a/cuda/private/macros/cuda_test.bzl +++ b/cuda/private/macros/cuda_test.bzl @@ -1,18 +1,30 @@ load("//cuda/private:rules/cuda_library.bzl", _cuda_library = "cuda_library") def cuda_test(name, **attrs): - """Wrapper to ensure the test is compiled with the CUDA compiler.""" - cuda_library_only_attrs = ["deps", "srcs", "hdrs"] + """A macro wraps cuda_library and cc_test to ensure the test is compiled with the CUDA compiler. + + Args: + name: A unique name for this target (cc_test). + **attrs: attrs of cc_test and cuda_library. + """ + cuda_library_only_attrs = ["deps", "srcs", "hdrs", "testonly", "alwayslink"] + cuda_library_only_attrs_defaults = { + "testonly": True, + "alwayslink": True, + } # https://bazel.build/reference/be/common-definitions?hl=en#common-attributes-tests cc_test_only_attrs = ["args", "env", "env_inherit", "size", "timeout", "flaky", "shard_count", "local"] - cuda_library_name = "_" + name + cuda_library_attrs = {k: v for k, v in attrs.items() if k not in cc_test_only_attrs} + for attr in cuda_library_only_attrs_defaults: + if attr not in cuda_library_attrs: + cuda_library_attrs[attr] = cuda_library_only_attrs_defaults[attr] + cuda_library_name = "_" + name _cuda_library( name = cuda_library_name, - testonly = True, - **{k: v for k, v in attrs.items() if k not in cc_test_only_attrs} + **cuda_library_attrs ) native.cc_test(