diff --git a/cuda/private/macros/cuda_binary.bzl b/cuda/private/macros/cuda_binary.bzl index 2fb1c024..91911b2f 100644 --- a/cuda/private/macros/cuda_binary.bzl +++ b/cuda/private/macros/cuda_binary.bzl @@ -1,24 +1,29 @@ load("//cuda/private:rules/cuda_library.bzl", _cuda_library = "cuda_library") -def cuda_binary(name, alwayslink = True, **attrs): +def cuda_binary(name, **attrs): """A macro wraps cuda_library and cc_binary to ensure the binary is compiled with the CUDA compiler. Args: name: A unique name for this target (cc_binary). - alwayslink: pass to the hidden cuda_library target. **attrs: attrs of cc_binary and cuda_library. """ - cuda_library_only_attrs = ["deps", "srcs", "hdrs"] + cuda_library_only_attrs = ["deps", "srcs", "hdrs", "alwayslink"] + cuda_library_only_attrs_defaults = { + "alwayslink": True, + } # https://bazel.build/reference/be/common-definitions?hl=en#common-attributes-binaries cc_binary_only_attrs = ["args", "env", "output_licenses"] - cuda_library_name = "_" + name + cuda_library_attrs = {k: v for k, v in attrs.items() if k not in cc_binary_only_attrs} + for attr in cuda_library_only_attrs_defaults: + if attr not in cuda_library_attrs: + cuda_library_attrs[attr] = cuda_library_only_attrs_defaults[attr] + cuda_library_name = "_" + name _cuda_library( name = cuda_library_name, - alwayslink = alwayslink, - **{k: v for k, v in attrs.items() if k not in cc_binary_only_attrs} + **cuda_library_attrs ) native.cc_binary( diff --git a/cuda/private/macros/cuda_test.bzl b/cuda/private/macros/cuda_test.bzl index caa8e3f6..cc5255f5 100644 --- a/cuda/private/macros/cuda_test.bzl +++ b/cuda/private/macros/cuda_test.bzl @@ -1,25 +1,30 @@ load("//cuda/private:rules/cuda_library.bzl", _cuda_library = "cuda_library") -def cuda_test(name, alwayslink = True, **attrs): +def cuda_test(name, **attrs): """A macro wraps cuda_library and cc_test to ensure the test is compiled with the CUDA compiler. Args: name: A unique name for this target (cc_test). - alwayslink: pass to the hidden cuda_library target. **attrs: attrs of cc_test and cuda_library. """ - cuda_library_only_attrs = ["deps", "srcs", "hdrs"] + cuda_library_only_attrs = ["deps", "srcs", "hdrs", "testonly", "alwayslink"] + cuda_library_only_attrs_defaults = { + "testonly": True, + "alwayslink": True, + } # https://bazel.build/reference/be/common-definitions?hl=en#common-attributes-tests cc_test_only_attrs = ["args", "env", "env_inherit", "size", "timeout", "flaky", "shard_count", "local"] - cuda_library_name = "_" + name + cuda_library_attrs = {k: v for k, v in attrs.items() if k not in cc_test_only_attrs} + for attr in cuda_library_only_attrs_defaults: + if attr not in cuda_library_attrs: + cuda_library_attrs[attr] = cuda_library_only_attrs_defaults[attr] + cuda_library_name = "_" + name _cuda_library( name = cuda_library_name, - alwayslink = alwayslink, - testonly = True, - **{k: v for k, v in attrs.items() if k not in cc_test_only_attrs} + **cuda_library_attrs, ) native.cc_test(