Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
RichardAngersbach committed Nov 4, 2024
2 parents bd1aa12 + 9ef53d8 commit 146b0ea
Show file tree
Hide file tree
Showing 14 changed files with 255 additions and 96 deletions.
3 changes: 2 additions & 1 deletion Compiler/src/exastencils/base/ir/IR_Reduction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ package exastencils.base.ir
/// IR_Reduction

// FIXME: op as BinOp
case class IR_Reduction(var op : String, var target : IR_Expression, var targetName : String, var skipMpi : Boolean = false) extends IR_Node
case class IR_Reduction(var op : String, var target : IR_Expression, var targetName : String,
var skipMpi : Boolean = false, var skipOpenMP : Boolean = false) extends IR_Node
35 changes: 19 additions & 16 deletions Compiler/src/exastencils/baseExt/ir/IR_LoopOverDimensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import exastencils.datastructures._
import exastencils.logger.Logger
import exastencils.optimization.ir._
import exastencils.parallelization.ir._
import exastencils.util.ir.IR_FragmentLoopCollector

// FIXME: refactor
object IR_LoopOverDimensions {
Expand Down Expand Up @@ -148,20 +149,8 @@ case class IR_LoopOverDimensions(
}
}

def parallelizationIsReasonable : Boolean = {
val maxItCount = maxIterationCount()
if (maxItCount == null)
return true // cannot determine iteration count, default is no change in parallelizability, i.e. true

var totalNumPoints : Long = 1
for (i <- maxItCount)
totalNumPoints *= i
totalNumPoints > Knowledge.omp_minWorkItemsPerThread * Knowledge.omp_numThreads
}

def explParLoop = lcCSEApplied && parallelization.potentiallyParallel &&
Knowledge.omp_enabled && Knowledge.omp_parallelizeLoopOverDimensions &&
parallelizationIsReasonable
Knowledge.omp_enabled && parallelizationOverDimensionsIsReasonable(maxIterationCount())

def createOMPThreadsWrapper(body : ListBuffer[IR_Statement]) : ListBuffer[IR_Statement] = {
if (explParLoop) {
Expand Down Expand Up @@ -214,9 +203,9 @@ case class IR_LoopOverDimensions(
nju
}

def expandSpecial() : ListBuffer[IR_Statement] = {
def expandSpecial(collector : IR_FragmentLoopCollector) : ListBuffer[IR_Statement] = {
def parallelizable(d : Int) = parallelization.potentiallyParallel && parDims.contains(d)
def parallelize(d : Int) = parallelizable(d) && Knowledge.omp_parallelizeLoopOverDimensions && parallelizationIsReasonable
def parallelize(d : Int) = parallelizable(d) && parallelizationOverDimensionsIsReasonable(maxIterationCount())

// TODO: check interaction between at1stIt and condition (see also: TODO in polyhedron.Extractor.enterLoop)
var wrappedBody : ListBuffer[IR_Statement] = body
Expand Down Expand Up @@ -250,6 +239,16 @@ case class IR_LoopOverDimensions(
wrappedBody = ListBuffer[IR_Statement](loop)
}

// propagate parallelization hints to enclosing fragment loop if parallel
if (Knowledge.omp_parallelizeLoopOverFragments && collector.getEnclosingFragmentLoop().isDefined) {
collector.getEnclosingFragmentLoop().get match {
case fragLoop : IR_LoopOverFragments =>
fragLoop.parallelization.parallelizationReasonable &&= parallelizationOverFragmentsIsReasonable(maxIterationCount())
case fragLoop @ IR_ForLoop(IR_VariableDeclaration(_, name, _, _), _, _, _, _) if name == IR_LoopOverFragments.defIt.name =>
fragLoop.parallelization.parallelizationReasonable &&= parallelizationOverFragmentsIsReasonable(maxIterationCount())
}
}

wrappedBody = createOMPThreadsWrapper(wrappedBody)

wrappedBody
Expand All @@ -259,7 +258,11 @@ case class IR_LoopOverDimensions(
/// IR_ResolveLoopOverDimensions

object IR_ResolveLoopOverDimensions extends DefaultStrategy("Resolve LoopOverDimensions nodes") {
var collector = new IR_FragmentLoopCollector
this.register(collector)
this.onBefore = () => this.resetCollectors()

this += new Transformation("Resolve", {
case loop : IR_LoopOverDimensions => loop.expandSpecial()
case loop : IR_LoopOverDimensions => loop.expandSpecial(collector)
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ case class IR_LoopOverFragments(
// TODO: separate omp and potentiallyParallel
parallelization.potentiallyParallel = Knowledge.omp_enabled && Knowledge.omp_parallelizeLoopOverFragments && parallelization.potentiallyParallel

// if there is no loop found, we determine here if omp parallelization is reasonable
if (Knowledge.omp_enabled && Knowledge.omp_parallelizeLoopOverFragments & !body.exists(_.isInstanceOf[IR_HasParallelizationInfo]))
parallelization.potentiallyParallel &&= parallelizationOverFragmentsIsReasonable(Array(Knowledge.domain_numFragmentsPerBlock))

val loop = IR_ForLoop(
IR_VariableDeclaration(defIt, 0),
IR_Lower(defIt, Knowledge.domain_numFragmentsPerBlock),
Expand Down
50 changes: 44 additions & 6 deletions Compiler/src/exastencils/parallelization/api/cuda/CUDA_Flags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,33 @@ import exastencils.base.ir._
import exastencils.baseExt.ir._
import exastencils.communication.ir._
import exastencils.config._
import exastencils.domain.ir.IR_IV_IsValidForDomain
import exastencils.field.ir._
import exastencils.prettyprinting._

/// CUDA_DirtyFlagHelper

object CUDA_DirtyFlagHelper {
def fragmentIdxIsValid(fragIdx : IR_Expression, domainIdx : IR_Expression) = {
if (fragIdx != IR_LoopOverFragments.defIt)
fragIdx >= 0 AndAnd fragIdx < Knowledge.domain_numFragmentsPerBlock AndAnd IR_IV_IsValidForDomain(domainIdx, fragIdx)
else
IR_IV_IsValidForDomain(domainIdx, fragIdx)
}
}

/// CUDA_DirtyFlagCase

object CUDA_DirtyFlagCase extends Enumeration {
type Access = Value
final val ANNOT : String = "DirtyFlagCase"

// CLEAR : field/buffer was not updated -> no transfer needed
// INTERMEDIATE: field/buffer was updated -> possibly need to wait for event before setting to DIRTY
// DIRTY : field/buffer was updated -> transfer needed if execution hardware changes
final val CLEAR, INTERMEDIATE, DIRTY = Value
}

/// CUDA_HostDataUpdated

// TODO: move to communication package?
Expand All @@ -35,7 +59,14 @@ case class CUDA_HostDataUpdated(override var field : IR_Field, override var slot
override def usesFieldArrays : Boolean = !Knowledge.data_useFieldNamesAsIdx

override def resolveName() = s"hostDataUpdated" + resolvePostfix(fragmentIdx.prettyprint, "", if (Knowledge.data_useFieldNamesAsIdx) field.name else field.index.toString, field.level.toString, "")
override def resolveDefValue() = Some(true)
override def resolveDefValue() = Some(CUDA_DirtyFlagCase.DIRTY.id)

override def resolveDatatype() = {
if (field.numSlots > 1)
IR_ArrayDatatype(IR_IntegerDatatype, field.numSlots)
else
IR_IntegerDatatype
}
}

/// CUDA_DeviceDataUpdated
Expand All @@ -47,7 +78,14 @@ case class CUDA_DeviceDataUpdated(override var field : IR_Field, override var sl
override def usesFieldArrays : Boolean = !Knowledge.data_useFieldNamesAsIdx

override def resolveName() = s"deviceDataUpdated" + resolvePostfix(fragmentIdx.prettyprint, "", if (Knowledge.data_useFieldNamesAsIdx) field.name else field.index.toString, field.level.toString, "")
override def resolveDefValue() = Some(false)
override def resolveDefValue() = Some(CUDA_DirtyFlagCase.CLEAR.id)

override def resolveDatatype() = {
if (field.numSlots > 1)
IR_ArrayDatatype(IR_IntegerDatatype, field.numSlots)
else
IR_IntegerDatatype
}
}

/// CUDA_HostBufferDataUpdated
Expand All @@ -57,8 +95,8 @@ case class CUDA_HostBufferDataUpdated(var field : IR_Field, var direction : Stri
override def prettyprint(out : PpStream) : Unit = out << resolveAccess(resolveName(), fragmentIdx, IR_NullExpression, field.index, field.level, neighIdx)

override def resolveName() = s"hostBufferDataUpdated_$direction" + resolvePostfix(fragmentIdx.prettyprint, "", field.index.toString, field.level.toString, neighIdx.prettyprint)
override def resolveDatatype() = IR_BooleanDatatype
override def resolveDefValue() = Some(false)
override def resolveDatatype() = IR_IntegerDatatype
override def resolveDefValue() = Some(CUDA_DirtyFlagCase.CLEAR.id)
}

/// CUDA_DeviceBufferDataUpdated
Expand All @@ -68,8 +106,8 @@ case class CUDA_DeviceBufferDataUpdated(var field : IR_Field, var direction : St
override def prettyprint(out : PpStream) : Unit = out << resolveAccess(resolveName(), fragmentIdx, IR_NullExpression, field.index, field.level, neighIdx)

override def resolveName() = s"deviceBufferDataUpdated_$direction" + resolvePostfix(fragmentIdx.prettyprint, "", field.index.toString, field.level.toString, neighIdx.prettyprint)
override def resolveDatatype() = IR_BooleanDatatype
override def resolveDefValue() = Some(false)
override def resolveDatatype() = IR_IntegerDatatype
override def resolveDefValue() = Some(CUDA_DirtyFlagCase.CLEAR.id)
}

/// CUDA_ExecutionMode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,56 +221,84 @@ case class CUDA_HandleFragmentLoops(

for (access <- fieldAccesses.toSeq.sortBy(_._1)) {
val fieldData = access._2
val transferStream = CUDA_TransferStream(fieldData.field, Duplicate(fieldData.fragmentIdx))
val field = fieldData.field
val fragIdx = fieldData.fragmentIdx
val domainIdx = field.domain.index
val transferStream = CUDA_TransferStream(field, Duplicate(fragIdx))

// add data sync statements
if (syncBeforeHost(access._1, fieldAccesses.keys))
beforeHost += CUDA_UpdateHostData(Duplicate(fieldData), transferStream).expand().inner // expand here to avoid global expand afterwards

// update flags for written fields
if (syncAfterHost(access._1, fieldAccesses.keys))
afterHost += IR_Assignment(CUDA_HostDataUpdated(fieldData.field, Duplicate(fieldData.slot), Duplicate(fieldData.fragmentIdx)), IR_BooleanConstant(true))
if (syncAfterHost(access._1, fieldAccesses.keys)) {
val dirtyFlag = CUDA_HostDataUpdated(field, Duplicate(fieldData.slot), Duplicate(fragIdx))
val isValid = CUDA_DirtyFlagHelper.fragmentIdxIsValid(fragIdx, domainIdx)
afterHost += IR_IfCondition(isValid AndAnd (dirtyFlag EqEq CUDA_DirtyFlagCase.INTERMEDIATE.id),
IR_Assignment(dirtyFlag, CUDA_DirtyFlagCase.DIRTY.id))
}
}

for (access <- bufferAccesses.toSeq.sortBy(_._1)) {
val buffer = access._2
val transferStream = CUDA_TransferStream(buffer.field, Duplicate(buffer.fragmentIdx))
val field = buffer.field
val fragIdx = buffer.fragmentIdx
val domainIdx = field.domain.index
val transferStream = CUDA_TransferStream(field, Duplicate(fragIdx))

// add buffer sync statements
if (syncBeforeHost(access._1, bufferAccesses.keys))
beforeHost += CUDA_UpdateHostBufferData(Duplicate(buffer), transferStream).expand().inner // expand here to avoid global expand afterwards

// update flags for written buffers
if (syncAfterHost(access._1, bufferAccesses.keys))
afterHost += IR_Assignment(CUDA_HostBufferDataUpdated(buffer.field, buffer.direction, Duplicate(buffer.neighIdx), Duplicate(buffer.fragmentIdx)), IR_BooleanConstant(true))
if (syncAfterHost(access._1, bufferAccesses.keys)) {
val dirtyFlag = CUDA_HostBufferDataUpdated(field, buffer.direction, Duplicate(buffer.neighIdx), Duplicate(fragIdx))
val isValid = CUDA_DirtyFlagHelper.fragmentIdxIsValid(fragIdx, domainIdx)
afterHost += IR_IfCondition(isValid AndAnd (dirtyFlag EqEq CUDA_DirtyFlagCase.INTERMEDIATE.id),
IR_Assignment(dirtyFlag, CUDA_DirtyFlagCase.DIRTY.id))
}
}

// - device sync stmts -

if (isParallel) {
for (access <- fieldAccesses.toSeq.sortBy(_._1)) {
val fieldData = access._2
val transferStream = CUDA_TransferStream(fieldData.field, Duplicate(fieldData.fragmentIdx))
val field = fieldData.field
val fragIdx = fieldData.fragmentIdx
val domainIdx = field.domain.index
val transferStream = CUDA_TransferStream(field, Duplicate(fragIdx))

// add data sync statements
if (syncBeforeDevice(access._1, fieldAccesses.keys))
beforeDevice += CUDA_UpdateDeviceData(Duplicate(fieldData), transferStream).expand().inner // expand here to avoid global expand afterwards

// update flags for written fields
if (syncAfterDevice(access._1, fieldAccesses.keys))
afterDevice += IR_Assignment(CUDA_DeviceDataUpdated(fieldData.field, Duplicate(fieldData.slot), Duplicate(fieldData.fragmentIdx)), IR_BooleanConstant(true))
if (syncAfterDevice(access._1, fieldAccesses.keys)) {
val dirtyFlag = CUDA_DeviceDataUpdated(field, Duplicate(fieldData.slot), Duplicate(fragIdx))
val isValid = CUDA_DirtyFlagHelper.fragmentIdxIsValid(fragIdx, domainIdx)
afterDevice += IR_IfCondition(isValid AndAnd dirtyFlag EqEq CUDA_DirtyFlagCase.INTERMEDIATE.id,
IR_Assignment(dirtyFlag, CUDA_DirtyFlagCase.DIRTY.id))
}
}
for (access <- bufferAccesses.toSeq.sortBy(_._1)) {
val buffer = access._2
val transferStream = CUDA_TransferStream(buffer.field, Duplicate(buffer.fragmentIdx))
val field = buffer.field
val fragIdx = buffer.fragmentIdx
val domainIdx = field.domain.index
val transferStream = CUDA_TransferStream(field, Duplicate(fragIdx))

// add data sync statements
if (syncBeforeDevice(access._1, bufferAccesses.keys))
beforeDevice += CUDA_UpdateDeviceBufferData(Duplicate(buffer), transferStream).expand().inner // expand here to avoid global expand afterwards

// update flags for written fields
if (syncAfterDevice(access._1, bufferAccesses.keys))
afterDevice += IR_Assignment(CUDA_DeviceBufferDataUpdated(buffer.field, buffer.direction, Duplicate(buffer.neighIdx), Duplicate(buffer.fragmentIdx)), IR_BooleanConstant(true))
if (syncAfterDevice(access._1, bufferAccesses.keys)) {
val dirtyFlag = CUDA_DeviceBufferDataUpdated(field, buffer.direction, Duplicate(buffer.neighIdx), Duplicate(fragIdx))
val isValid = CUDA_DirtyFlagHelper.fragmentIdxIsValid(fragIdx, domainIdx)
afterDevice += IR_IfCondition(isValid AndAnd (dirtyFlag EqEq CUDA_DirtyFlagCase.INTERMEDIATE.id),
IR_Assignment(dirtyFlag, CUDA_DirtyFlagCase.DIRTY.id))
}
}
}

Expand Down Expand Up @@ -305,8 +333,11 @@ case class CUDA_HandleFragmentLoops(
val redTarget = Duplicate(red.target)

// move reduction towards "synchroFragLoop"
// -> OpenMP/MPI reduction occurs after accumulation in "synchroFragLoop"
loop.parallelization.reduction = None
// -> MPI reduction occurs after accumulation in "synchroFragLoop" (i.e. skip in enclosing frag loop and its inner dimension loop)
// -> OMP reduction occurs only for parallelization over IR_LoopOverDimensions, otherwise skipped as MPI reduction
loop.parallelization.reduction.get.skipMpi = true
loop.parallelization.reduction.get.skipOpenMP = !Knowledge.omp_parallelizeLoopOverDimensions

syncAfterFragLoop.parallelization.reduction = Some(red)

// force comp stream sync if comp kernels are not synced explicitly
Expand Down
Loading

0 comments on commit 146b0ea

Please sign in to comment.