Skip to content

Commit

Permalink
Handle null descriptor for patching global buffer operations
Browse files Browse the repository at this point in the history
Null descriptor is not taken into consideration on pathing global
load/store/atomic operations when `allowNullDescriptor` is enabled.
If the dword3 of descriptor is zero, we will return 0 for load/atmic and
skip store.
NOTE: we use dword3 to check null desc so that we should not change the
descriptor itself if it is null in `CreateBufferBuffer`.
  • Loading branch information
xuechen417 committed Dec 12, 2023
1 parent daf27f7 commit 586f50b
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 145 deletions.
10 changes: 8 additions & 2 deletions lgc/builder/DescBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,14 @@ Value *BuilderImpl::CreateBufferDesc(uint64_t descSet, unsigned binding, Value *
desc = CreateInsertElement(desc, CreateAnd(desc1, getInt32(0xc000ffff)), 1);
desc = CreateInsertElement(desc, CreateMul(stride, desc2), 2);
// gfx10 and gfx11 have oob fields with 2 bits in dword3[29:28] here force to set to 3 as OOB_COMPLETE mode.
if (getPipelineState()->getTargetInfo().getGfxIpVersion().major >= 10)
desc = CreateInsertElement(desc, CreateOr(desc3, getInt32(0x30000000)), 3);
if (getPipelineState()->getTargetInfo().getGfxIpVersion().major >= 10) {
Value *newDesc3 = CreateOr(desc3, getInt32(0x30000000));
if (getPipelineState()->getOptions().allowNullDescriptor) {
Value *isNullDesc = CreateICmpEQ(desc3, getInt32(0));
newDesc3 = CreateSelect(isNullDesc, desc3, newDesc3);
}
desc = CreateInsertElement(desc, newDesc3, 3);
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions lgc/include/lgc/patch/PatchBufferOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ class BufferOpLowering {
llvm::Value *replaceLoadStore(llvm::Instruction &inst);
llvm::Instruction *makeLoop(llvm::Value *const loopStart, llvm::Value *const loopEnd, llvm::Value *const loopStride,
llvm::Instruction *const insertPos);
Value *createGlobalPointerAccess(llvm::Value *bufferDesc, llvm::Value *offset, llvm::Type *type,
llvm::Instruction &inst, const llvm::function_ref<Value *(Value *)> &callback);

TypeLowering &m_typeLowering;
llvm::IRBuilder<> m_builder;
Expand Down
164 changes: 91 additions & 73 deletions lgc/patch/PatchBufferOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,34 +353,27 @@ void BufferOpLowering::visitAtomicCmpXchgInst(AtomicCmpXchgInst &atomicCmpXchgIn

// If our buffer descriptor is divergent, need to handle it differently.
if (getDescriptorInfo(bufferDesc).divergent.value()) {
Value *const baseAddr = getBaseAddressFromBufferDesc(bufferDesc);

// The 2nd element in the buffer descriptor is the byte bound, we do this to support robust buffer access.
Value *const bound = m_builder.CreateExtractElement(bufferDesc, 2);
Value *const inBound = m_builder.CreateICmpULT(baseIndex, bound);
Value *const newBaseIndex = m_builder.CreateSelect(inBound, baseIndex, m_builder.getInt32(0));

// Add on the index to the address.
Value *atomicPointer = m_builder.CreateGEP(m_builder.getInt8Ty(), baseAddr, newBaseIndex);

atomicPointer = m_builder.CreateBitCast(atomicPointer, storeType->getPointerTo(ADDR_SPACE_GLOBAL));

const AtomicOrdering successOrdering = atomicCmpXchgInst.getSuccessOrdering();
const AtomicOrdering failureOrdering = atomicCmpXchgInst.getFailureOrdering();

Value *const compareValue = atomicCmpXchgInst.getCompareOperand();
Value *const newValue = atomicCmpXchgInst.getNewValOperand();
AtomicCmpXchgInst *const newAtomicCmpXchg = m_builder.CreateAtomicCmpXchg(
atomicPointer, compareValue, newValue, MaybeAlign(), successOrdering, failureOrdering);
newAtomicCmpXchg->setVolatile(atomicCmpXchgInst.isVolatile());
newAtomicCmpXchg->setSyncScopeID(atomicCmpXchgInst.getSyncScopeID());
newAtomicCmpXchg->setWeak(atomicCmpXchgInst.isWeak());
copyMetadata(newAtomicCmpXchg, &atomicCmpXchgInst);
auto createAtomicCmpXchgFunc = [&](Value *pointer) {
const AtomicOrdering successOrdering = atomicCmpXchgInst.getSuccessOrdering();
const AtomicOrdering failureOrdering = atomicCmpXchgInst.getFailureOrdering();

Value *const compareValue = atomicCmpXchgInst.getCompareOperand();
Value *const newValue = atomicCmpXchgInst.getNewValOperand();
AtomicCmpXchgInst *const newAtomicCmpXchg = m_builder.CreateAtomicCmpXchg(
pointer, compareValue, newValue, MaybeAlign(), successOrdering, failureOrdering);
newAtomicCmpXchg->setVolatile(atomicCmpXchgInst.isVolatile());
newAtomicCmpXchg->setSyncScopeID(atomicCmpXchgInst.getSyncScopeID());
newAtomicCmpXchg->setWeak(atomicCmpXchgInst.isWeak());
copyMetadata(newAtomicCmpXchg, &atomicCmpXchgInst);
return newAtomicCmpXchg;
};
Value *result =
createGlobalPointerAccess(bufferDesc, baseIndex, storeType, atomicCmpXchgInst, createAtomicCmpXchgFunc);

// Record the atomic instruction so we remember to delete it later.
m_typeLowering.eraseInstruction(&atomicCmpXchgInst);

atomicCmpXchgInst.replaceAllUsesWith(newAtomicCmpXchg);
atomicCmpXchgInst.replaceAllUsesWith(result);
} else {
switch (atomicCmpXchgInst.getSuccessOrdering()) {
case AtomicOrdering::Release:
Expand Down Expand Up @@ -459,29 +452,21 @@ void BufferOpLowering::visitAtomicRMWInst(AtomicRMWInst &atomicRmwInst) {

// If our buffer descriptor is divergent, need to handle it differently.
if (getDescriptorInfo(bufferDesc).divergent.value()) {
Value *const baseAddr = getBaseAddressFromBufferDesc(bufferDesc);

// The 2nd element in the buffer descriptor is the byte bound, we do this to support robust buffer access.
Value *const bound = m_builder.CreateExtractElement(bufferDesc, 2);
Value *const inBound = m_builder.CreateICmpULT(baseIndex, bound);
Value *const newBaseIndex = m_builder.CreateSelect(inBound, baseIndex, m_builder.getInt32(0));

// Add on the index to the address.
Value *atomicPointer = m_builder.CreateGEP(m_builder.getInt8Ty(), baseAddr, newBaseIndex);

atomicPointer = m_builder.CreateBitCast(atomicPointer, storeType->getPointerTo(ADDR_SPACE_GLOBAL));

AtomicRMWInst *const newAtomicRmw =
m_builder.CreateAtomicRMW(atomicRmwInst.getOperation(), atomicPointer, atomicRmwInst.getValOperand(),
atomicRmwInst.getAlign(), atomicRmwInst.getOrdering());
newAtomicRmw->setVolatile(atomicRmwInst.isVolatile());
newAtomicRmw->setSyncScopeID(atomicRmwInst.getSyncScopeID());
copyMetadata(newAtomicRmw, &atomicRmwInst);
auto createAtomicRmwFunc = [&](Value *pointer) {
AtomicRMWInst *const newAtomicRmw =
m_builder.CreateAtomicRMW(atomicRmwInst.getOperation(), pointer, atomicRmwInst.getValOperand(),
atomicRmwInst.getAlign(), atomicRmwInst.getOrdering());
newAtomicRmw->setVolatile(atomicRmwInst.isVolatile());
newAtomicRmw->setSyncScopeID(atomicRmwInst.getSyncScopeID());
copyMetadata(newAtomicRmw, &atomicRmwInst);
return newAtomicRmw;
};
Value *result = createGlobalPointerAccess(bufferDesc, baseIndex, storeType, atomicRmwInst, createAtomicRmwFunc);

// Record the atomic instruction so we remember to delete it later.
m_typeLowering.eraseInstruction(&atomicRmwInst);

atomicRmwInst.replaceAllUsesWith(newAtomicRmw);
atomicRmwInst.replaceAllUsesWith(result);
} else {
switch (atomicRmwInst.getOrdering()) {
case AtomicOrdering::Release:
Expand Down Expand Up @@ -1292,36 +1277,28 @@ Value *BufferOpLowering::replaceLoadStore(Instruction &inst) {

// If our buffer descriptor is divergent, need to handle that differently.
if (getDescriptorInfo(bufferDesc).divergent.value()) {
Value *const baseAddr = getBaseAddressFromBufferDesc(bufferDesc);

// The 2nd element in the buffer descriptor is the byte bound, we do this to support robust buffer access.
Value *const bound = m_builder.CreateExtractElement(bufferDesc, 2);
Value *const inBound = m_builder.CreateICmpULT(baseIndex, bound);
Value *const newBaseIndex = m_builder.CreateSelect(inBound, baseIndex, m_builder.getInt32(0));

// Add on the index to the address.
Value *pointer = m_builder.CreateGEP(m_builder.getInt8Ty(), baseAddr, newBaseIndex);

pointer = m_builder.CreateBitCast(pointer, type->getPointerTo(ADDR_SPACE_GLOBAL));

if (isLoad) {
LoadInst *const newLoad = m_builder.CreateAlignedLoad(type, pointer, alignment, loadInst->isVolatile());
newLoad->setOrdering(ordering);
newLoad->setSyncScopeID(syncScopeID);
copyMetadata(newLoad, loadInst);

if (isInvariant)
newLoad->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(m_builder.getContext(), {}));

return newLoad;
}
StoreInst *const newStore =
m_builder.CreateAlignedStore(storeInst->getValueOperand(), pointer, alignment, storeInst->isVolatile());
newStore->setOrdering(ordering);
newStore->setSyncScopeID(syncScopeID);
copyMetadata(newStore, storeInst);

return newStore;
auto createLoadStoreFunc = [&](Value *pointer) {
Value *result = nullptr;
if (isLoad) {
LoadInst *const newLoad = m_builder.CreateAlignedLoad(type, pointer, alignment, loadInst->isVolatile());
newLoad->setOrdering(ordering);
newLoad->setSyncScopeID(syncScopeID);
copyMetadata(newLoad, loadInst);

if (isInvariant)
newLoad->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(m_builder.getContext(), {}));
result = newLoad;
} else {
StoreInst *const newStore =
m_builder.CreateAlignedStore(storeInst->getValueOperand(), pointer, alignment, storeInst->isVolatile());
newStore->setOrdering(ordering);
newStore->setSyncScopeID(syncScopeID);
copyMetadata(newStore, storeInst);
result = newStore;
}
return result;
};
return createGlobalPointerAccess(bufferDesc, baseIndex, type, inst, createLoadStoreFunc);
}

switch (ordering) {
Expand Down Expand Up @@ -1571,3 +1548,44 @@ Instruction *BufferOpLowering::makeLoop(Value *const loopStart, Value *const loo

return loopCounter;
}

// =====================================================================================================================
// Create global pointer access.
//
// @param bufferDesc: The buffer descriptor
// @param offset: The offset on the global memory
// @param type: The accessed data type
// @param inst: The instruction to be executed on the buffer
// @param callback: The callback function to perform the specific global access
Value *BufferOpLowering::createGlobalPointerAccess(Value *bufferDesc, Value *offset, Type *type, Instruction &inst,
const function_ref<Value *(Value *)> &callback) {
// Handle null descriptor if it is allowed. Load/atomic zero and skip store for null descriptor.
Value *isNullDesc = m_builder.getFalse();
if (m_pipelineState.getOptions().allowNullDescriptor) {
// Check dword3 against 0 for a null descriptor
Value *descWord3 = m_builder.CreateExtractElement(bufferDesc, 3);
isNullDesc = m_builder.CreateICmpEQ(descWord3, m_builder.getInt32(0));
}
BasicBlock *const origBlock = inst.getParent();
Instruction *const terminator = SplitBlockAndInsertIfThen(isNullDesc, &inst, false);

// Global pointer access
m_builder.SetInsertPoint(terminator);
Value *baseAddr = getBaseAddressFromBufferDesc(bufferDesc);
// The 2nd element in the buffer descriptor is the byte bound, we do this to support robust buffer access.
Value *bound = m_builder.CreateExtractElement(bufferDesc, 2);
Value *inBound = m_builder.CreateICmpULT(offset, bound);
Value *newOffset = m_builder.CreateSelect(inBound, offset, m_builder.getInt32(0));
// Add on the index to the address.
Value *pointer = m_builder.CreateGEP(m_builder.getInt8Ty(), baseAddr, newOffset);
pointer = m_builder.CreateBitCast(pointer, type->getPointerTo(ADDR_SPACE_GLOBAL));
Value *newValue = callback(pointer);

m_builder.SetInsertPoint(&inst);
assert(!type->isVoidTy());
auto phi = m_builder.CreatePHI(type, 2, "newValue");
phi->addIncoming(Constant::getNullValue(type), origBlock);
phi->addIncoming(newValue, terminator->getParent());

return phi;
}
Loading

0 comments on commit 586f50b

Please sign in to comment.