Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split local write instructions. #1180

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 100 additions & 5 deletions tensilelite/Tensile/Components/SIA.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
from ..TensileInstructions import Item, Module, HolderContainer, Instruction, \
GlobalReadInstruction, LocalReadInstruction, \
LocalWriteInstruction, SSetPrior, SWaitCnt, \
replaceHolder, fastdeepcopy, VMovB32
replaceHolder, fastdeepcopy, VMovB32, \
DSStoreB128, DSStoreB64, DSStoreB32
from ..Common import roundUp
from ..Component import SIA
from ..TensileInstructions.Containers import DSModifiers

import copy
from math import ceil
Expand Down Expand Up @@ -747,8 +749,16 @@ def schedLocalWrite(writer, kernel, numLocalWriteModPerIter, numLocalWritesPerSc
itemsGRToSchedLater, itemsLWToSched, startIter, readsToWait, readsToWaitNGLL, \
firstIter, lastLc, maxVmcnt, startIterItem = None):
# schedule here
localwriteCnt = 0
localwriteCnt = 0
globalReadInstOffset = 0
additionalIndexList = {}
for u in range(startIter, localWriteEndIter+1):
# If we have some LW not scheduled in last Iter, add them.
newAdditionalIndexList = fastdeepcopy(additionalIndexList)
additionalIndexList = {}
for idx in newAdditionalIndexList:
additionalIndexList[idx - itemPerIter] = newAdditionalIndexList[idx]

if u==(localWriteEndIter):
itemPerIter = len(itemsLWToSched) # schedule all remaining activity
else:
Expand All @@ -759,6 +769,7 @@ def schedLocalWrite(writer, kernel, numLocalWriteModPerIter, numLocalWritesPerSc
if u == startIter and startIterItem:
itemPerIter = startIterItem

itemsLWToSchedIndex = 0
for item in itemsLWToSched[:itemPerIter]:
# Use a module to ensure these pieces stay together in the sub-iter scheduler
imod = Module("LocalWriteMod%u"%u)
Expand All @@ -767,6 +778,15 @@ def schedLocalWrite(writer, kernel, numLocalWriteModPerIter, numLocalWritesPerSc
if kernel["ProblemType"]["Sparse"] and not writesPerItem:
writesPerItem = item.name.startswith("MetadataWrite") and item.countType(VMovB32)
if writesPerItem:
# Split into several dsStore32
itemNew, numItemNew, globalReadInstOffset = splitDSInstructionIntoSmaller(writer, kernel, item, numLocalWritesPerSched, len(itemsLWToSched), itemsLWToSchedIndex)
if itemsLWToSchedIndex + globalReadInstOffset <= len(itemsLWToSched):
additionalIndexList = {}
for i in range(numItemNew):
additionalIndexList[i * numLocalWritesPerSched + itemsLWToSchedIndex] = itemNew[i]
else:
globalReadInstOffset = 0

imod.addComment0("sched write - iter %u writesPerItem=%u"%(u,writesPerItem))
imodNGLL.addComment0("sched write - iter %u writesPerItem=%u"%(u,writesPerItem))
# if writesPerItem>1 this indicates multiple LocalWrites in the same module
Expand All @@ -791,10 +811,15 @@ def schedLocalWrite(writer, kernel, numLocalWriteModPerIter, numLocalWritesPerSc
readsToWaitAdjust = len(list(writer.codes.globalReadA.middle.items())) + len(list(writer.codes.globalReadB.middle.items()))
for wc in wcList:
replaceHolder(wc, (readsToWaitAdjust))

imod.add(item)

if itemsLWToSchedIndex in additionalIndexList:
imod.add(additionalIndexList[itemsLWToSchedIndex])
additionalIndexList.pop(itemsLWToSchedIndex)
else:
imod.add(item)
# schedule global instruction that need to be scheduled later
if localwriteCnt % PRECISION == (numLocalWritesPerSched % PRECISION):
if localwriteCnt % PRECISION == ((numLocalWritesPerSched % PRECISION) + globalReadInstOffset):
globalReadInstOffset = 0
reads = 0
while itemsGRToSchedLater:
itemGR = itemsGRToSchedLater[0]
Expand Down Expand Up @@ -825,6 +850,7 @@ def schedLocalWrite(writer, kernel, numLocalWriteModPerIter, numLocalWritesPerSc
# in that case, local write code for NGLL is not as expected.
writer.codes.perIterLocalWriteCodeNGLL[u].add(imodNGLL)

itemsLWToSchedIndex += 1
itemsLWToSched = itemsLWToSched[itemPerIter:]

# should never run out of items to schedule
Expand All @@ -838,6 +864,75 @@ def schedLocalWrite(writer, kernel, numLocalWriteModPerIter, numLocalWritesPerSc
imod.add(itemGR)
itemsGRToSchedLater.pop(0)

def splitDSInstructionIntoSmaller(writer, kernel, item, numLocalWritesPerSched, lenOfItems, currentModIdx):
if not item:
return None, 0, 0
if item.countType(DSStoreB128) != 1 or item.countType(Instruction) != 1:
# only support one b128
return None, 0, 0

instruction = None
itemList = item.flatitems()
for inst in itemList:
if isinstance(inst, DSStoreB128):
instruction = inst
break
if instruction == None:
assert 0, "no instructions to be splitted"

lwLatency = DSStoreB128.issueLatency()
miLatency = writer.states.miLatency
div = 1
dsOffset = 0

LocalWriteX = DSStoreB128
if DSStoreB128.issueLatency() < (miLatency - 1):
# no need to split
return None, 0, 0
elif DSStoreB64.issueLatency() < (miLatency - 1):
LocalWriteX = DSStoreB64
dsOffset = 8
div = 2
elif DSStoreB32.issueLatency() < (miLatency - 1):
LocalWriteX = DSStoreB32
dsOffset = 4
div = 4
else:
# miLatency is not enough
return None, 0, 0

if numLocalWritesPerSched * div >= PRECISION:
# no enough mfma to split
return None, 0, 0

if (currentModIdx + numLocalWritesPerSched * div) >= lenOfItems:
# no enough modules to schedule
return None, 0, 0

# LW b32 ~ 56 cycles
# 4-way bank conflict = (4-1) x 2(1 addr + 1 data) = 6 cycles
# round with quad-cycle
finalLWCycles = roundUp((56 + 6) / 4)
extraSched = roundUp(finalLWCycles / miLatency)
if (currentModIdx + numLocalWritesPerSched * (div - 1 + extraSched)) >= lenOfItems:
# no enough cycles before barrier
return None, 0, 0

addr = instruction.getParams()[0]
srcr = instruction.getParams()[1]
offs = instruction.getParams()[2]
ds = instruction.getParams()[3]
writeInst = []
for d in range(div):
ds1 = DSModifiers(na=1, offset=ds.offset + dsOffset * d)
r1 = fastdeepcopy(srcr)
r1.regNum //= div
r1.regName.offsets.append(4 // div * d)
writeInst.append(LocalWriteX(dstAddr=addr, src=r1, ds=ds1, comment=instruction.comment + " splitted"))

return writeInst, len(writeInst), numLocalWritesPerSched * (div - 1)


################################################################################
################################################################################
###
Expand Down
2 changes: 1 addition & 1 deletion tensilelite/Tensile/TensileInstructions/Instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def __init__(self, instType: InstType, dstAddr, src0, src1, \
self.ds = ds

def getParams(self) -> list:
return [self.dstAddr, self.src0, self.src1]
return [self.dstAddr, self.src0, self.src1, self.ds]

def preStr(self):
if self.kernel.isa[0] < 11:
Expand Down