Skip to content

Commit 6180de2

Browse files
committed
Handle TypeProxy of Named Tuples, minimal fix without refactoring
1 parent 5176f9f commit 6180de2

File tree

7 files changed

+51
-8
lines changed

7 files changed

+51
-8
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1744,7 +1744,7 @@ object desugar {
17441744
def adaptPatternArgs(elems: List[Tree], pt: Type)(using Context): List[Tree] =
17451745

17461746
def reorderedNamedArgs(wildcardSpan: Span): List[untpd.Tree] =
1747-
var selNames = pt.namedTupleElementTypes.map(_(0))
1747+
var selNames = pt.namedTupleElementTypes(false).map(_(0))
17481748
if selNames.isEmpty && pt.classSymbol.is(CaseClass) then
17491749
selNames = pt.classSymbol.caseAccessors.map(_.name.asTermName)
17501750
val nameToIdx = selNames.zipWithIndex.toMap

compiler/src/dotty/tools/dotc/core/TypeUtils.scala

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,19 +127,33 @@ class TypeUtils:
127127
case Some(types) => TypeOps.nestedPairs(types)
128128
case None => throw new AssertionError("not a tuple")
129129

130-
def namedTupleElementTypesUpTo(bound: Int, normalize: Boolean = true)(using Context): List[(TermName, Type)] =
130+
def namedTupleElementTypesUpTo(bound: Int, derived: Boolean, normalize: Boolean = true)(using Context): List[(TermName, Type)] =
131131
(if normalize then self.normalized else self).dealias match
132132
case defn.NamedTuple(nmes, vals) =>
133133
val names = nmes.tupleElementTypesUpTo(bound, normalize).getOrElse(Nil).map(_.dealias).map:
134134
case ConstantType(Constant(str: String)) => str.toTermName
135135
case t => throw TypeError(em"Malformed NamedTuple: names must be string types, but $t was found.")
136136
val values = vals.tupleElementTypesUpTo(bound, normalize).getOrElse(Nil)
137137
names.zip(values)
138+
case tp: TypeProxy if derived =>
139+
tp.superType.namedTupleElementTypesUpTo(bound - 1, normalize)
140+
case tp: OrType if derived =>
141+
val lhs = tp.tp1.namedTupleElementTypesUpTo(bound - 1, normalize)
142+
val rhs = tp.tp2.namedTupleElementTypesUpTo(bound - 1, normalize)
143+
if (lhs.map(_._1) != rhs.map(_._1)) throw TypeError(em"Malformed Union Type: Named Tuple elements must be the same, but $lhs and $rhs were found.")
144+
lhs.zip(rhs).map((lhs, rhs) => (lhs._1, lhs._2 | rhs._2))
145+
case tp: AndType if derived =>
146+
(tp.tp1.namedTupleElementTypesUpTo(bound - 1, normalize), tp.tp2.namedTupleElementTypesUpTo(bound - 1, normalize)) match
147+
case (Nil, rhs) => rhs
148+
case (lhs, Nil) => lhs
149+
case (lhs, rhs) =>
150+
if (lhs.map(_._1) != rhs.map(_._1)) throw TypeError(em"Malformed Intersection Type: Named Tuple elements must be the same, but $lhs and $rhs were found.")
151+
lhs.zip(rhs).map((lhs, rhs) => (lhs._1, lhs._2 & rhs._2))
138152
case t =>
139153
Nil
140154

141-
def namedTupleElementTypes(using Context): List[(TermName, Type)] =
142-
namedTupleElementTypesUpTo(Int.MaxValue)
155+
def namedTupleElementTypes(derived: Boolean)(using Context): List[(TermName, Type)] =
156+
namedTupleElementTypesUpTo(Int.MaxValue, derived)
143157

144158
def isNamedTupleType(using Context): Boolean = self match
145159
case defn.NamedTuple(_, _) => true

compiler/src/dotty/tools/dotc/interactive/Completion.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ object Completion:
532532
def namedTupleCompletionsFromType(tpe: Type): CompletionMap =
533533
val freshCtx = ctx.fresh.setExploreTyperState()
534534
inContext(freshCtx):
535-
tpe.namedTupleElementTypes
535+
tpe.namedTupleElementTypes(true)
536536
.map { (name, tpe) =>
537537
val symbol = newSymbol(owner = NoSymbol, name, EmptyFlags, tpe)
538538
val denot = SymDenotation(symbol, NoSymbol, name, EmptyFlags, tpe)

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
248248
def appliedText(tp: Type): Text = tp match
249249
case tp @ AppliedType(tycon, args) =>
250250
val namedElems =
251-
try tp.namedTupleElementTypesUpTo(200, normalize = false)
251+
try tp.namedTupleElementTypesUpTo(200, false, normalize = false) // TODO: should the printer use derived or not?
252252
catch case ex: TypeError => Nil
253253
if namedElems.nonEmpty then
254254
toTextNamedTuple(namedElems)

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
799799

800800
// Otherwise, try to expand a named tuple selection
801801
def tryNamedTupleSelection() =
802-
val namedTupleElems = qual.tpe.widenDealias.namedTupleElementTypes
802+
val namedTupleElems = qual.tpe.widenDealias.namedTupleElementTypes(true)
803803
val nameIdx = namedTupleElems.indexWhere(_._1 == selName)
804804
if nameIdx >= 0 && Feature.enabled(Feature.namedTuples) then
805805
typed(
@@ -875,7 +875,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
875875
then
876876
val pre = if !TypeOps.isLegalPrefix(qual.tpe) then SkolemType(qual.tpe) else qual.tpe
877877
val fieldsType = pre.select(tpnme.Fields).widenDealias.simplified
878-
val fields = fieldsType.namedTupleElementTypes
878+
val fields = fieldsType.namedTupleElementTypes(true)
879879
typr.println(i"try dyn select $qual, $selName, $fields")
880880
fields.find(_._1 == selName) match
881881
case Some((_, fieldType)) =>

tests/run/i22150.check

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
0
2+
1
3+
2

tests/run/i22150.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//> using options -experimental -language:experimental.namedTuples
2+
import language.experimental.namedTuples
3+
4+
val directionsNT = IArray(
5+
(dx = 0, dy = 1), // up
6+
(dx = 1, dy = 0), // right
7+
(dx = 0, dy = -1), // down
8+
(dx = -1, dy = 0), // left
9+
)
10+
val IArray(UpNT @ _, _, _, _) = directionsNT
11+
12+
object NT:
13+
// def foo[T <: (x: Int, y: String)](tup: T): Int =
14+
// tup.x
15+
16+
def union[T](tup: (x: Int, y: String) | (x: Int, y: String)): Int =
17+
tup.x
18+
19+
def intersect[T](tup: (x: Int, y: String) & T): Int =
20+
tup.x
21+
22+
23+
@main def Test =
24+
println(UpNT.dx)
25+
println(NT.union((1, "a")))
26+
println(NT.intersect((2, "b")))

0 commit comments

Comments
 (0)