Skip to content

Commit

Permalink
Apply extended robust access on global buffer access
Browse files Browse the repository at this point in the history
This change will perform defined behavior for global buffer access in
the case of null descriptor or out-of-bound if the application need the
extended `robustnessAccess`.
NOTE: we use dword2 to check null descriptor.
  • Loading branch information
xuechen417 committed Dec 20, 2023
1 parent abe957e commit 217a8e3
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 143 deletions.
2 changes: 2 additions & 0 deletions lgc/include/lgc/patch/PatchBufferOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ class BufferOpLowering {
llvm::Value *replaceLoadStore(llvm::Instruction &inst);
llvm::Instruction *makeLoop(llvm::Value *const loopStart, llvm::Value *const loopEnd, llvm::Value *const loopStride,
llvm::Instruction *const insertPos);
Value *createGlobalPointerAccess(llvm::Value *const bufferDesc, llvm::Value *const offset, llvm::Type *const type,
llvm::Instruction &inst, const llvm::function_ref<Value *(Value *)> callback);

TypeLowering &m_typeLowering;
llvm::IRBuilder<> m_builder;
Expand Down
1 change: 1 addition & 0 deletions lgc/interface/lgc/Pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ union Options {
unsigned rtStaticPipelineFlags; // Ray tracing static pipeline flags
unsigned rtTriCompressMode; // Ray tracing triangle compression mode
bool useGpurt; // Whether GPURT is used
bool enableExtendedRobustBufferAccess; // Enable the extended robust buffer access
};
};
static_assert(sizeof(Options) == sizeof(Options::u32All));
Expand Down
174 changes: 101 additions & 73 deletions lgc/patch/PatchBufferOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,34 +353,27 @@ void BufferOpLowering::visitAtomicCmpXchgInst(AtomicCmpXchgInst &atomicCmpXchgIn

// If our buffer descriptor is divergent, need to handle it differently.
if (getDescriptorInfo(bufferDesc).divergent.value()) {
Value *const baseAddr = getBaseAddressFromBufferDesc(bufferDesc);

// The 2nd element in the buffer descriptor is the byte bound, we do this to support robust buffer access.
Value *const bound = m_builder.CreateExtractElement(bufferDesc, 2);
Value *const inBound = m_builder.CreateICmpULT(baseIndex, bound);
Value *const newBaseIndex = m_builder.CreateSelect(inBound, baseIndex, m_builder.getInt32(0));

// Add on the index to the address.
Value *atomicPointer = m_builder.CreateGEP(m_builder.getInt8Ty(), baseAddr, newBaseIndex);

atomicPointer = m_builder.CreateBitCast(atomicPointer, storeType->getPointerTo(ADDR_SPACE_GLOBAL));

const AtomicOrdering successOrdering = atomicCmpXchgInst.getSuccessOrdering();
const AtomicOrdering failureOrdering = atomicCmpXchgInst.getFailureOrdering();

Value *const compareValue = atomicCmpXchgInst.getCompareOperand();
Value *const newValue = atomicCmpXchgInst.getNewValOperand();
AtomicCmpXchgInst *const newAtomicCmpXchg = m_builder.CreateAtomicCmpXchg(
atomicPointer, compareValue, newValue, MaybeAlign(), successOrdering, failureOrdering);
newAtomicCmpXchg->setVolatile(atomicCmpXchgInst.isVolatile());
newAtomicCmpXchg->setSyncScopeID(atomicCmpXchgInst.getSyncScopeID());
newAtomicCmpXchg->setWeak(atomicCmpXchgInst.isWeak());
copyMetadata(newAtomicCmpXchg, &atomicCmpXchgInst);
auto createAtomicCmpXchgFunc = [&](Value *pointer) {
const AtomicOrdering successOrdering = atomicCmpXchgInst.getSuccessOrdering();
const AtomicOrdering failureOrdering = atomicCmpXchgInst.getFailureOrdering();

Value *const compareValue = atomicCmpXchgInst.getCompareOperand();
Value *const newValue = atomicCmpXchgInst.getNewValOperand();
AtomicCmpXchgInst *const newAtomicCmpXchg = m_builder.CreateAtomicCmpXchg(
pointer, compareValue, newValue, MaybeAlign(), successOrdering, failureOrdering);
newAtomicCmpXchg->setVolatile(atomicCmpXchgInst.isVolatile());
newAtomicCmpXchg->setSyncScopeID(atomicCmpXchgInst.getSyncScopeID());
newAtomicCmpXchg->setWeak(atomicCmpXchgInst.isWeak());
copyMetadata(newAtomicCmpXchg, &atomicCmpXchgInst);
return newAtomicCmpXchg;
};
Value *result =
createGlobalPointerAccess(bufferDesc, baseIndex, storeType, atomicCmpXchgInst, createAtomicCmpXchgFunc);

// Record the atomic instruction so we remember to delete it later.
m_typeLowering.eraseInstruction(&atomicCmpXchgInst);

atomicCmpXchgInst.replaceAllUsesWith(newAtomicCmpXchg);
atomicCmpXchgInst.replaceAllUsesWith(result);
} else {
switch (atomicCmpXchgInst.getSuccessOrdering()) {
case AtomicOrdering::Release:
Expand Down Expand Up @@ -459,29 +452,21 @@ void BufferOpLowering::visitAtomicRMWInst(AtomicRMWInst &atomicRmwInst) {

// If our buffer descriptor is divergent, need to handle it differently.
if (getDescriptorInfo(bufferDesc).divergent.value()) {
Value *const baseAddr = getBaseAddressFromBufferDesc(bufferDesc);

// The 2nd element in the buffer descriptor is the byte bound, we do this to support robust buffer access.
Value *const bound = m_builder.CreateExtractElement(bufferDesc, 2);
Value *const inBound = m_builder.CreateICmpULT(baseIndex, bound);
Value *const newBaseIndex = m_builder.CreateSelect(inBound, baseIndex, m_builder.getInt32(0));

// Add on the index to the address.
Value *atomicPointer = m_builder.CreateGEP(m_builder.getInt8Ty(), baseAddr, newBaseIndex);

atomicPointer = m_builder.CreateBitCast(atomicPointer, storeType->getPointerTo(ADDR_SPACE_GLOBAL));

AtomicRMWInst *const newAtomicRmw =
m_builder.CreateAtomicRMW(atomicRmwInst.getOperation(), atomicPointer, atomicRmwInst.getValOperand(),
atomicRmwInst.getAlign(), atomicRmwInst.getOrdering());
newAtomicRmw->setVolatile(atomicRmwInst.isVolatile());
newAtomicRmw->setSyncScopeID(atomicRmwInst.getSyncScopeID());
copyMetadata(newAtomicRmw, &atomicRmwInst);
auto createAtomicRmwFunc = [&](Value *pointer) {
AtomicRMWInst *const newAtomicRmw =
m_builder.CreateAtomicRMW(atomicRmwInst.getOperation(), pointer, atomicRmwInst.getValOperand(),
atomicRmwInst.getAlign(), atomicRmwInst.getOrdering());
newAtomicRmw->setVolatile(atomicRmwInst.isVolatile());
newAtomicRmw->setSyncScopeID(atomicRmwInst.getSyncScopeID());
copyMetadata(newAtomicRmw, &atomicRmwInst);
return newAtomicRmw;
};
Value *result = createGlobalPointerAccess(bufferDesc, baseIndex, storeType, atomicRmwInst, createAtomicRmwFunc);

// Record the atomic instruction so we remember to delete it later.
m_typeLowering.eraseInstruction(&atomicRmwInst);

atomicRmwInst.replaceAllUsesWith(newAtomicRmw);
atomicRmwInst.replaceAllUsesWith(result);
} else {
switch (atomicRmwInst.getOrdering()) {
case AtomicOrdering::Release:
Expand Down Expand Up @@ -1292,36 +1277,28 @@ Value *BufferOpLowering::replaceLoadStore(Instruction &inst) {

// If our buffer descriptor is divergent, need to handle that differently.
if (getDescriptorInfo(bufferDesc).divergent.value()) {
Value *const baseAddr = getBaseAddressFromBufferDesc(bufferDesc);

// The 2nd element in the buffer descriptor is the byte bound, we do this to support robust buffer access.
Value *const bound = m_builder.CreateExtractElement(bufferDesc, 2);
Value *const inBound = m_builder.CreateICmpULT(baseIndex, bound);
Value *const newBaseIndex = m_builder.CreateSelect(inBound, baseIndex, m_builder.getInt32(0));

// Add on the index to the address.
Value *pointer = m_builder.CreateGEP(m_builder.getInt8Ty(), baseAddr, newBaseIndex);

pointer = m_builder.CreateBitCast(pointer, type->getPointerTo(ADDR_SPACE_GLOBAL));

if (isLoad) {
LoadInst *const newLoad = m_builder.CreateAlignedLoad(type, pointer, alignment, loadInst->isVolatile());
newLoad->setOrdering(ordering);
newLoad->setSyncScopeID(syncScopeID);
copyMetadata(newLoad, loadInst);

if (isInvariant)
newLoad->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(m_builder.getContext(), {}));

return newLoad;
}
StoreInst *const newStore =
m_builder.CreateAlignedStore(storeInst->getValueOperand(), pointer, alignment, storeInst->isVolatile());
newStore->setOrdering(ordering);
newStore->setSyncScopeID(syncScopeID);
copyMetadata(newStore, storeInst);

return newStore;
auto createLoadStoreFunc = [&](Value *pointer) {
Value *result = nullptr;
if (isLoad) {
LoadInst *const newLoad = m_builder.CreateAlignedLoad(type, pointer, alignment, loadInst->isVolatile());
newLoad->setOrdering(ordering);
newLoad->setSyncScopeID(syncScopeID);
copyMetadata(newLoad, loadInst);

if (isInvariant)
newLoad->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(m_builder.getContext(), {}));
result = newLoad;
} else {
StoreInst *const newStore =
m_builder.CreateAlignedStore(storeInst->getValueOperand(), pointer, alignment, storeInst->isVolatile());
newStore->setOrdering(ordering);
newStore->setSyncScopeID(syncScopeID);
copyMetadata(newStore, storeInst);
result = newStore;
}
return result;
};
return createGlobalPointerAccess(bufferDesc, baseIndex, type, inst, createLoadStoreFunc);
}

switch (ordering) {
Expand Down Expand Up @@ -1572,3 +1549,54 @@ Instruction *BufferOpLowering::makeLoop(Value *const loopStart, Value *const loo

return loopCounter;
}

// =====================================================================================================================
// Create global pointer access.
//
// @param bufferDesc: The buffer descriptor
// @param offset: The offset on the global memory
// @param type: The accessed data type
// @param inst: The instruction to be executed on the buffer
// @param callback: The callback function to perform the specific global access
Value *BufferOpLowering::createGlobalPointerAccess(Value *const bufferDesc, Value *const offset, Type *const type,
Instruction &inst, const function_ref<Value *(Value *)> callback) {
// The 2nd element (NUM_RECORDS) in the buffer descriptor is byte bound.
Value *bound = m_builder.CreateExtractElement(bufferDesc, 2);
Value *inBound = m_builder.CreateICmpULT(offset, bound);

// If null descriptor or extended robust buffer access is allowed, we will create a branch to perform normal global
// access based on the valid check.
Value *isValidAccess = m_builder.getTrue();
if (m_pipelineState.getOptions().allowNullDescriptor ||
m_pipelineState.getOptions().enableExtendedRobustBufferAccess) {
Value *isNonNullDesc = m_builder.getTrue();
if (m_pipelineState.getOptions().allowNullDescriptor) {
// Check dword2 against 0 for null descriptor
isNonNullDesc = m_builder.CreateICmpNE(bound, m_builder.getInt32(0));
}
Value *isInBound = m_pipelineState.getOptions().enableExtendedRobustBufferAccess ? inBound : m_builder.getTrue();
isValidAccess = m_builder.CreateAnd(isNonNullDesc, isInBound);
}

BasicBlock *const origBlock = inst.getParent();
Instruction *const terminator = SplitBlockAndInsertIfThen(isValidAccess, &inst, false);

// Global pointer access
m_builder.SetInsertPoint(terminator);
Value *baseAddr = getBaseAddressFromBufferDesc(bufferDesc);
// NOTE: The offset of out-of-bound overridden as 0 may causes unexpected result when the extended robustness access
// is disabled.
Value *newOffset = m_builder.CreateSelect(inBound, offset, m_builder.getInt32(0));
// Add on the index to the address.
Value *pointer = m_builder.CreateGEP(m_builder.getInt8Ty(), baseAddr, newOffset);
pointer = m_builder.CreateBitCast(pointer, type->getPointerTo(ADDR_SPACE_GLOBAL));
Value *newValue = callback(pointer);

m_builder.SetInsertPoint(&inst);
assert(!type->isVoidTy());
auto phi = m_builder.CreatePHI(type, 2, "newValue");
phi->addIncoming(Constant::getNullValue(type), origBlock);
phi->addIncoming(newValue, terminator->getParent());

return phi;
}
Loading

0 comments on commit 217a8e3

Please sign in to comment.