Skip to content

Commit c10def4

Browse files
authored
Handle TypeProxy of Named Tuples in unapply (#22325)
Fixes #22150. Previously, there were several ways to check if something was a Named Tuple (`derivesFromNamedTuple`, `isNamedTupleType` and `NamedTuple.unapply`), this PR moves everything into `NamedTuple.unapply`. `namedTupleElementTypes` now takes an argument `derived` that when false will skip `unapply` (to avoid infinite recursion, used in desugaring and RefinedPrinter where trees can have invalid cycles).
2 parents 49839cd + 83ae00d commit c10def4

File tree

10 files changed

+71
-26
lines changed

10 files changed

+71
-26
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -1744,7 +1744,7 @@ object desugar {
17441744
def adaptPatternArgs(elems: List[Tree], pt: Type)(using Context): List[Tree] =
17451745

17461746
def reorderedNamedArgs(wildcardSpan: Span): List[untpd.Tree] =
1747-
var selNames = pt.namedTupleElementTypes.map(_(0))
1747+
var selNames = pt.namedTupleElementTypes(false).map(_(0))
17481748
if selNames.isEmpty && pt.classSymbol.is(CaseClass) then
17491749
selNames = pt.classSymbol.caseAccessors.map(_.name.asTermName)
17501750
val nameToIdx = selNames.zipWithIndex.toMap

compiler/src/dotty/tools/dotc/core/Definitions.scala

+19-4
Original file line numberDiff line numberDiff line change
@@ -1337,10 +1337,25 @@ class Definitions {
13371337
object NamedTuple:
13381338
def apply(nmes: Type, vals: Type)(using Context): Type =
13391339
AppliedType(NamedTupleTypeRef, nmes :: vals :: Nil)
1340-
def unapply(t: Type)(using Context): Option[(Type, Type)] = t match
1341-
case AppliedType(tycon, nmes :: vals :: Nil) if tycon.typeSymbol == NamedTupleTypeRef.symbol =>
1342-
Some((nmes, vals))
1343-
case _ => None
1340+
def unapply(t: Type)(using Context): Option[(Type, Type)] =
1341+
t match
1342+
case AppliedType(tycon, nmes :: vals :: Nil) if tycon.typeSymbol == NamedTupleTypeRef.symbol =>
1343+
Some((nmes, vals))
1344+
case tp: TypeProxy =>
1345+
val t = unapply(tp.superType); t
1346+
case tp: OrType =>
1347+
(unapply(tp.tp1), unapply(tp.tp2)) match
1348+
case (Some(lhsName, lhsVal), Some(rhsName, rhsVal)) if lhsName == rhsName =>
1349+
Some(lhsName, lhsVal | rhsVal)
1350+
case _ => None
1351+
case tp: AndType =>
1352+
(unapply(tp.tp1), unapply(tp.tp2)) match
1353+
case (Some(lhsName, lhsVal), Some(rhsName, rhsVal)) if lhsName == rhsName =>
1354+
Some(lhsName, lhsVal & rhsVal)
1355+
case (lhs, None) => lhs
1356+
case (None, rhs) => rhs
1357+
case _ => None
1358+
case _ => None
13441359

13451360
final def isCompiletime_S(sym: Symbol)(using Context): Boolean =
13461361
sym.name == tpnme.S && sym.owner == CompiletimeOpsIntModuleClass

compiler/src/dotty/tools/dotc/core/TypeUtils.scala

+12-12
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,17 @@ class TypeUtils:
127127
case Some(types) => TypeOps.nestedPairs(types)
128128
case None => throw new AssertionError("not a tuple")
129129

130-
def namedTupleElementTypesUpTo(bound: Int, normalize: Boolean = true)(using Context): List[(TermName, Type)] =
130+
def namedTupleElementTypesUpTo(bound: Int, derived: Boolean, normalize: Boolean = true)(using Context): List[(TermName, Type)] =
131131
(if normalize then self.normalized else self).dealias match
132+
// for desugaring and printer, ignore derived types to avoid infinite recursion in NamedTuple.unapply
133+
case AppliedType(tycon, nmes :: vals :: Nil) if !derived && tycon.typeSymbol == defn.NamedTupleTypeRef.symbol =>
134+
val names = nmes.tupleElementTypesUpTo(bound, normalize).getOrElse(Nil).map(_.dealias).map:
135+
case ConstantType(Constant(str: String)) => str.toTermName
136+
case t => throw TypeError(em"Malformed NamedTuple: names must be string types, but $t was found.")
137+
val values = vals.tupleElementTypesUpTo(bound, normalize).getOrElse(Nil)
138+
names.zip(values)
139+
case t if !derived => Nil
140+
// default cause, used for post-typing
132141
case defn.NamedTuple(nmes, vals) =>
133142
val names = nmes.tupleElementTypesUpTo(bound, normalize).getOrElse(Nil).map(_.dealias).map:
134143
case ConstantType(Constant(str: String)) => str.toTermName
@@ -138,22 +147,13 @@ class TypeUtils:
138147
case t =>
139148
Nil
140149

141-
def namedTupleElementTypes(using Context): List[(TermName, Type)] =
142-
namedTupleElementTypesUpTo(Int.MaxValue)
150+
def namedTupleElementTypes(derived: Boolean)(using Context): List[(TermName, Type)] =
151+
namedTupleElementTypesUpTo(Int.MaxValue, derived)
143152

144153
def isNamedTupleType(using Context): Boolean = self match
145154
case defn.NamedTuple(_, _) => true
146155
case _ => false
147156

148-
def derivesFromNamedTuple(using Context): Boolean = self match
149-
case defn.NamedTuple(_, _) => true
150-
case tp: MatchType =>
151-
tp.bound.derivesFromNamedTuple || tp.reduced.derivesFromNamedTuple
152-
case tp: TypeProxy => tp.superType.derivesFromNamedTuple
153-
case tp: AndType => tp.tp1.derivesFromNamedTuple || tp.tp2.derivesFromNamedTuple
154-
case tp: OrType => tp.tp1.derivesFromNamedTuple && tp.tp2.derivesFromNamedTuple
155-
case _ => false
156-
157157
/** Drop all named elements in tuple type */
158158
def stripNamedTuple(using Context): Type = self.normalized.dealias match
159159
case defn.NamedTuple(_, vals) =>

compiler/src/dotty/tools/dotc/interactive/Completion.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ object Completion:
532532
def namedTupleCompletionsFromType(tpe: Type): CompletionMap =
533533
val freshCtx = ctx.fresh.setExploreTyperState()
534534
inContext(freshCtx):
535-
tpe.namedTupleElementTypes
535+
tpe.namedTupleElementTypes(true)
536536
.map { (name, tpe) =>
537537
val symbol = newSymbol(owner = NoSymbol, name, EmptyFlags, tpe)
538538
val denot = SymDenotation(symbol, NoSymbol, name, EmptyFlags, tpe)
@@ -543,7 +543,7 @@ object Completion:
543543
.groupByName
544544

545545
val qualTpe = qual.typeOpt
546-
if qualTpe.derivesFromNamedTuple then
546+
if qualTpe.isNamedTupleType then
547547
namedTupleCompletionsFromType(qualTpe)
548548
else if qualTpe.derivesFrom(defn.SelectableClass) then
549549
val pre = if !TypeOps.isLegalPrefix(qualTpe) then Types.SkolemType(qualTpe) else qualTpe

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

+3-2
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,9 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
248248
def appliedText(tp: Type): Text = tp match
249249
case tp @ AppliedType(tycon, args) =>
250250
val namedElems =
251-
try tp.namedTupleElementTypesUpTo(200, normalize = false)
252-
catch case ex: TypeError => Nil
251+
try tp.namedTupleElementTypesUpTo(200, false, normalize = false)
252+
catch
253+
case ex: TypeError => Nil
253254
if namedElems.nonEmpty then
254255
toTextNamedTuple(namedElems)
255256
else tp.tupleElementTypesUpTo(200, normalize = false) match

compiler/src/dotty/tools/dotc/typer/Applications.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ object Applications {
110110
}
111111

112112
def namedTupleOrProductTypes(tp: Type)(using Context): List[Type] =
113-
if tp.isNamedTupleType then tp.namedTupleElementTypes.map(_(1))
113+
if tp.isNamedTupleType then tp.namedTupleElementTypes(true).map(_(1))
114114
else productSelectorTypes(tp, NoSourcePosition)
115115

116116
def productSelectorTypes(tp: Type, errorPos: SrcPos)(using Context): List[Type] = {

compiler/src/dotty/tools/dotc/typer/Implicits.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,7 @@ trait Implicits:
876876
|| inferView(dummyTreeOfType(from), to)
877877
(using ctx.fresh.addMode(Mode.ImplicitExploration).setExploreTyperState()).isSuccess
878878
// TODO: investigate why we can't TyperState#test here
879-
|| from.widen.derivesFromNamedTuple && to.derivesFrom(defn.TupleClass)
879+
|| from.widen.isNamedTupleType && to.derivesFrom(defn.TupleClass)
880880
&& from.widen.stripNamedTuple <:< to
881881
)
882882

compiler/src/dotty/tools/dotc/typer/Typer.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
799799

800800
// Otherwise, try to expand a named tuple selection
801801
def tryNamedTupleSelection() =
802-
val namedTupleElems = qual.tpe.widenDealias.namedTupleElementTypes
802+
val namedTupleElems = qual.tpe.widenDealias.namedTupleElementTypes(true)
803803
val nameIdx = namedTupleElems.indexWhere(_._1 == selName)
804804
if nameIdx >= 0 && Feature.enabled(Feature.namedTuples) then
805805
typed(
@@ -875,7 +875,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
875875
then
876876
val pre = if !TypeOps.isLegalPrefix(qual.tpe) then SkolemType(qual.tpe) else qual.tpe
877877
val fieldsType = pre.select(tpnme.Fields).widenDealias.simplified
878-
val fields = fieldsType.namedTupleElementTypes
878+
val fields = fieldsType.namedTupleElementTypes(true)
879879
typr.println(i"try dyn select $qual, $selName, $fields")
880880
fields.find(_._1 == selName) match
881881
case Some((_, fieldType)) =>
@@ -4663,7 +4663,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
46634663
case _: SelectionProto =>
46644664
tree // adaptations for selections are handled in typedSelect
46654665
case _ if ctx.mode.is(Mode.ImplicitsEnabled) && tree.tpe.isValueType =>
4666-
if tree.tpe.derivesFromNamedTuple && pt.derivesFrom(defn.TupleClass) then
4666+
if tree.tpe.isNamedTupleType && pt.derivesFrom(defn.TupleClass) then
46674667
readapt(typed(untpd.Select(untpd.TypedSplice(tree), nme.toTuple)))
46684668
else if pt.isRef(defn.AnyValClass, skipRefined = false)
46694669
|| pt.isRef(defn.ObjectClass, skipRefined = false)

tests/run/i22150.check

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
0
2+
1
3+
2

tests/run/i22150.scala

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//> using options -experimental -language:experimental.namedTuples
2+
import language.experimental.namedTuples
3+
4+
val directionsNT = IArray(
5+
(dx = 0, dy = 1), // up
6+
(dx = 1, dy = 0), // right
7+
(dx = 0, dy = -1), // down
8+
(dx = -1, dy = 0), // left
9+
)
10+
val IArray(UpNT @ _, _, _, _) = directionsNT
11+
12+
object NT:
13+
def foo[T <: (x: Int, y: String)](tup: T): Int =
14+
tup.x
15+
16+
def union[T](tup: (x: Int, y: String) | (x: Int, y: String)): Int =
17+
tup.x
18+
19+
def intersect[T](tup: (x: Int, y: String) & T): Int =
20+
tup.x
21+
22+
23+
@main def Test =
24+
println(UpNT.dx)
25+
println(NT.union((1, "a")))
26+
println(NT.intersect((2, "b")))

0 commit comments

Comments
 (0)