From 177fff98736d82033a66c43fdd986462586277cf Mon Sep 17 00:00:00 2001 From: Cloud Han Date: Wed, 19 Jul 2023 21:18:40 +0800 Subject: [PATCH] Fix _wrapper_device_link artifact name conflict --- cuda/private/actions/compile.bzl | 10 ++++++---- cuda/private/actions/dlink.bzl | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/cuda/private/actions/compile.bzl b/cuda/private/actions/compile.bzl index 644dbc39..e630ebab 100644 --- a/cuda/private/actions/compile.bzl +++ b/cuda/private/actions/compile.bzl @@ -9,7 +9,8 @@ def compile( srcs, common, pic = False, - rdc = False): + rdc = False, + _prefix = "_objs"): """Perform CUDA compilation, return compiled object files. Notes: @@ -26,6 +27,7 @@ def compile( common: A cuda common object. Can be obtained with `cuda_helper.create_common(ctx)` pic: Whether the `srcs` are compiled for position independent code. rdc: Whether the `srcs` are compiled for relocatable device code. + _prefix: DON'T USE IT! Prefix of the output dir. Exposed for device link to redirect the output. Returns: An compiled object `File`. @@ -53,13 +55,13 @@ def compile( filename = None filename = cuda_helper.get_artifact_name(cuda_toolchain, artifact_category_name, basename) - # Objects are placed in _objs//. + # Objects are placed in <_prefix>//. # For files with the same basename, say srcs = ["kernel.cu", "foo/kernel.cu", "bar/kernel.cu"], we get - # _objs//0/kernel., _objs//1/kernel., _objs//2/kernel.. + # <_prefix>//0/kernel., <_prefix>//1/kernel., <_prefix>//2/kernel.. # Otherwise, the index is not presented. if basename_counter[basename] > 1: filename = "{}/{}".format(basename_index, filename) - obj_file = actions.declare_file("_objs/{}/{}".format(ctx.attr.name, filename)) + obj_file = actions.declare_file("{}/{}/{}".format(_prefix, ctx.attr.name, filename)) ret.append(obj_file) var = cuda_helper.create_compile_variables( diff --git a/cuda/private/actions/dlink.bzl b/cuda/private/actions/dlink.bzl index f6175c04..70a54651 100644 --- a/cuda/private/actions/dlink.bzl +++ b/cuda/private/actions/dlink.bzl @@ -203,5 +203,5 @@ def _wrapper_device_link( # suppress cuda mode as c++ mode host_compile_flags = common.host_compile_flags + ["-x", "c++"], ) - ret = compile(ctx, cuda_toolchain, cc_toolchain, srcs = [fatbin_c], common = compile_common, pic = pic, rdc = rdc) + ret = compile(ctx, cuda_toolchain, cc_toolchain, srcs = [fatbin_c], common = compile_common, pic = pic, rdc = rdc, _prefix = "_objs/_dlink") return ret[0]