Skip to content

Commit

Permalink
Merge pull request #2415 from AllanZyne/review/yang/fix_metadata_assert
Browse files Browse the repository at this point in the history
[DeviceASAN] Fix ASAN with kernel assert
  • Loading branch information
martygrant authored Dec 18, 2024
2 parents c45de9a + 05f94a8 commit d18d523
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 58 deletions.
1 change: 1 addition & 0 deletions source/loader/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ if(UR_ENABLE_SANITIZER)
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_buffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_buffer.hpp
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_ddi.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_ddi.hpp
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_interceptor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_interceptor.hpp
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_libdevice.hpp
Expand Down
47 changes: 13 additions & 34 deletions source/loader/layers/sanitizer/asan/asan_ddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,6 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
return UR_RESULT_SUCCESS;
}

bool isInstrumentedKernel(ur_kernel_handle_t hKernel) {
auto hProgram = GetProgram(hKernel);
auto PI = getAsanInterceptor()->getProgramInfo(hProgram);
if (PI == nullptr) {
return false;
}
return PI->isKernelInstrumented(hKernel);
}

} // namespace

///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -470,15 +461,10 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunch(

getContext()->logger.debug("==== urEnqueueKernelLaunch");

if (!isInstrumentedKernel(hKernel)) {
return pfnKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
pGlobalWorkSize, pLocalWorkSize,
numEventsInWaitList, phEventWaitList, phEvent);
}

LaunchInfo LaunchInfo(GetContext(hQueue), GetDevice(hQueue),
pGlobalWorkSize, pLocalWorkSize, pGlobalWorkOffset,
workDim);
UR_CALL(LaunchInfo.Data.syncToDevice(hQueue));

UR_CALL(getAsanInterceptor()->preLaunchKernel(hKernel, hQueue, LaunchInfo));

Expand Down Expand Up @@ -1366,9 +1352,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelCreate(
getContext()->logger.debug("==== urKernelCreate");

UR_CALL(pfnCreate(hProgram, pKernelName, phKernel));
if (isInstrumentedKernel(*phKernel)) {
UR_CALL(getAsanInterceptor()->insertKernel(*phKernel));
}
UR_CALL(getAsanInterceptor()->insertKernel(*phKernel));

return UR_RESULT_SUCCESS;
}
Expand All @@ -1389,9 +1373,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(
UR_CALL(pfnRetain(hKernel));

auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
if (KernelInfo) {
KernelInfo->RefCount++;
}
KernelInfo->RefCount++;

return UR_RESULT_SUCCESS;
}
Expand All @@ -1411,10 +1393,8 @@ __urdlllocal ur_result_t urKernelRelease(
UR_CALL(pfnRelease(hKernel));

auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
if (KernelInfo) {
if (--KernelInfo->RefCount == 0) {
UR_CALL(getAsanInterceptor()->eraseKernel(hKernel));
}
if (--KernelInfo->RefCount == 0) {
UR_CALL(getAsanInterceptor()->eraseKernel(hKernel));
}

return UR_RESULT_SUCCESS;
Expand All @@ -1440,11 +1420,10 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgValue(
getContext()->logger.debug("==== urKernelSetArgValue");

std::shared_ptr<MemBuffer> MemBuffer;
std::shared_ptr<KernelInfo> KernelInfo;
if (argSize == sizeof(ur_mem_handle_t) &&
(MemBuffer = getAsanInterceptor()->getMemBuffer(
*ur_cast<const ur_mem_handle_t *>(pArgValue))) &&
(KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel))) {
*ur_cast<const ur_mem_handle_t *>(pArgValue)))) {
auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
} else {
Expand Down Expand Up @@ -1473,9 +1452,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj(
getContext()->logger.debug("==== urKernelSetArgMemObj");

std::shared_ptr<MemBuffer> MemBuffer;
std::shared_ptr<KernelInfo> KernelInfo;
if ((MemBuffer = getAsanInterceptor()->getMemBuffer(hArgValue)) &&
(KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel))) {
if ((MemBuffer = getAsanInterceptor()->getMemBuffer(hArgValue))) {
auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
} else {
Expand Down Expand Up @@ -1505,7 +1483,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgLocal(
"==== urKernelSetArgLocal (argIndex={}, argSize={})", argIndex,
argSize);

if (auto KI = getAsanInterceptor()->getKernelInfo(hKernel)) {
{
auto KI = getAsanInterceptor()->getKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KI->Mutex);
// TODO: get local variable alignment
auto argSizeWithRZ = GetSizeAndRedzoneSizeForLocal(
Expand Down Expand Up @@ -1542,8 +1521,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgPointer(
pArgValue);

std::shared_ptr<KernelInfo> KI;
if (getAsanInterceptor()->getOptions().DetectKernelArguments &&
(KI = getAsanInterceptor()->getKernelInfo(hKernel))) {
if (getAsanInterceptor()->getOptions().DetectKernelArguments) {
auto KI = getAsanInterceptor()->getKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KI->Mutex);
KI->PointerArgs[argIndex] = {pArgValue, GetCurrentBacktrace()};
}
Expand Down
51 changes: 33 additions & 18 deletions source/loader/layers/sanitizer/asan/asan_interceptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,13 @@ ur_result_t AsanInterceptor::insertKernel(ur_kernel_handle_t Kernel) {
if (m_KernelMap.find(Kernel) != m_KernelMap.end()) {
return UR_RESULT_SUCCESS;
}
m_KernelMap.emplace(Kernel, std::make_shared<KernelInfo>(Kernel));

auto hProgram = GetProgram(Kernel);
auto PI = getAsanInterceptor()->getProgramInfo(hProgram);
bool IsInstrumented = PI->isKernelInstrumented(Kernel);

m_KernelMap.emplace(Kernel,
std::make_shared<KernelInfo>(Kernel, IsInstrumented));
return UR_RESULT_SUCCESS;
}

Expand Down Expand Up @@ -685,9 +691,19 @@ ur_result_t AsanInterceptor::prepareLaunch(
std::shared_ptr<ContextInfo> &ContextInfo,
std::shared_ptr<DeviceInfo> &DeviceInfo, ur_queue_handle_t Queue,
ur_kernel_handle_t Kernel, LaunchInfo &LaunchInfo) {

auto KernelInfo = getKernelInfo(Kernel);
assert(KernelInfo && "Kernel should be instrumented");

auto ArgNums = GetKernelNumArgs(Kernel);
auto LocalMemoryUsage =
GetKernelLocalMemorySize(Kernel, DeviceInfo->Handle);
auto PrivateMemoryUsage =
GetKernelPrivateMemorySize(Kernel, DeviceInfo->Handle);

getContext()->logger.info(
"KernelInfo {} (Name={}, ArgNums={}, IsInstrumented={}, "
"LocalMemory={}, PrivateMemory={})",
(void *)Kernel, GetKernelName(Kernel), ArgNums,
KernelInfo->IsInstrumented, LocalMemoryUsage, PrivateMemoryUsage);

// Validate pointer arguments
if (getOptions().DetectKernelArguments) {
Expand Down Expand Up @@ -719,11 +735,17 @@ ur_result_t AsanInterceptor::prepareLaunch(
}
}

auto ArgNums = GetKernelNumArgs(Kernel);
if (!KernelInfo->IsInstrumented) {
return UR_RESULT_SUCCESS;
}

// We must prepare all kernel args before call
// urKernelGetSuggestedLocalWorkSize, otherwise the call will fail on
// CPU device.
if (ArgNums) {
{
assert(ArgNums >= 1 &&
"Sanitized Kernel should have at least one argument");

ur_result_t URes = getContext()->urDdiTable.Kernel.pfnSetArgPointer(
Kernel, ArgNums - 1, nullptr, LaunchInfo.Data.getDevicePtr());
if (URes != UR_RESULT_SUCCESS) {
Expand Down Expand Up @@ -763,15 +785,6 @@ ur_result_t AsanInterceptor::prepareLaunch(
LaunchInfo.Data.Host.DeviceTy = DeviceInfo->Type;
LaunchInfo.Data.Host.Debug = getOptions().Debug ? 1 : 0;

auto LocalMemoryUsage =
GetKernelLocalMemorySize(Kernel, DeviceInfo->Handle);
auto PrivateMemoryUsage =
GetKernelPrivateMemorySize(Kernel, DeviceInfo->Handle);

getContext()->logger.info(
"KernelInfo {} (LocalMemory={}, PrivateMemory={})", (void *)Kernel,
LocalMemoryUsage, PrivateMemoryUsage);

// Write shadow memory offset for local memory
if (getOptions().DetectLocals) {
if (DeviceInfo->Shadow->AllocLocalShadow(
Expand Down Expand Up @@ -831,10 +844,12 @@ ur_result_t AsanInterceptor::prepareLaunch(
// sync asan runtime data to device side
UR_CALL(LaunchInfo.Data.syncToDevice(Queue));

getContext()->logger.debug("launch_info {} (numLocalArgs={}, localArgs={})",
(void *)LaunchInfo.Data.getDevicePtr(),
LaunchInfo.Data.Host.NumLocalArgs,
(void *)LaunchInfo.Data.Host.LocalArgs);
getContext()->logger.info(
"LaunchInfo {} (device={}, debug={}, numLocalArgs={}, localArgs={})",
(void *)LaunchInfo.Data.getDevicePtr(),
ToString(LaunchInfo.Data.Host.DeviceTy), LaunchInfo.Data.Host.Debug,
LaunchInfo.Data.Host.NumLocalArgs,
(void *)LaunchInfo.Data.Host.LocalArgs);

return UR_RESULT_SUCCESS;
}
Expand Down
12 changes: 7 additions & 5 deletions source/loader/layers/sanitizer/asan/asan_interceptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ struct KernelInfo {
ur_kernel_handle_t Handle;
std::atomic<int32_t> RefCount = 1;

// sanitized kernel
bool IsInstrumented = false;

// lock this mutex if following fields are accessed
ur_shared_mutex Mutex;
std::unordered_map<uint32_t, std::shared_ptr<MemBuffer>> BufferArgs;
Expand All @@ -94,7 +97,8 @@ struct KernelInfo {
// Need preserve the order of local arguments
std::map<uint32_t, LocalArgsInfo> LocalArgs;

explicit KernelInfo(ur_kernel_handle_t Kernel) : Handle(Kernel) {
explicit KernelInfo(ur_kernel_handle_t Kernel, bool IsInstrumented)
: Handle(Kernel), IsInstrumented(IsInstrumented) {
[[maybe_unused]] auto Result =
getContext()->urDdiTable.Kernel.pfnRetain(Kernel);
assert(Result == UR_RESULT_SUCCESS);
Expand Down Expand Up @@ -348,10 +352,8 @@ class AsanInterceptor {

std::shared_ptr<KernelInfo> getKernelInfo(ur_kernel_handle_t Kernel) {
std::shared_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
if (m_KernelMap.find(Kernel) != m_KernelMap.end()) {
return m_KernelMap[Kernel];
}
return nullptr;
assert(m_KernelMap.find(Kernel) != m_KernelMap.end());
return m_KernelMap[Kernel];
}

const AsanOptions &getOptions() { return m_Options; }
Expand Down
2 changes: 1 addition & 1 deletion source/loader/layers/sanitizer/asan/asan_libdevice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ struct AsanRuntimeData {
uint32_t Debug = 0;

int ReportFlag = 0;
AsanErrorReport Report[ASAN_MAX_NUM_REPORTS];
AsanErrorReport Report[ASAN_MAX_NUM_REPORTS] = {};
};

constexpr unsigned ASAN_SHADOW_SCALE = 4;
Expand Down

0 comments on commit d18d523

Please sign in to comment.