Skip to content

Commit

Permalink
Add cuda_binary macro (#186)
Browse files Browse the repository at this point in the history
* Add macro for cuda_binary

* Add example for cuda_binary and cuda_test macros
  • Loading branch information
cloudhan authored Nov 3, 2023
1 parent 5b9f263 commit 902f979
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 1 deletion.
5 changes: 5 additions & 0 deletions cuda/defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ load("//cuda/private:providers.bzl", _CudaArchsInfo = "CudaArchsInfo", _cuda_arc
load("//cuda/private:os_helpers.bzl", _cc_import_versioned_sos = "cc_import_versioned_sos", _if_linux = "if_linux", _if_windows = "if_windows")
load("//cuda/private:rules/cuda_objects.bzl", _cuda_objects = "cuda_objects")
load("//cuda/private:rules/cuda_library.bzl", _cuda_library = "cuda_library")
load("//cuda/private:macros/cuda_binary.bzl", _cuda_binary = "cuda_binary")
load("//cuda/private:macros/cuda_test.bzl", _cuda_test = "cuda_test")
load("//cuda/private:rules/cuda_toolkit.bzl", _cuda_toolkit = "cuda_toolkit")
load(
Expand All @@ -29,8 +30,12 @@ cuda_toolchain_config_nvcc = _cuda_toolchain_config_nvcc
cuda_archs = _cuda_archs
CudaArchsInfo = _CudaArchsInfo

# rules
cuda_objects = _cuda_objects
cuda_library = _cuda_library

# macros
cuda_binary = _cuda_binary
cuda_test = _cuda_test

if_linux = _if_linux
Expand Down
21 changes: 21 additions & 0 deletions cuda/private/macros/cuda_binary.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
load("//cuda/private:rules/cuda_library.bzl", _cuda_library = "cuda_library")

def cuda_binary(name, **attrs):
"""Wrapper to ensure the binary is compiled with the CUDA compiler."""
cuda_library_only_attrs = ["deps", "srcs", "hdrs"]

# https://bazel.build/reference/be/common-definitions?hl=en#common-attributes-binaries
cc_binary_only_attrs = ["args", "env", "output_licenses"]

cuda_library_name = "_" + name

_cuda_library(
name = cuda_library_name,
**{k: v for k, v in attrs.items() if k not in cc_binary_only_attrs}
)

native.cc_binary(
name = name,
deps = [cuda_library_name],
**{k: v for k, v in attrs.items() if k not in cuda_library_only_attrs}
)
4 changes: 3 additions & 1 deletion docs/user_docs.bzl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
load("@rules_cuda//cuda:defs.bzl", _cuda_library = "cuda_library", _cuda_objects = "cuda_objects", _cuda_test = "cuda_test")
load("@rules_cuda//cuda:defs.bzl", _cuda_binary = "cuda_binary", _cuda_library = "cuda_library", _cuda_objects = "cuda_objects", _cuda_test = "cuda_test")
load("@rules_cuda//cuda:repositories.bzl", _register_detected_cuda_toolchains = "register_detected_cuda_toolchains", _rules_cuda_dependencies = "rules_cuda_dependencies")
load("@rules_cuda//cuda/private:rules/flags.bzl", _cuda_archs_flag = "cuda_archs_flag")

cuda_library = _cuda_library
cuda_objects = _cuda_objects

cuda_binary = _cuda_binary
cuda_test = _cuda_test

cuda_archs = _cuda_archs_flag
Expand Down
13 changes: 13 additions & 0 deletions examples/basic_macros/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
load("@rules_cuda//cuda:defs.bzl", "cuda_binary", "cuda_test")

package(default_visibility = ["//visibility:public"])

cuda_binary(
name = "main",
srcs = ["main.cu"],
)

cuda_test(
name = "test",
srcs = ["main.cu"],
)
26 changes: 26 additions & 0 deletions examples/basic_macros/main.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include <cstdio>

#define CUDA_CHECK(expr) \
do { \
cudaError_t err = (expr); \
if (err != cudaSuccess) { \
fprintf(stderr, "CUDA Error Code : %d\n Error String: %s\n", \
err, cudaGetErrorString(err)); \
exit(err); \
} \
} while (0)

__global__ void kernel() {
printf("cuda kernel called!\n");
}

void launch() {
kernel<<<1, 1>>>();
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
}

int main() {
launch();
return 0;
}

0 comments on commit 902f979

Please sign in to comment.