Skip to content

Commit

Permalink
split local write instructions.
Browse files Browse the repository at this point in the history
  • Loading branch information
hcman2 committed Oct 1, 2024
1 parent c324499 commit 276e5a6
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 6 deletions.
92 changes: 87 additions & 5 deletions tensilelite/Tensile/Components/SIA.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
from ..TensileInstructions import Item, Module, HolderContainer, Instruction, \
GlobalReadInstruction, LocalReadInstruction, \
LocalWriteInstruction, SSetPrior, SWaitCnt, \
replaceHolder, fastdeepcopy, VMovB32
replaceHolder, fastdeepcopy, VMovB32, \
DSStoreB128, DSStoreB64, DSStoreB32
from ..Common import roundUp
from ..Component import SIA
from ..TensileInstructions.Containers import DSModifiers

import copy
from math import ceil
Expand Down Expand Up @@ -747,8 +749,16 @@ def schedLocalWrite(writer, kernel, numLocalWriteModPerIter, numLocalWritesPerSc
itemsGRToSchedLater, itemsLWToSched, startIter, readsToWait, readsToWaitNGLL, \
firstIter, lastLc, maxVmcnt, startIterItem = None):
# schedule here
localwriteCnt = 0
localwriteCnt = 0
globalReadInstOffset = 0
additionalIndexList = {}
for u in range(startIter, localWriteEndIter+1):
# If we have some LW not scheduled in last Iter, add them.
newAdditionalIndexList = fastdeepcopy(additionalIndexList)
additionalIndexList = {}
for idx in newAdditionalIndexList:
additionalIndexList[idx - itemPerIter] = newAdditionalIndexList[idx]

if u==(localWriteEndIter):
itemPerIter = len(itemsLWToSched) # schedule all remaining activity
else:
Expand All @@ -759,6 +769,7 @@ def schedLocalWrite(writer, kernel, numLocalWriteModPerIter, numLocalWritesPerSc
if u == startIter and startIterItem:
itemPerIter = startIterItem

itemsLWToSchedIndex = 0
for item in itemsLWToSched[:itemPerIter]:
# Use a module to ensure these pieces stay together in the sub-iter scheduler
imod = Module("LocalWriteMod%u"%u)
Expand All @@ -767,6 +778,15 @@ def schedLocalWrite(writer, kernel, numLocalWriteModPerIter, numLocalWritesPerSc
if kernel["ProblemType"]["Sparse"] and not writesPerItem:
writesPerItem = item.name.startswith("MetadataWrite") and item.countType(VMovB32)
if writesPerItem:
# Split into several dsStore32
itemNew, numItemNew, globalReadInstOffset = splitDSInstructionIntoSmaller(writer, kernel, item, numLocalWritesPerSched, itemPerIter)
if itemsLWToSchedIndex + globalReadInstOffset <= len(itemsLWToSched):
additionalIndexList = {}
for i in range(numItemNew):
additionalIndexList[i * numLocalWritesPerSched + itemsLWToSchedIndex] = itemNew[i]
else:
globalReadInstOffset = 0

imod.addComment0("sched write - iter %u writesPerItem=%u"%(u,writesPerItem))
imodNGLL.addComment0("sched write - iter %u writesPerItem=%u"%(u,writesPerItem))
# if writesPerItem>1 this indicates multiple LocalWrites in the same module
Expand All @@ -791,10 +811,15 @@ def schedLocalWrite(writer, kernel, numLocalWriteModPerIter, numLocalWritesPerSc
readsToWaitAdjust = len(list(writer.codes.globalReadA.middle.items())) + len(list(writer.codes.globalReadB.middle.items()))
for wc in wcList:
replaceHolder(wc, (readsToWaitAdjust))

imod.add(item)

if itemsLWToSchedIndex in additionalIndexList:
imod.add(additionalIndexList[itemsLWToSchedIndex])
additionalIndexList.pop(itemsLWToSchedIndex)
else:
imod.add(item)
# schedule global instruction that need to be scheduled later
if localwriteCnt % PRECISION == (numLocalWritesPerSched % PRECISION):
if localwriteCnt % PRECISION == ((numLocalWritesPerSched % PRECISION) + globalReadInstOffset):
globalReadInstOffset = 0
reads = 0
while itemsGRToSchedLater:
itemGR = itemsGRToSchedLater[0]
Expand Down Expand Up @@ -825,6 +850,7 @@ def schedLocalWrite(writer, kernel, numLocalWriteModPerIter, numLocalWritesPerSc
# in that case, local write code for NGLL is not as expected.
writer.codes.perIterLocalWriteCodeNGLL[u].add(imodNGLL)

itemsLWToSchedIndex += 1
itemsLWToSched = itemsLWToSched[itemPerIter:]

# should never run out of items to schedule
Expand All @@ -838,6 +864,62 @@ def schedLocalWrite(writer, kernel, numLocalWriteModPerIter, numLocalWritesPerSc
imod.add(itemGR)
itemsGRToSchedLater.pop(0)

def splitDSInstructionIntoSmaller(writer, kernel, item, numLocalWritesPerSched, itemPerIter):
if not item:
return None, 0, 0
if item.countType(DSStoreB128) != 1 or item.countType(Instruction) != 1:
# only support one b128
return None, 0, 0

instruction = None
itemList = item.flatitems()
for inst in itemList:
if isinstance(inst, DSStoreB128):
instruction = inst
break
if instruction == None:
# no instructions to be splitted
return None, 0, 0

lwLatency = DSStoreB128.issueLatency()
miLatency = writer.states.miLatency
div = 1
dsOffset = 0

LocalWriteX = DSStoreB128
if DSStoreB128.issueLatency() < (miLatency - 1):
# no need to split
return None, 0, 0
elif DSStoreB64.issueLatency() < (miLatency - 1):
LocalWriteX = DSStoreB64
dsOffset = 8
div = 2
elif DSStoreB32.issueLatency() < (miLatency - 1):
LocalWriteX = DSStoreB32
dsOffset = 4
div = 4
else:
# miLatency is not enough
return None, 0, 0

if numLocalWritesPerSched * (div + 1) > PRECISION or numLocalWritesPerSched * div >= itemPerIter:
# no enough mfma to split
return None, 0, 0
addr = instruction.getParams()[0]
srcr = instruction.getParams()[1]
offs = instruction.getParams()[2]
ds = instruction.getParams()[3]
writeInst = []
for d in range(div):
ds1 = DSModifiers(na=1, offset=ds.offset + dsOffset * d)
r1 = fastdeepcopy(srcr)
r1.regNum //= div
r1.regName.offsets.append(4 // div * d)
writeInst.append(LocalWriteX(dstAddr=addr, src=r1, ds=ds1, comment=instruction.comment + " splitted"))

return writeInst, len(writeInst), numLocalWritesPerSched * (div - 1)


################################################################################
################################################################################
###
Expand Down
2 changes: 1 addition & 1 deletion tensilelite/Tensile/TensileInstructions/Instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def __init__(self, instType: InstType, dstAddr, src0, src1, \
self.ds = ds

def getParams(self) -> list:
return [self.dstAddr, self.src0, self.src1]
return [self.dstAddr, self.src0, self.src1, self.ds]

def preStr(self):
if self.kernel.isa[0] < 11:
Expand Down

0 comments on commit 276e5a6

Please sign in to comment.