Skip to content

Commit

Permalink
Refactor handling of rechecked types
Browse files Browse the repository at this point in the history
 - Always store new types on rechecking
 - Store them in a hashmap which is associated with the rechecker of the
   current compilation unit
 - After rechecking is done, the map is forgotten, unless keepTypes is true.
   Under keepTypes, then map is kept in an attachment of the unit's root tree.

Change in nomenclature:

    knownType --> nuType
    rememberType --> setNuType
    hasRememberedType --> hasNuType
  • Loading branch information
odersky committed Dec 17, 2024
1 parent 47f7d14 commit 4285536
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 89 deletions.
46 changes: 29 additions & 17 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,22 @@ object CheckCaptures:
checkNotUniversal.traverse(tpe.widen)
end checkNotUniversalInUnboxedResult

trait CheckerAPI:
/** Complete symbol info of a val or a def */
def completeDef(tree: ValOrDefDef, sym: Symbol)(using Context): Type

extension [T <: Tree](tree: T)

/** Set new type of the tree if none was installed yet. */
def setNuType(tpe: Type): Unit

/** The new type of the tree, or if none was installed, the original type */
def nuType(using Context): Type

/** Was a new type installed for this tree? */
def hasNuType: Boolean
end CheckerAPI

class CheckCaptures extends Recheck, SymTransformer:
thisPhase =>

Expand All @@ -243,7 +259,7 @@ class CheckCaptures extends Recheck, SymTransformer:

val ccState1 = new CCState // Dotty problem: Rename to ccState ==> Crash in ExplicitOuter

class CaptureChecker(ictx: Context) extends Rechecker(ictx):
class CaptureChecker(ictx: Context) extends Rechecker(ictx), CheckerAPI:

/** The current environment */
private val rootEnv: Env = inContext(ictx):
Expand All @@ -261,10 +277,6 @@ class CheckCaptures extends Recheck, SymTransformer:
*/
private val todoAtPostCheck = new mutable.ListBuffer[() => Unit]

override def keepType(tree: Tree) =
super.keepType(tree)
|| tree.isInstanceOf[Try] // type of `try` needs tp be checked for * escapes

/** Instantiate capture set variables appearing contra-variantly to their
* upper approximation.
*/
Expand All @@ -286,8 +298,8 @@ class CheckCaptures extends Recheck, SymTransformer:
*/
private def interpolateVarsIn(tpt: Tree)(using Context): Unit =
if tpt.isInstanceOf[InferredTypeTree] then
interpolator().traverse(tpt.knownType)
.showing(i"solved vars in ${tpt.knownType}", capt)
interpolator().traverse(tpt.nuType)
.showing(i"solved vars in ${tpt.nuType}", capt)
for msg <- ccState.approxWarnings do
report.warning(msg, tpt.srcPos)
ccState.approxWarnings.clear()
Expand Down Expand Up @@ -501,11 +513,11 @@ class CheckCaptures extends Recheck, SymTransformer:
then ("\nThis is often caused by a local capability$where\nleaking as part of its result.", fn.srcPos)
else if arg.span.exists then ("", arg.srcPos)
else ("", fn.srcPos)
disallowRootCapabilitiesIn(arg.knownType, NoSymbol,
disallowRootCapabilitiesIn(arg.nuType, NoSymbol,
i"Type variable $pname of $sym", "be instantiated to", addendum, pos)

val param = fn.symbol.paramNamed(pname)
if param.isUseParam then markFree(arg.knownType.deepCaptureSet, pos)
if param.isUseParam then markFree(arg.nuType.deepCaptureSet, pos)
end disallowCapInTypeArgs

override def recheckIdent(tree: Ident, pt: Type)(using Context): Type =
Expand Down Expand Up @@ -769,8 +781,8 @@ class CheckCaptures extends Recheck, SymTransformer:
*/
def checkContains(tree: TypeApply)(using Context): Unit = tree match
case ContainsImpl(csArg, refArg) =>
val cs = csArg.knownType.captureSet
val ref = refArg.knownType
val cs = csArg.nuType.captureSet
val ref = refArg.nuType
capt.println(i"check contains $cs , $ref")
ref match
case ref: CaptureRef if ref.isTracked =>
Expand Down Expand Up @@ -852,7 +864,7 @@ class CheckCaptures extends Recheck, SymTransformer:
case _ =>
(sym, "")
disallowRootCapabilitiesIn(
tree.tpt.knownType, carrier, i"Mutable $sym", "have type", addendum, sym.srcPos)
tree.tpt.nuType, carrier, i"Mutable $sym", "have type", addendum, sym.srcPos)
checkInferredResult(super.recheckValDef(tree, sym), tree)
finally
if !sym.is(Param) then
Expand Down Expand Up @@ -1533,7 +1545,7 @@ class CheckCaptures extends Recheck, SymTransformer:
private val setup: SetupAPI = thisPhase.prev.asInstanceOf[Setup]

override def checkUnit(unit: CompilationUnit)(using Context): Unit =
setup.setupUnit(unit.tpdTree, completeDef)
setup.setupUnit(unit.tpdTree, this)
collectCapturedMutVars.traverse(unit.tpdTree)

if ctx.settings.YccPrintSetup.value then
Expand Down Expand Up @@ -1676,7 +1688,7 @@ class CheckCaptures extends Recheck, SymTransformer:
traverseChildren(tp)

if tree.isInstanceOf[InferredTypeTree] then
checker.traverse(tree.knownType)
checker.traverse(tree.nuType)
end healTypeParam

/** Under the unsealed policy: Arrays are like vars, check that their element types
Expand Down Expand Up @@ -1716,10 +1728,10 @@ class CheckCaptures extends Recheck, SymTransformer:
check(tree)
def check(tree: Tree)(using Context) = tree match
case TypeApply(fun, args) =>
fun.knownType.widen match
fun.nuType.widen match
case tl: PolyType =>
val normArgs = args.lazyZip(tl.paramInfos).map: (arg, bounds) =>
arg.withType(arg.knownType.forceBoxStatus(
arg.withType(arg.nuType.forceBoxStatus(
bounds.hi.isBoxedCapturing | bounds.lo.isBoxedCapturing))
checkBounds(normArgs, tl)
args.lazyZip(tl.paramNames).foreach(healTypeParam(_, _, fun.symbol))
Expand All @@ -1739,7 +1751,7 @@ class CheckCaptures extends Recheck, SymTransformer:
def traverse(t: Tree)(using Context) = t match
case tree: InferredTypeTree =>
case tree: New =>
case tree: TypeTree => checkAppliedTypesIn(tree.withKnownType)
case tree: TypeTree => checkAppliedTypesIn(tree.withType(tree.nuType))
case _ => traverseChildren(t)
checkApplied.traverse(unit)
end postCheck
Expand Down
47 changes: 24 additions & 23 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import printing.{Printer, Texts}, Texts.{Text, Str}
import collection.mutable
import CCState.*
import dotty.tools.dotc.util.NoSourcePosition
import CheckCaptures.CheckerAPI

/** Operations accessed from CheckCaptures */
trait SetupAPI:
Expand All @@ -28,10 +29,9 @@ trait SetupAPI:

/** Setup procedure to run for each compilation unit
* @param tree the typed tree of the unit to check
* @param recheckDef the recheck method to run on completion of symbols with
* inferred (result-) types
* @param checker the capture checker which will run subsequently.
*/
def setupUnit(tree: Tree, recheckDef: DefRecheck)(using Context): Unit
def setupUnit(tree: Tree, checker: CheckerAPI)(using Context): Unit

/** Symbol is a term member of a class that was not capture checked
* The info of these symbols is made fluid.
Expand Down Expand Up @@ -378,15 +378,6 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
tp2
end transformExplicitType

/** Transform type of tree, and remember the transformed type as the type the tree */
private def transformTT(tree: TypeTree, boxed: Boolean)(using Context): Unit =
if !tree.hasRememberedType then
val transformed =
if tree.isInferred
then transformInferredType(tree.tpe)
else transformExplicitType(tree.tpe, tptToCheck = tree)
tree.rememberType(if boxed then box(transformed) else transformed)

/** Substitute parameter symbols in `from` to paramRefs in corresponding
* method or poly types `to`. We use a single BiTypeMap to do everything.
* @param from a list of lists of type or term parameter symbols of a curried method
Expand Down Expand Up @@ -436,7 +427,17 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
atPhase(thisPhase.next)(sym.info)

/** A traverser that adds knownTypes and updates symbol infos */
def setupTraverser(recheckDef: DefRecheck) = new TreeTraverserWithPreciseImportContexts:
def setupTraverser(checker: CheckerAPI) = new TreeTraverserWithPreciseImportContexts:
import checker.*

/** Transform type of tree, and remember the transformed type as the type the tree */
private def transformTT(tree: TypeTree, boxed: Boolean)(using Context): Unit =
if !tree.hasNuType then
val transformed =
if tree.isInferred
then transformInferredType(tree.tpe)
else transformExplicitType(tree.tpe, tptToCheck = tree)
tree.setNuType(if boxed then box(transformed) else transformed)

/** Transform the type of a val or var or the result type of a def */
def transformResultType(tpt: TypeTree, sym: Symbol)(using Context): Unit =
Expand Down Expand Up @@ -464,7 +465,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
traverse(parent)
case _ =>
traverseChildren(tp)
addDescription.traverse(tpt.knownType)
addDescription.traverse(tpt.nuType)
end transformResultType

def traverse(tree: Tree)(using Context): Unit =
Expand Down Expand Up @@ -504,7 +505,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:

case tree @ SeqLiteral(elems, tpt: TypeTree) =>
traverse(elems)
tpt.rememberType(box(transformInferredType(tpt.tpe)))
tpt.setNuType(box(transformInferredType(tpt.tpe)))

case tree: Block =>
inNestedLevel(traverseChildren(tree))
Expand Down Expand Up @@ -537,22 +538,22 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
// with special treatment for constructors.
def localReturnType =
if sym.isConstructor then constrReturnType(sym.info, sym.paramSymss)
else tree.tpt.knownType
else tree.tpt.nuType

// A test whether parameter signature might change. This returns true if one of
// the parameters has a remembered type. The idea here is that we store a remembered
// the parameters has a new type installee. The idea here is that we store a new
// type only if the transformed type is different from the original.
def paramSignatureChanges = tree.match
case tree: DefDef =>
tree.paramss.nestedExists:
case param: ValDef => param.tpt.hasRememberedType
case param: TypeDef => param.rhs.hasRememberedType
case param: ValDef => param.tpt.hasNuType
case param: TypeDef => param.rhs.hasNuType
case _ => false

// A symbol's signature changes if some of its parameter types or its result type
// have a new type installed here (meaning hasRememberedType is true)
def signatureChanges =
tree.tpt.hasRememberedType && !sym.isConstructor || paramSignatureChanges
tree.tpt.hasNuType && !sym.isConstructor || paramSignatureChanges

// Replace an existing symbol info with inferred types where capture sets of
// TypeParamRefs and TermParamRefs are put in correspondence by BiTypeMaps with the
Expand Down Expand Up @@ -616,7 +617,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
capt.println(i"forcing $sym, printing = ${ctx.mode.is(Mode.Printing)}")
//if ctx.mode.is(Mode.Printing) then new Error().printStackTrace()
denot.info = newInfo
recheckDef(tree, sym)
completeDef(tree, sym)
updateInfo(sym, updatedInfo)

case tree: Bind =>
Expand Down Expand Up @@ -833,8 +834,8 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
/** Run setup on a compilation unit with given `tree`.
* @param recheckDef the function to run for completing a val or def
*/
def setupUnit(tree: Tree, recheckDef: DefRecheck)(using Context): Unit =
setupTraverser(recheckDef).traverse(tree)(using ctx.withPhase(thisPhase))
def setupUnit(tree: Tree, checker: CheckerAPI)(using Context): Unit =
setupTraverser(checker).traverse(tree)(using ctx.withPhase(thisPhase))

// ------ Checks to run after main capture checking --------------------------

Expand Down
Loading

0 comments on commit 4285536

Please sign in to comment.