Skip to content

Fix 13493: compute union of child types for mirror #15007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import transform.SyntheticMembers._
import util.Property
import annotation.{tailrec, constructorOnly}

import scala.collection.mutable

/** Synthesize terms for special classes */
class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
import ast.tpd._
Expand Down Expand Up @@ -339,7 +341,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
if acceptable(mirroredType) && cls.isGenericSum(if useCompanion then cls.linkedClass else ctx.owner) then
val elemLabels = cls.children.map(c => ConstantType(Constant(c.name.toString)))

def solve(sym: Symbol): Type = sym match
def solve(target: Type)(sym: Symbol): Type = sym match
case childClass: ClassSymbol =>
assert(childClass.isOneOf(Case | Sealed))
if childClass.is(Module) then
Expand All @@ -350,36 +352,50 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
// Compute the the full child type by solving the subtype constraint
// `C[X1, ..., Xn] <: P`, where
//
// - P is the current `mirroredType`
// - P is the current `targetPart`
// - C is the child class, with type parameters X1, ..., Xn
//
// Contravariant type parameters are minimized, all other type parameters are maximized.
def instantiate(using Context) =
val poly = constrained(info, untpd.EmptyTree)._1
def instantiate(targetPart: Type)(using Context) =
val poly = constrained(info)
val resType = poly.finalResultType
val target = mirroredType match
case tp: HKTypeLambda => tp.resultType
case tp => tp
resType <:< target
resType <:< targetPart // record constraints
val tparams = poly.paramRefs
val variances = childClass.typeParams.map(_.paramVarianceSign)
val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) =>
TypeComparer.instanceType(tparam, fromBelow = variance < 0))
resType.substParams(poly, instanceTypes)
instantiate(using ctx.fresh.setExploreTyperState().setOwner(childClass))

def instantiateAll(using Context): Type =

// instantiate for each part of a union type, compute lub of the results
def loop(explore: List[Type], acc: mutable.ListBuffer[Type]): Type = explore match
case OrType(tp1, tp2) :: rest => loop(tp1 :: tp2 :: rest, acc )
case tp :: rest => loop(rest , acc += instantiate(tp))
case _ => TypeComparer.lub(acc.toList)

def instantiateLub(tp1: Type, tp2: Type): Type =
loop(tp1 :: tp2 :: Nil, new mutable.ListBuffer[Type])

target match
case OrType(tp1, tp2) => instantiateLub(tp1, tp2)
case _ => instantiate(target)

instantiateAll(using ctx.fresh.setExploreTyperState().setOwner(childClass))
case _ =>
childClass.typeRef
case child => child.termRef
end solve

val (monoType, elemsType) = mirroredType match
case mirroredType: HKTypeLambda =>
val target = mirroredType.resultType
val elems = mirroredType.derivedLambdaType(
resType = TypeOps.nestedPairs(cls.children.map(solve))
resType = TypeOps.nestedPairs(cls.children.map(solve(target)))
)
(mkMirroredMonoType(mirroredType), elems)
case _ =>
val elems = TypeOps.nestedPairs(cls.children.map(solve))
case target =>
val elems = TypeOps.nestedPairs(cls.children.map(solve(target)))
(mirroredType, elems)

val mirrorType =
Expand Down
43 changes: 43 additions & 0 deletions tests/pos/i13493.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import deriving.Mirror

sealed trait Box[T]
object Box

case class Child[T](t: T) extends Box[T]

object MirrorK1:
type Of[F[_]] = Mirror { type MirroredType[A] = F[A] }

def testSums =

val foo = summon[Mirror.Of[Option[Int] | Option[String]]]
summon[foo.MirroredElemTypes =:= (None.type, Some[Int] | Some[String])]

val bar = summon[Mirror.Of[Box[Int] | Box[String]]]
summon[bar.MirroredElemTypes =:= ((Child[Int] | Child[String]) *: EmptyTuple)]

val qux = summon[Mirror.Of[Option[Int | String]]]
summon[qux.MirroredElemTypes =:= (None.type, Some[Int | String])]

val bip = summon[Mirror.Of[Box[Int | String]]]
summon[bip.MirroredElemTypes =:= (Child[Int | String] *: EmptyTuple)]

val bap = summon[MirrorK1.Of[[X] =>> Box[X] | Box[Int] | Box[String]]]
summon[bap.MirroredElemTypes[Boolean] =:= ((Child[Boolean] | Child[Int] | Child[String]) *: EmptyTuple)]


def testProducts =
val foo = summon[Mirror.Of[Some[Int] | Some[String]]]
summon[foo.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)]

val bar = summon[Mirror.Of[Child[Int] | Child[String]]]
summon[bar.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)]

val qux = summon[Mirror.Of[Some[Int | String]]]
summon[foo.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)]

val bip = summon[Mirror.Of[Child[Int | String]]]
summon[bip.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)]

val bap = summon[MirrorK1.Of[[X] =>> Child[X] | Child[Int] | Child[String]]]
summon[bap.MirroredElemTypes[Boolean] =:= ((Boolean | Int | String) *: EmptyTuple)]