Skip to content

Commit

Permalink
Merge pull request #141 from retronym/ticket/await-extractor
Browse files Browse the repository at this point in the history
Enable a compiler plugin to use the async transform after patmat
  • Loading branch information
retronym committed Sep 24, 2015
2 parents 93f207f + 168e10c commit 7263aaa
Show file tree
Hide file tree
Showing 11 changed files with 459 additions and 55 deletions.
81 changes: 70 additions & 11 deletions src/main/scala/scala/async/internal/AnfTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@ private[async] trait AnfTransform {
import c.internal._
import decorators._

def anfTransform(tree: Tree): Block = {
def anfTransform(tree: Tree, owner: Symbol): Block = {
// Must prepend the () for issue #31.
val block = c.typecheck(atPos(tree.pos)(Block(List(Literal(Constant(()))), tree))).setType(tree.tpe)
val block = c.typecheck(atPos(tree.pos)(newBlock(List(Literal(Constant(()))), tree))).setType(tree.tpe)

sealed abstract class AnfMode
case object Anf extends AnfMode
case object Linearizing extends AnfMode

val tree1 = adjustTypeOfTranslatedPatternMatches(block, owner)

var mode: AnfMode = Anf
typingTransform(block)((tree, api) => {
typingTransform(tree1, owner)((tree, api) => {
def blockToList(tree: Tree): List[Tree] = tree match {
case Block(stats, expr) => stats :+ expr
case t => t :: Nil
Expand All @@ -34,7 +36,7 @@ private[async] trait AnfTransform {
def listToBlock(trees: List[Tree]): Block = trees match {
case trees @ (init :+ last) =>
val pos = trees.map(_.pos).reduceLeft(_ union _)
Block(init, last).setType(last.tpe).setPos(pos)
newBlock(init, last).setType(last.tpe).setPos(pos)
}

object linearize {
Expand Down Expand Up @@ -66,6 +68,17 @@ private[async] trait AnfTransform {
stats :+ valDef :+ atPos(tree.pos)(ref1)

case If(cond, thenp, elsep) =>
// If we run the ANF transform post patmat, deal with trees like `(if (cond) jump1(){String} else jump2(){String}){String}`
// as though it was typed with `Unit`.
def isPatMatGeneratedJump(t: Tree): Boolean = t match {
case Block(_, expr) => isPatMatGeneratedJump(expr)
case If(_, thenp, elsep) => isPatMatGeneratedJump(thenp) && isPatMatGeneratedJump(elsep)
case _: Apply if isLabel(t.symbol) => true
case _ => false
}
if (isPatMatGeneratedJump(expr)) {
internal.setType(expr, definitions.UnitTpe)
}
// if type of if-else is Unit don't introduce assignment,
// but add Unit value to bring it into form expected by async transform
if (expr.tpe =:= definitions.UnitTpe) {
Expand All @@ -77,7 +90,7 @@ private[async] trait AnfTransform {
def branchWithAssign(orig: Tree) = api.typecheck(atPos(orig.pos) {
def cast(t: Tree) = mkAttributedCastPreservingAnnotations(t, tpe(varDef.symbol))
orig match {
case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr)))
case Block(thenStats, thenExpr) => newBlock(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr)))
case _ => Assign(Ident(varDef.symbol), cast(orig))
}
})
Expand Down Expand Up @@ -115,7 +128,7 @@ private[async] trait AnfTransform {
}
}

private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = {
def defineVar(prefix: String, tp: Type, pos: Position): ValDef = {
val sym = api.currentOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(uncheckedBounds(tp))
valDef(sym, mkZero(uncheckedBounds(tp))).setType(NoType).setPos(pos)
}
Expand Down Expand Up @@ -152,8 +165,7 @@ private[async] trait AnfTransform {
}

def _transformToList(tree: Tree): List[Tree] = trace(tree) {
val containsAwait = tree exists isAwait
if (!containsAwait) {
if (!containsAwait(tree)) {
tree match {
case Block(stats, expr) =>
// avoids nested block in `while(await(false)) ...`.
Expand Down Expand Up @@ -207,10 +219,11 @@ private[async] trait AnfTransform {
funStats ++ argStatss.flatten.flatten :+ typedNewApply

case Block(stats, expr) =>
(stats :+ expr).flatMap(linearize.transformToList)
val trees = stats.flatMap(linearize.transformToList).filterNot(isLiteralUnit) ::: linearize.transformToList(expr)
eliminateMatchEndLabelParameter(trees)

case ValDef(mods, name, tpt, rhs) =>
if (rhs exists isAwait) {
if (containsAwait(rhs)) {
val stats :+ expr = api.atOwner(api.currentOwner.owner)(linearize.transformToList(rhs))
stats.foreach(_.changeOwner(api.currentOwner, api.currentOwner.owner))
stats :+ treeCopy.ValDef(tree, mods, name, tpt, expr)
Expand Down Expand Up @@ -247,7 +260,7 @@ private[async] trait AnfTransform {
scrutStats :+ treeCopy.Match(tree, scrutExpr, caseDefs)

case LabelDef(name, params, rhs) =>
List(LabelDef(name, params, Block(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
List(LabelDef(name, params, newBlock(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))

case TypeApply(fun, targs) =>
val funStats :+ simpleFun = linearize.transformToList(fun)
Expand All @@ -259,6 +272,52 @@ private[async] trait AnfTransform {
}
}

// Replace the label parameters on `matchEnd` with use of a `matchRes` temporary variable
//
// CaseDefs are translated to labels without parmeters. A terminal label, `matchEnd`, accepts
// a parameter which is the result of the match (this is regular, so even Unit-typed matches have this).
//
// For our purposes, it is easier to:
// - extract a `matchRes` variable
// - rewrite the terminal label def to take no parameters, and instead read this temp variable
// - change jumps to the terminal label to an assignment and a no-arg label application
def eliminateMatchEndLabelParameter(statsExpr: List[Tree]): List[Tree] = {
import internal.{methodType, setInfo}
val caseDefToMatchResult = collection.mutable.Map[Symbol, Symbol]()

val matchResults = collection.mutable.Buffer[Tree]()
val statsExpr0 = statsExpr.reverseMap {
case ld @ LabelDef(_, param :: Nil, body) =>
val matchResult = linearize.defineVar(name.matchRes, param.tpe, ld.pos)
matchResults += matchResult
caseDefToMatchResult(ld.symbol) = matchResult.symbol
val ld2 = treeCopy.LabelDef(ld, ld.name, Nil, body.substituteSymbols(param.symbol :: Nil, matchResult.symbol :: Nil))
setInfo(ld.symbol, methodType(Nil, ld.symbol.info.resultType))
ld2
case t =>
if (caseDefToMatchResult.isEmpty) t
else typingTransform(t)((tree, api) =>
tree match {
case Apply(fun, arg :: Nil) if isLabel(fun.symbol) && caseDefToMatchResult.contains(fun.symbol) =>
api.typecheck(atPos(tree.pos)(newBlock(Assign(Ident(caseDefToMatchResult(fun.symbol)), api.recur(arg)) :: Nil, treeCopy.Apply(tree, fun, Nil))))
case Block(stats, expr) =>
api.default(tree) match {
case Block(stats, Block(stats1, expr)) =>
treeCopy.Block(tree, stats ::: stats1, expr)
case t => t
}
case _ =>
api.default(tree)
}
)
}
matchResults.toList match {
case Nil => statsExpr
case r1 :: Nil => (r1 +: statsExpr0.reverse) :+ atPos(tree.pos)(gen.mkAttributedIdent(r1.symbol))
case _ => c.error(macroPos, "Internal error: unexpected tree encountered during ANF transform " + statsExpr); statsExpr
}
}

def anfLinearize(tree: Tree): Block = {
val trees: List[Tree] = mode match {
case Anf => anf._transformToList(tree)
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/scala/async/internal/AsyncBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ abstract class AsyncBase {
(body: c.Expr[T])
(execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = {
import c.universe._, c.internal._, decorators._
val asyncMacro = AsyncMacro(c, self)
val asyncMacro = AsyncMacro(c, self)(body.tree)

val code = asyncMacro.asyncTransform[T](body.tree, execContext.tree)(c.weakTypeTag[T])
val code = asyncMacro.asyncTransform[T](execContext.tree)(c.weakTypeTag[T])
AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code}")

// Mark range positions for synthetic code as transparent to allow some wiggle room for overlapping ranges
Expand Down
10 changes: 5 additions & 5 deletions src/main/scala/scala/async/internal/AsyncId.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ object AsyncTestLV extends AsyncBase {
* A trivial implementation of [[FutureSystem]] that performs computations
* on the current thread. Useful for testing.
*/
class Box[A] {
var a: A = _
}
object IdentityFutureSystem extends FutureSystem {

class Prom[A] {
var a: A = _
}
type Prom[A] = Box[A]

type Fut[A] = A
type ExecContext = Unit
Expand All @@ -57,7 +57,7 @@ object IdentityFutureSystem extends FutureSystem {

def execContext: Expr[ExecContext] = c.Expr[Unit](Literal(Constant(())))

def promType[A: WeakTypeTag]: Type = weakTypeOf[Prom[A]]
def promType[A: WeakTypeTag]: Type = weakTypeOf[Box[A]]
def tryType[A: WeakTypeTag]: Type = weakTypeOf[scala.util.Try[A]]
def execContextType: Type = weakTypeOf[Unit]

Expand Down
7 changes: 6 additions & 1 deletion src/main/scala/scala/async/internal/AsyncMacro.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package scala.async.internal

object AsyncMacro {
def apply(c0: reflect.macros.Context, base: AsyncBase): AsyncMacro { val c: c0.type } = {
def apply(c0: reflect.macros.Context, base: AsyncBase)(body0: c0.Tree): AsyncMacro { val c: c0.type } = {
import language.reflectiveCalls
new AsyncMacro { self =>
val c: c0.type = c0
val body: c.Tree = body0
// This member is required by `AsyncTransform`:
val asyncBase: AsyncBase = base
// These members are required by `ExprBuilder`:
val futureSystem: FutureSystem = base.futureSystem
val futureSystemOps: futureSystem.Ops {val c: self.c.type} = futureSystem.mkOps(c)
val containsAwait: c.Tree => Boolean = containsAwaitCached(body0)
}
}
}
Expand All @@ -19,7 +21,10 @@ private[async] trait AsyncMacro
with ExprBuilder with AsyncTransform with AsyncAnalysis with LiveVariables {

val c: scala.reflect.macros.Context
val body: c.Tree
val containsAwait: c.Tree => Boolean

lazy val macroPos = c.macroApplication.pos.makeTransparent
def atMacroPos(t: c.Tree) = c.universe.atPos(macroPos)(t)

}
11 changes: 6 additions & 5 deletions src/main/scala/scala/async/internal/AsyncTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ trait AsyncTransform {

val asyncBase: AsyncBase

def asyncTransform[T](body: Tree, execContext: Tree)
def asyncTransform[T](execContext: Tree)
(resultType: WeakTypeTag[T]): Tree = {

// We annotate the type of the whole expression as `T @uncheckedBounds` so as not to introduce
Expand All @@ -22,7 +22,7 @@ trait AsyncTransform {
// Transform to A-normal form:
// - no await calls in qualifiers or arguments,
// - if/match only used in statement position.
val anfTree0: Block = anfTransform(body)
val anfTree0: Block = anfTransform(body, c.internal.enclosingOwner)

val anfTree = futureSystemOps.postAnfTransform(anfTree0)

Expand All @@ -35,15 +35,15 @@ trait AsyncTransform {
val stateMachine: ClassDef = {
val body: List[Tree] = {
val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(StateAssigner.Initial)))
val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T](uncheckedBoundsResultTag)), futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree)
val resultAndAccessors = mkMutableField(futureSystemOps.promType[T](uncheckedBoundsResultTag), name.result, futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree)
val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext)

val apply0DefDef: DefDef = {
// We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`.
// See SI-1247 for the the optimization that avoids creation.
DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.apply), literalNull :: Nil))
}
List(emptyConstructor, stateVar, result, execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef)
List(emptyConstructor, stateVar) ++ resultAndAccessors ++ List(execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef)
}

val tryToUnit = appliedType(definitions.FunctionClass(1), futureSystemOps.tryType[Any], typeOf[Unit])
Expand Down Expand Up @@ -98,10 +98,11 @@ trait AsyncTransform {
}

val isSimple = asyncBlock.asyncStates.size == 1
if (isSimple)
val result = if (isSimple)
futureSystemOps.spawn(body, execContext) // generate lean code for the simple case of `async { 1 + 1 }`
else
startStateMachine
cleanupContainsAwaitAttachments(result)
}

def logDiagnostics(anfTree: Tree, states: Seq[String]) {
Expand Down
Loading

0 comments on commit 7263aaa

Please sign in to comment.