diff --git a/cuda/private/cuda_helper.bzl b/cuda/private/cuda_helper.bzl index b835d806..1d5855c4 100644 --- a/cuda/private/cuda_helper.bzl +++ b/cuda/private/cuda_helper.bzl @@ -3,6 +3,7 @@ load("@bazel_skylib//lib:paths.bzl", "paths") load("@bazel_skylib//lib:types.bzl", "types") load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo") +load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain") load("//cuda/private:action_names.bzl", "ACTION_NAMES") load("//cuda/private:artifact_categories.bzl", "ARTIFACT_CATEGORIES") load("//cuda/private:providers.bzl", "ArchSpecInfo", "CudaArchsInfo", "CudaInfo", "Stage2ArchInfo", "cuda_archs") @@ -242,6 +243,8 @@ def _create_common(ctx): merged_cc_info = cc_common.merge_cc_infos(cc_infos = [dep[CcInfo] for dep in all_cc_deps]) + cc_toolchain = find_cpp_toolchain(ctx) + # gather include info includes = merged_cc_info.compilation_context.includes.to_list() system_includes = [] @@ -276,6 +279,8 @@ def _create_common(ctx): host_defines = [] host_local_defines = [i for i in attr.host_local_defines] host_compile_flags = attr._default_host_copts[BuildSettingInfo].value + [i for i in attr.host_copts] + if cc_toolchain.sysroot: + host_compile_flags.append("--sysroot={}".format(cc_toolchain.sysroot)) host_link_flags = [] if hasattr(attr, "host_linkopts"): host_link_flags.extend([i for i in attr.host_linkopts])