Skip to content

Commit

Permalink
Handle TypeProxy of Named Tuples in unapply (#22325)
Browse files Browse the repository at this point in the history
Fixes #22150.
Previously, there were several ways to check if something was a Named
Tuple (`derivesFromNamedTuple`, `isNamedTupleType` and
`NamedTuple.unapply`), this PR moves everything into
`NamedTuple.unapply`. `namedTupleElementTypes` now takes an argument
`derived` that when false will skip `unapply` (to avoid infinite
recursion, used in desugaring and RefinedPrinter where trees can have
invalid cycles).
  • Loading branch information
odersky authored Jan 13, 2025
2 parents 49839cd + 83ae00d commit c10def4
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 26 deletions.
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1744,7 +1744,7 @@ object desugar {
def adaptPatternArgs(elems: List[Tree], pt: Type)(using Context): List[Tree] =

def reorderedNamedArgs(wildcardSpan: Span): List[untpd.Tree] =
var selNames = pt.namedTupleElementTypes.map(_(0))
var selNames = pt.namedTupleElementTypes(false).map(_(0))
if selNames.isEmpty && pt.classSymbol.is(CaseClass) then
selNames = pt.classSymbol.caseAccessors.map(_.name.asTermName)
val nameToIdx = selNames.zipWithIndex.toMap
Expand Down
23 changes: 19 additions & 4 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1337,10 +1337,25 @@ class Definitions {
object NamedTuple:
def apply(nmes: Type, vals: Type)(using Context): Type =
AppliedType(NamedTupleTypeRef, nmes :: vals :: Nil)
def unapply(t: Type)(using Context): Option[(Type, Type)] = t match
case AppliedType(tycon, nmes :: vals :: Nil) if tycon.typeSymbol == NamedTupleTypeRef.symbol =>
Some((nmes, vals))
case _ => None
def unapply(t: Type)(using Context): Option[(Type, Type)] =
t match
case AppliedType(tycon, nmes :: vals :: Nil) if tycon.typeSymbol == NamedTupleTypeRef.symbol =>
Some((nmes, vals))
case tp: TypeProxy =>
val t = unapply(tp.superType); t
case tp: OrType =>
(unapply(tp.tp1), unapply(tp.tp2)) match
case (Some(lhsName, lhsVal), Some(rhsName, rhsVal)) if lhsName == rhsName =>
Some(lhsName, lhsVal | rhsVal)
case _ => None
case tp: AndType =>
(unapply(tp.tp1), unapply(tp.tp2)) match
case (Some(lhsName, lhsVal), Some(rhsName, rhsVal)) if lhsName == rhsName =>
Some(lhsName, lhsVal & rhsVal)
case (lhs, None) => lhs
case (None, rhs) => rhs
case _ => None
case _ => None

final def isCompiletime_S(sym: Symbol)(using Context): Boolean =
sym.name == tpnme.S && sym.owner == CompiletimeOpsIntModuleClass
Expand Down
24 changes: 12 additions & 12 deletions compiler/src/dotty/tools/dotc/core/TypeUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,17 @@ class TypeUtils:
case Some(types) => TypeOps.nestedPairs(types)
case None => throw new AssertionError("not a tuple")

def namedTupleElementTypesUpTo(bound: Int, normalize: Boolean = true)(using Context): List[(TermName, Type)] =
def namedTupleElementTypesUpTo(bound: Int, derived: Boolean, normalize: Boolean = true)(using Context): List[(TermName, Type)] =
(if normalize then self.normalized else self).dealias match
// for desugaring and printer, ignore derived types to avoid infinite recursion in NamedTuple.unapply
case AppliedType(tycon, nmes :: vals :: Nil) if !derived && tycon.typeSymbol == defn.NamedTupleTypeRef.symbol =>
val names = nmes.tupleElementTypesUpTo(bound, normalize).getOrElse(Nil).map(_.dealias).map:
case ConstantType(Constant(str: String)) => str.toTermName
case t => throw TypeError(em"Malformed NamedTuple: names must be string types, but $t was found.")
val values = vals.tupleElementTypesUpTo(bound, normalize).getOrElse(Nil)
names.zip(values)
case t if !derived => Nil
// default cause, used for post-typing
case defn.NamedTuple(nmes, vals) =>
val names = nmes.tupleElementTypesUpTo(bound, normalize).getOrElse(Nil).map(_.dealias).map:
case ConstantType(Constant(str: String)) => str.toTermName
Expand All @@ -138,22 +147,13 @@ class TypeUtils:
case t =>
Nil

def namedTupleElementTypes(using Context): List[(TermName, Type)] =
namedTupleElementTypesUpTo(Int.MaxValue)
def namedTupleElementTypes(derived: Boolean)(using Context): List[(TermName, Type)] =
namedTupleElementTypesUpTo(Int.MaxValue, derived)

def isNamedTupleType(using Context): Boolean = self match
case defn.NamedTuple(_, _) => true
case _ => false

def derivesFromNamedTuple(using Context): Boolean = self match
case defn.NamedTuple(_, _) => true
case tp: MatchType =>
tp.bound.derivesFromNamedTuple || tp.reduced.derivesFromNamedTuple
case tp: TypeProxy => tp.superType.derivesFromNamedTuple
case tp: AndType => tp.tp1.derivesFromNamedTuple || tp.tp2.derivesFromNamedTuple
case tp: OrType => tp.tp1.derivesFromNamedTuple && tp.tp2.derivesFromNamedTuple
case _ => false

/** Drop all named elements in tuple type */
def stripNamedTuple(using Context): Type = self.normalized.dealias match
case defn.NamedTuple(_, vals) =>
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/interactive/Completion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ object Completion:
def namedTupleCompletionsFromType(tpe: Type): CompletionMap =
val freshCtx = ctx.fresh.setExploreTyperState()
inContext(freshCtx):
tpe.namedTupleElementTypes
tpe.namedTupleElementTypes(true)
.map { (name, tpe) =>
val symbol = newSymbol(owner = NoSymbol, name, EmptyFlags, tpe)
val denot = SymDenotation(symbol, NoSymbol, name, EmptyFlags, tpe)
Expand All @@ -543,7 +543,7 @@ object Completion:
.groupByName

val qualTpe = qual.typeOpt
if qualTpe.derivesFromNamedTuple then
if qualTpe.isNamedTupleType then
namedTupleCompletionsFromType(qualTpe)
else if qualTpe.derivesFrom(defn.SelectableClass) then
val pre = if !TypeOps.isLegalPrefix(qualTpe) then Types.SkolemType(qualTpe) else qualTpe
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,9 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
def appliedText(tp: Type): Text = tp match
case tp @ AppliedType(tycon, args) =>
val namedElems =
try tp.namedTupleElementTypesUpTo(200, normalize = false)
catch case ex: TypeError => Nil
try tp.namedTupleElementTypesUpTo(200, false, normalize = false)
catch
case ex: TypeError => Nil
if namedElems.nonEmpty then
toTextNamedTuple(namedElems)
else tp.tupleElementTypesUpTo(200, normalize = false) match
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ object Applications {
}

def namedTupleOrProductTypes(tp: Type)(using Context): List[Type] =
if tp.isNamedTupleType then tp.namedTupleElementTypes.map(_(1))
if tp.isNamedTupleType then tp.namedTupleElementTypes(true).map(_(1))
else productSelectorTypes(tp, NoSourcePosition)

def productSelectorTypes(tp: Type, errorPos: SrcPos)(using Context): List[Type] = {
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ trait Implicits:
|| inferView(dummyTreeOfType(from), to)
(using ctx.fresh.addMode(Mode.ImplicitExploration).setExploreTyperState()).isSuccess
// TODO: investigate why we can't TyperState#test here
|| from.widen.derivesFromNamedTuple && to.derivesFrom(defn.TupleClass)
|| from.widen.isNamedTupleType && to.derivesFrom(defn.TupleClass)
&& from.widen.stripNamedTuple <:< to
)

Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer

// Otherwise, try to expand a named tuple selection
def tryNamedTupleSelection() =
val namedTupleElems = qual.tpe.widenDealias.namedTupleElementTypes
val namedTupleElems = qual.tpe.widenDealias.namedTupleElementTypes(true)
val nameIdx = namedTupleElems.indexWhere(_._1 == selName)
if nameIdx >= 0 && Feature.enabled(Feature.namedTuples) then
typed(
Expand Down Expand Up @@ -875,7 +875,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
then
val pre = if !TypeOps.isLegalPrefix(qual.tpe) then SkolemType(qual.tpe) else qual.tpe
val fieldsType = pre.select(tpnme.Fields).widenDealias.simplified
val fields = fieldsType.namedTupleElementTypes
val fields = fieldsType.namedTupleElementTypes(true)
typr.println(i"try dyn select $qual, $selName, $fields")
fields.find(_._1 == selName) match
case Some((_, fieldType)) =>
Expand Down Expand Up @@ -4663,7 +4663,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
case _: SelectionProto =>
tree // adaptations for selections are handled in typedSelect
case _ if ctx.mode.is(Mode.ImplicitsEnabled) && tree.tpe.isValueType =>
if tree.tpe.derivesFromNamedTuple && pt.derivesFrom(defn.TupleClass) then
if tree.tpe.isNamedTupleType && pt.derivesFrom(defn.TupleClass) then
readapt(typed(untpd.Select(untpd.TypedSplice(tree), nme.toTuple)))
else if pt.isRef(defn.AnyValClass, skipRefined = false)
|| pt.isRef(defn.ObjectClass, skipRefined = false)
Expand Down
3 changes: 3 additions & 0 deletions tests/run/i22150.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
0
1
2
26 changes: 26 additions & 0 deletions tests/run/i22150.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//> using options -experimental -language:experimental.namedTuples
import language.experimental.namedTuples

val directionsNT = IArray(
(dx = 0, dy = 1), // up
(dx = 1, dy = 0), // right
(dx = 0, dy = -1), // down
(dx = -1, dy = 0), // left
)
val IArray(UpNT @ _, _, _, _) = directionsNT

object NT:
def foo[T <: (x: Int, y: String)](tup: T): Int =
tup.x

def union[T](tup: (x: Int, y: String) | (x: Int, y: String)): Int =
tup.x

def intersect[T](tup: (x: Int, y: String) & T): Int =
tup.x


@main def Test =
println(UpNT.dx)
println(NT.union((1, "a")))
println(NT.intersect((2, "b")))

0 comments on commit c10def4

Please sign in to comment.