From 902f979102f00da9cce72a59dee4b50a9327854a Mon Sep 17 00:00:00 2001 From: cloudhan Date: Fri, 3 Nov 2023 23:25:29 +0800 Subject: [PATCH] Add cuda_binary macro (#186) * Add macro for cuda_binary * Add example for cuda_binary and cuda_test macros --- cuda/defs.bzl | 5 +++++ cuda/private/macros/cuda_binary.bzl | 21 +++++++++++++++++++++ docs/user_docs.bzl | 4 +++- examples/basic_macros/BUILD.bazel | 13 +++++++++++++ examples/basic_macros/main.cu | 26 ++++++++++++++++++++++++++ 5 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 cuda/private/macros/cuda_binary.bzl create mode 100644 examples/basic_macros/BUILD.bazel create mode 100644 examples/basic_macros/main.cu diff --git a/cuda/defs.bzl b/cuda/defs.bzl index 6419497d..e7a948f4 100644 --- a/cuda/defs.bzl +++ b/cuda/defs.bzl @@ -6,6 +6,7 @@ load("//cuda/private:providers.bzl", _CudaArchsInfo = "CudaArchsInfo", _cuda_arc load("//cuda/private:os_helpers.bzl", _cc_import_versioned_sos = "cc_import_versioned_sos", _if_linux = "if_linux", _if_windows = "if_windows") load("//cuda/private:rules/cuda_objects.bzl", _cuda_objects = "cuda_objects") load("//cuda/private:rules/cuda_library.bzl", _cuda_library = "cuda_library") +load("//cuda/private:macros/cuda_binary.bzl", _cuda_binary = "cuda_binary") load("//cuda/private:macros/cuda_test.bzl", _cuda_test = "cuda_test") load("//cuda/private:rules/cuda_toolkit.bzl", _cuda_toolkit = "cuda_toolkit") load( @@ -29,8 +30,12 @@ cuda_toolchain_config_nvcc = _cuda_toolchain_config_nvcc cuda_archs = _cuda_archs CudaArchsInfo = _CudaArchsInfo +# rules cuda_objects = _cuda_objects cuda_library = _cuda_library + +# macros +cuda_binary = _cuda_binary cuda_test = _cuda_test if_linux = _if_linux diff --git a/cuda/private/macros/cuda_binary.bzl b/cuda/private/macros/cuda_binary.bzl new file mode 100644 index 00000000..20a158ce --- /dev/null +++ b/cuda/private/macros/cuda_binary.bzl @@ -0,0 +1,21 @@ +load("//cuda/private:rules/cuda_library.bzl", _cuda_library = "cuda_library") + +def cuda_binary(name, **attrs): + """Wrapper to ensure the binary is compiled with the CUDA compiler.""" + cuda_library_only_attrs = ["deps", "srcs", "hdrs"] + + # https://bazel.build/reference/be/common-definitions?hl=en#common-attributes-binaries + cc_binary_only_attrs = ["args", "env", "output_licenses"] + + cuda_library_name = "_" + name + + _cuda_library( + name = cuda_library_name, + **{k: v for k, v in attrs.items() if k not in cc_binary_only_attrs} + ) + + native.cc_binary( + name = name, + deps = [cuda_library_name], + **{k: v for k, v in attrs.items() if k not in cuda_library_only_attrs} + ) diff --git a/docs/user_docs.bzl b/docs/user_docs.bzl index fe181eee..e02bd9e8 100644 --- a/docs/user_docs.bzl +++ b/docs/user_docs.bzl @@ -1,9 +1,11 @@ -load("@rules_cuda//cuda:defs.bzl", _cuda_library = "cuda_library", _cuda_objects = "cuda_objects", _cuda_test = "cuda_test") +load("@rules_cuda//cuda:defs.bzl", _cuda_binary = "cuda_binary", _cuda_library = "cuda_library", _cuda_objects = "cuda_objects", _cuda_test = "cuda_test") load("@rules_cuda//cuda:repositories.bzl", _register_detected_cuda_toolchains = "register_detected_cuda_toolchains", _rules_cuda_dependencies = "rules_cuda_dependencies") load("@rules_cuda//cuda/private:rules/flags.bzl", _cuda_archs_flag = "cuda_archs_flag") cuda_library = _cuda_library cuda_objects = _cuda_objects + +cuda_binary = _cuda_binary cuda_test = _cuda_test cuda_archs = _cuda_archs_flag diff --git a/examples/basic_macros/BUILD.bazel b/examples/basic_macros/BUILD.bazel new file mode 100644 index 00000000..23b588e6 --- /dev/null +++ b/examples/basic_macros/BUILD.bazel @@ -0,0 +1,13 @@ +load("@rules_cuda//cuda:defs.bzl", "cuda_binary", "cuda_test") + +package(default_visibility = ["//visibility:public"]) + +cuda_binary( + name = "main", + srcs = ["main.cu"], +) + +cuda_test( + name = "test", + srcs = ["main.cu"], +) diff --git a/examples/basic_macros/main.cu b/examples/basic_macros/main.cu new file mode 100644 index 00000000..9812984d --- /dev/null +++ b/examples/basic_macros/main.cu @@ -0,0 +1,26 @@ +#include + +#define CUDA_CHECK(expr) \ + do { \ + cudaError_t err = (expr); \ + if (err != cudaSuccess) { \ + fprintf(stderr, "CUDA Error Code : %d\n Error String: %s\n", \ + err, cudaGetErrorString(err)); \ + exit(err); \ + } \ + } while (0) + +__global__ void kernel() { + printf("cuda kernel called!\n"); +} + +void launch() { + kernel<<<1, 1>>>(); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); +} + +int main() { + launch(); + return 0; +}