Skip to content

Commit

Permalink
Merge pull request #203 from retronym/topic/rt
Browse files Browse the repository at this point in the history
Deterministic output from the async macro
  • Loading branch information
adriaanm authored Nov 15, 2018
2 parents b662af6 + 2c4ac2f commit 7857e41
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 26 deletions.
46 changes: 46 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,49 @@ pomExtra := (
</developers>
)
OsgiKeys.exportPackage := Seq(s"scala.async.*;version=${version.value}")

commands += testDeterminism

def testDeterminism = Command.command("testDeterminism") { state =>
val extracted = Project.extract(state)
println("Running test:clean")
val (state1, _) = extracted.runTask(clean in Test in LocalRootProject, state)
println("Running test:compile")
val (state2, _) = extracted.runTask(compile in Test in LocalRootProject, state1)
val testClasses = extracted.get(classDirectory in Test)
val baseline: File = testClasses.getParentFile / (testClasses.getName + "-baseline")
baseline.mkdirs()
IO.copyDirectory(testClasses, baseline, overwrite = true)
IO.delete(testClasses)
println("Running test:compile")
val (state3, _) = extracted.runTask(compile in Test in LocalRootProject, state2)

import java.nio.file.FileVisitResult
import java.nio.file.{Files, Path}
import java.nio.file.SimpleFileVisitor
import java.nio.file.attribute.BasicFileAttributes
import java.util

def checkSameFileContents(one: Path, other: Path): Unit = {
Files.walkFileTree(one, new SimpleFileVisitor[Path]() {
override def visitFile(file: Path, attrs: BasicFileAttributes): FileVisitResult = {
val result: FileVisitResult = super.visitFile(file, attrs)
// get the relative file name from path "one"
val relativize: Path = one.relativize(file)
// construct the path for the counterpart file in "other"
val fileInOther: Path = other.resolve(relativize)
val otherBytes: Array[Byte] = Files.readAllBytes(fileInOther)
val thisBytes: Array[Byte] = Files.readAllBytes(file)
if (!(util.Arrays.equals(otherBytes, thisBytes))) {
throw new AssertionError(file + " is not equal to " + fileInOther)
}
return result
}
})
}
println("Comparing: " + baseline.toPath + " and " + testClasses.toPath)
checkSameFileContents(baseline.toPath, testClasses.toPath)
checkSameFileContents(testClasses.toPath, baseline.toPath)

state3
}
35 changes: 18 additions & 17 deletions src/main/scala/scala/async/internal/Lifter.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package scala.async.internal

import scala.collection.mutable

trait Lifter {
self: AsyncMacro =>
import c.universe._
Expand Down Expand Up @@ -37,7 +39,7 @@ trait Lifter {
}


val defs: Map[Tree, Int] = {
val defs: mutable.LinkedHashMap[Tree, Int] = {
/** Collect the DefTrees directly enclosed within `t` that have the same owner */
def collectDirectlyEnclosedDefs(t: Tree): List[DefTree] = t match {
case ld: LabelDef => Nil
Expand All @@ -48,33 +50,33 @@ trait Lifter {
companionship.record(childDefs)
childDefs
}
asyncStates.flatMap {
mutable.LinkedHashMap(asyncStates.flatMap {
asyncState =>
val defs = collectDirectlyEnclosedDefs(Block(asyncState.allStats: _*))
defs.map((_, asyncState.state))
}.toMap
}: _*)
}

// In which block are these symbols defined?
val symToDefiningState: Map[Symbol, Int] = defs.map {
val symToDefiningState: mutable.LinkedHashMap[Symbol, Int] = defs.map {
case (k, v) => (k.symbol, v)
}

// The definitions trees
val symToTree: Map[Symbol, Tree] = defs.map {
val symToTree: mutable.LinkedHashMap[Symbol, Tree] = defs.map {
case (k, v) => (k.symbol, k)
}

// The direct references of each definition tree
val defSymToReferenced: Map[Symbol, List[Symbol]] = defs.keys.map {
case tree => (tree.symbol, tree.collect {
val defSymToReferenced: mutable.LinkedHashMap[Symbol, List[Symbol]] = defs.map {
case (tree, _) => (tree.symbol, tree.collect {
case rt: RefTree if symToDefiningState.contains(rt.symbol) => rt.symbol
})
}.toMap
}

// The direct references of each block, excluding references of `DefTree`-s which
// are already accounted for.
val stateIdToDirectlyReferenced: Map[Int, List[Symbol]] = {
val stateIdToDirectlyReferenced: mutable.LinkedHashMap[Int, List[Symbol]] = {
val refs: List[(Int, Symbol)] = asyncStates.flatMap(
asyncState => asyncState.stats.filterNot(t => t.isDef && !isLabel(t.symbol)).flatMap(_.collect {
case rt: RefTree
Expand All @@ -84,8 +86,8 @@ trait Lifter {
toMultiMap(refs)
}

def liftableSyms: Set[Symbol] = {
val liftableMutableSet = collection.mutable.Set[Symbol]()
def liftableSyms: mutable.LinkedHashSet[Symbol] = {
val liftableMutableSet = mutable.LinkedHashSet[Symbol]()
def markForLift(sym: Symbol): Unit = {
if (!liftableMutableSet(sym)) {
liftableMutableSet += sym
Expand All @@ -97,19 +99,19 @@ trait Lifter {
}
}
// Start things with DefTrees directly referenced from statements from other states...
val liftableStatementRefs: List[Symbol] = stateIdToDirectlyReferenced.toList.flatMap {
val liftableStatementRefs: List[Symbol] = stateIdToDirectlyReferenced.iterator.flatMap {
case (i, syms) => syms.filter(sym => symToDefiningState(sym) != i)
}
}.toList
// .. and likewise for DefTrees directly referenced by other DefTrees from other states
val liftableRefsOfDefTrees = defSymToReferenced.toList.flatMap {
case (referee, referents) => referents.filter(sym => symToDefiningState(sym) != symToDefiningState(referee))
}
// Mark these for lifting, which will follow transitive references.
(liftableStatementRefs ++ liftableRefsOfDefTrees).foreach(markForLift)
liftableMutableSet.toSet
liftableMutableSet
}

val lifted = liftableSyms.map(symToTree).toList.map {
liftableSyms.iterator.map(symToTree).map {
t =>
val sym = t.symbol
val treeLifted = t match {
Expand Down Expand Up @@ -147,7 +149,6 @@ trait Lifter {
treeCopy.TypeDef(td, Modifiers(sym.flags), sym.name, tparams, rhs)
}
atPos(t.pos)(treeLifted)
}
lifted
}.toList
}
}
16 changes: 9 additions & 7 deletions src/main/scala/scala/async/internal/LiveVariables.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package scala.async.internal

import scala.collection.mutable

import java.util
import java.util.function.{IntConsumer, IntPredicate}

Expand All @@ -19,12 +21,12 @@ trait LiveVariables {
* @return a map mapping a state to the fields that should be nulled out
* upon resuming that state
*/
def fieldsToNullOut(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Int, List[Tree]] = {
def fieldsToNullOut(asyncStates: List[AsyncState], liftables: List[Tree]): mutable.LinkedHashMap[Int, List[Tree]] = {
// live variables analysis:
// the result map indicates in which states a given field should be nulled out
val liveVarsMap: Map[Tree, StateSet] = liveVars(asyncStates, liftables)
val liveVarsMap: mutable.LinkedHashMap[Tree, StateSet] = liveVars(asyncStates, liftables)

var assignsOf = Map[Int, List[Tree]]()
var assignsOf = mutable.LinkedHashMap[Int, List[Tree]]()

for ((fld, where) <- liveVarsMap) {
where.foreach { new IntConsumer { def accept(state: Int): Unit = {
Expand Down Expand Up @@ -54,7 +56,7 @@ trait LiveVariables {
* @param liftables the lifted fields
* @return a map which indicates for a given field (the key) the states in which it should be nulled out
*/
def liveVars(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Tree, StateSet] = {
def liveVars(asyncStates: List[AsyncState], liftables: List[Tree]): mutable.LinkedHashMap[Tree, StateSet] = {
val liftedSyms: Set[Symbol] = // include only vars
liftables.iterator.filter {
case ValDef(mods, _, _, _) => mods.hasFlag(MUTABLE)
Expand Down Expand Up @@ -262,15 +264,15 @@ trait LiveVariables {
result
}

val lastUsages: Map[Tree, StateSet] =
liftables.iterator.map(fld => fld -> lastUsagesOf(fld, finalState)).toMap
val lastUsages: mutable.LinkedHashMap[Tree, StateSet] =
mutable.LinkedHashMap(liftables.map(fld => fld -> lastUsagesOf(fld, finalState)): _*)

if(AsyncUtils.verbose) {
for ((fld, lastStates) <- lastUsages)
AsyncUtils.vprintln(s"field ${fld.symbol.name} is last used in states ${lastStates.iterator.mkString(", ")}")
}

val nullOutAt: Map[Tree, StateSet] =
val nullOutAt: mutable.LinkedHashMap[Tree, StateSet] =
for ((fld, lastStates) <- lastUsages) yield {
var result = new StateSet
lastStates.foreach(new IntConsumer { def accept(s: Int): Unit = {
Expand Down
13 changes: 11 additions & 2 deletions src/main/scala/scala/async/internal/TransformUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ package scala.async.internal
import scala.reflect.macros.Context
import reflect.ClassTag
import scala.collection.immutable.ListMap
import scala.collection.mutable
import scala.collection.mutable.ListBuffer

/**
* Utilities used in both `ExprBuilder` and `AnfTransform`.
Expand Down Expand Up @@ -303,8 +305,15 @@ private[async] trait TransformUtils {
})
}

def toMultiMap[A, B](as: Iterable[(A, B)]): Map[A, List[B]] =
as.toList.groupBy(_._1).mapValues(_.map(_._2).toList).toMap
def toMultiMap[A, B](abs: Iterable[(A, B)]): mutable.LinkedHashMap[A, List[B]] = {
// LinkedHashMap for stable order of results.
val result = new mutable.LinkedHashMap[A, ListBuffer[B]]()
for ((a, b) <- abs) {
val buffer = result.getOrElseUpdate(a, new ListBuffer[B])
buffer += b
}
result.map { case (a, b) => (a, b.toList) }
}

// Attributed version of `TreeGen#mkCastPreservingAnnotations`
def mkAttributedCastPreservingAnnotations(tree: Tree, tp: Type): Tree = {
Expand Down

0 comments on commit 7857e41

Please sign in to comment.