Skip to content

Commit

Permalink
Reduce divide/remainder in WGMXCC calculation (#1168)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex391a authored Sep 30, 2024
1 parent a687c7b commit c324499
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 12 deletions.
2 changes: 1 addition & 1 deletion tensilelite/Tensile/Common.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ def supportedCompiler(compiler: str) -> bool:
# Formula for wgSerial:
# wgSerial = wg0 + (wg1 % WorkGroupMapping) * nwg0
"WorkGroupMapping": list(range(-1024, 1024+1)), # change a workgroup's id so that the all the workgroups on the gpu at a time are hitting L2 cache the best
"WorkGroupMappingXCC": list(range(1, 64)), # change a workgroup's id so that contiguous workgroup can map on same XCC
"WorkGroupMappingXCC": [1,2,4,8,16,32], # change a workgroup's id so that contiguous workgroup can map on same XCC
# -1 : WorkGroupMappingXCCGroup will be set to CU_count at runtime. Please ensure that (CU_count % WGMXCC == 0).
"WorkGroupMappingXCCGroup": list(range(-1, 1024)), # change a workgroup's id so that contiguous workgroup can map on same XCC, remap workgroup in a group of WGMXCCG.

Expand Down
23 changes: 13 additions & 10 deletions tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,20 +1265,20 @@ def wgmXCC(self, kernel, tmpSgprNumWorkGroups):
CU_CountSgpr = WGMXCCSgpr+1

module.add(SLShiftRightB32(dst=sgpr(WGMXCCSgpr), shiftHex=hex(16), src=sgpr("WGM"), comment="Get WGMXCC"))
module.add(SAndB32(dst=sgpr(WGMXCCSgpr), src0=hex(0x3F), src1=sgpr(WGMXCCSgpr), comment="Get WGMXCC"))
module.add(SFf1B32(dst=sgpr(WGMXCCSgpr), src=sgpr(WGMXCCSgpr), comment="Get log(WGMXCC)"))
module.add(SLShiftRightB32(dst=sgpr(CU_CountSgpr), shiftHex=hex(22), src=sgpr("WGM"), comment="Get CU_Count"))

label_skipWGMXCC = Label(label="skip_WGMXCC", comment="skip WGMXCC if no enough WGs to remap")

module.addComment0("remap WGs if WGMXCC > 1")
module.add(SCmpGtU32(src0=sgpr(WGMXCCSgpr), src1=1))
module.addComment0("remap WGs if WGMXCC > 1 ( log(WGMXCC) > 0 )")
module.add(SCmpGtI32(src0=sgpr(WGMXCCSgpr), src1=0))
module.add(SCBranchSCC0(label_skipWGMXCC.getLabelName()))

module.addComment0("only remap WGs in the range")
tmpVgpr = self.vgprPool.checkOut(2)
tmpVgprRes = RegisterPoolResource(tmpVgpr, 2)
module.add(scalarUInt32DivideAndRemainder(qReg=tmpSgpr0, dReg=tmpSgprNumWorkGroups, divReg=WGMXCCSgpr, rReg=None, tmpVgprRes=tmpVgprRes, wavewidth=kernel["WavefrontSize"], doRemainder=0))
module.add(SMulI32(dst=sgpr(tmpSgpr0), src0=sgpr(tmpSgpr0), src1=sgpr(WGMXCCSgpr)))
module.add(SLShiftRightB32(dst=sgpr(tmpSgpr0), shiftHex=sgpr(WGMXCCSgpr), src=sgpr(tmpSgprNumWorkGroups)))
module.add(SLShiftLeftB32(dst=sgpr(tmpSgpr0), shiftHex=sgpr(WGMXCCSgpr), src=sgpr(tmpSgpr0)))
module.add(SCmpGeU32(src0=sgpr("WorkGroup0"), src1=sgpr(tmpSgpr0)))
module.add(SCBranchSCC1(label_skipWGMXCC.getLabelName()))

Expand All @@ -1287,8 +1287,10 @@ def wgmXCC(self, kernel, tmpSgprNumWorkGroups):
module.add(SCBranchSCC0(label_XCCG_nonzero.getLabelName()))

# CU_count == 0
module.add(scalarUInt32DivideAndRemainder(qReg=tmpSgpr0, dReg="WorkGroup0", divReg=WGMXCCSgpr, rReg=tmpSgpr1, tmpVgprRes=tmpVgprRes, wavewidth=kernel["WavefrontSize"], doRemainder=1))
module.add(scalarUInt32DivideAndRemainder(qReg=tmpSgpr2, dReg=tmpSgprNumWorkGroups, divReg=WGMXCCSgpr, rReg=None, tmpVgprRes=tmpVgprRes, wavewidth=kernel["WavefrontSize"], doRemainder=0))
module.add(SLShiftRightB32(dst=sgpr(tmpSgpr0), shiftHex=sgpr(WGMXCCSgpr), src=sgpr("WorkGroup0")))
module.add(SBfmB32(dst=sgpr(tmpSgpr1), src0=sgpr(WGMXCCSgpr), src1=0))
module.add(SAndB32(dst=sgpr(tmpSgpr1), src0=sgpr("WorkGroup0"), src1=sgpr(tmpSgpr1)))
module.add(SLShiftRightB32(dst=sgpr(tmpSgpr2), shiftHex=sgpr(WGMXCCSgpr), src=sgpr(tmpSgprNumWorkGroups)))
module.add(SMulI32(dst=sgpr(tmpSgpr1), src0=sgpr(tmpSgpr1), src1=sgpr(tmpSgpr2)))
module.add(SAddU32(dst=sgpr("WorkGroup0"), src0=sgpr(tmpSgpr0), src1=sgpr(tmpSgpr1)))
module.add(SBranch(label_skipWGMXCC.getLabelName()))
Expand All @@ -1299,7 +1301,7 @@ def wgmXCC(self, kernel, tmpSgprNumWorkGroups):
module.add(scalarUInt32DivideAndRemainder(qReg=tmpSgpr0, dReg="WorkGroup0", divReg=CU_CountSgpr, rReg=tmpSgpr1, tmpVgprRes=tmpVgprRes, wavewidth=kernel["WavefrontSize"], doRemainder=1, comment="wg//CU_Count"))
module.add(SMulI32(dst=sgpr(tmpSgpr0), src0=sgpr(tmpSgpr0), src1=sgpr(CU_CountSgpr)))
module.addComment0("temp1 = (wg%CU_Count)//WGMXCC")
module.add(scalarUInt32DivideAndRemainder(qReg=tmpSgpr1, dReg=tmpSgpr1, divReg=WGMXCCSgpr, rReg=None, tmpVgprRes=tmpVgprRes, wavewidth=kernel["WavefrontSize"], doRemainder=0))
module.add(SLShiftRightB32(dst=sgpr(tmpSgpr1), shiftHex=sgpr(WGMXCCSgpr), src=sgpr(tmpSgpr1)))
module.addComment0("temp0 = temp0 + temp1")
module.add(SAddU32(dst=sgpr(tmpSgpr0), src0=sgpr(tmpSgpr0), src1=sgpr(tmpSgpr1)))
module.addComment0("temp1 = (wg%WGMXCC) * ((WGs - (WGs//CU_Count) * CU_Count) if (wg > (WGs//CU_Count) * CU_Count) else CU_Count)//WGMXCC")
Expand All @@ -1308,8 +1310,9 @@ def wgmXCC(self, kernel, tmpSgprNumWorkGroups):
module.add(SSubU32(dst=sgpr(tmpSgpr2), src0=sgpr(tmpSgprNumWorkGroups), src1=sgpr(tmpSgpr1)))
module.add(SCmpGtU32(src0=sgpr("WorkGroup0"), src1=sgpr(tmpSgpr1)))
module.add(SCSelectB32(dst=sgpr(tmpSgpr1), src0=sgpr(tmpSgpr2), src1=sgpr(CU_CountSgpr)))
module.add(scalarUInt32DivideAndRemainder(qReg=tmpSgpr1, dReg=tmpSgpr1, divReg=WGMXCCSgpr, rReg=None, tmpVgprRes=tmpVgprRes, wavewidth=kernel["WavefrontSize"], doRemainder=0))
module.add(scalarUInt32DivideAndRemainder(qReg=tmpSgpr, dReg="WorkGroup0", divReg=WGMXCCSgpr, rReg=tmpSgpr2, tmpVgprRes=tmpVgprRes, wavewidth=kernel["WavefrontSize"], doRemainder=1))
module.add(SLShiftRightB32(dst=sgpr(tmpSgpr1), shiftHex=sgpr(WGMXCCSgpr), src=sgpr(tmpSgpr1)))
module.add(SBfmB32(dst=sgpr(tmpSgpr2), src0=sgpr(WGMXCCSgpr), src1=0))
module.add(SAndB32(dst=sgpr(tmpSgpr2), src0=sgpr("WorkGroup0"), src1=sgpr(tmpSgpr2)))
self.vgprPool.checkIn(tmpVgpr)
module.add(SMulI32(dst=sgpr(tmpSgpr1), src0=sgpr(tmpSgpr1), src1=sgpr(tmpSgpr2)))
module.addComment0("WorkGroup0 = temp0 + temp1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2603,7 +2603,7 @@ namespace Tensile
virtual bool operator()(ContractionProblemGemm const& problem) const override
{
size_t WGMXCCG = (value[1] == -1) ? cuCount : value[1];
return WGMXCCG % value[0] == 0;
return ((value[0] & (value[0] - 1)) == 0) && WGMXCCG % value[0] == 0;
}

virtual bool debugEval(ContractionProblemGemm const& problem,
Expand Down
12 changes: 12 additions & 0 deletions tensilelite/Tensile/TensileInstructions/Instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,6 +1548,18 @@ def __init__(self, dst, src, comment="") -> None:
super().__init__(InstType.INST_B64, dst, [src], None, None, comment)
self.setInst("s_cmov_b64")

# Find first bit
class SFf1B32(CommonInstruction):
def __init__(self, dst, src, comment="") -> None:
super().__init__(InstType.INST_B32, dst, [src], None, None, comment)
self.setInst("s_ff1_i32_b32")

# Bit field mask
class SBfmB32(CommonInstruction):
def __init__(self, dst, src0, src1, comment="") -> None:
super().__init__(InstType.INST_B32, dst, [src0, src1], None, None, comment)
self.setInst("s_bfm_b32")

# Sign ext
class SMovkI32(CommonInstruction):
def __init__(self, dst, src, comment="") -> None:
Expand Down

0 comments on commit c324499

Please sign in to comment.