From 227ab22bd03f7a59ecadc49bf3aa311bdaff73ba Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Fri, 22 Apr 2022 10:31:28 +0200 Subject: [PATCH] compute mirror child types of a union --- .../dotty/tools/dotc/typer/Synthesizer.scala | 40 +++++++++++------ tests/pos/i13493.scala | 43 +++++++++++++++++++ 2 files changed, 71 insertions(+), 12 deletions(-) create mode 100644 tests/pos/i13493.scala diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index 6b930f705809..f634efdec33b 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -17,6 +17,8 @@ import transform.SyntheticMembers._ import util.Property import annotation.{tailrec, constructorOnly} +import scala.collection.mutable + /** Synthesize terms for special classes */ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): import ast.tpd._ @@ -339,7 +341,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): if acceptable(mirroredType) && cls.isGenericSum(if useCompanion then cls.linkedClass else ctx.owner) then val elemLabels = cls.children.map(c => ConstantType(Constant(c.name.toString))) - def solve(sym: Symbol): Type = sym match + def solve(target: Type)(sym: Symbol): Type = sym match case childClass: ClassSymbol => assert(childClass.isOneOf(Case | Sealed)) if childClass.is(Module) then @@ -350,23 +352,36 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): // Compute the the full child type by solving the subtype constraint // `C[X1, ..., Xn] <: P`, where // - // - P is the current `mirroredType` + // - P is the current `targetPart` // - C is the child class, with type parameters X1, ..., Xn // // Contravariant type parameters are minimized, all other type parameters are maximized. - def instantiate(using Context) = - val poly = constrained(info, untpd.EmptyTree)._1 + def instantiate(targetPart: Type)(using Context) = + val poly = constrained(info) val resType = poly.finalResultType - val target = mirroredType match - case tp: HKTypeLambda => tp.resultType - case tp => tp - resType <:< target + resType <:< targetPart // record constraints val tparams = poly.paramRefs val variances = childClass.typeParams.map(_.paramVarianceSign) val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) => TypeComparer.instanceType(tparam, fromBelow = variance < 0)) resType.substParams(poly, instanceTypes) - instantiate(using ctx.fresh.setExploreTyperState().setOwner(childClass)) + + def instantiateAll(using Context): Type = + + // instantiate for each part of a union type, compute lub of the results + def loop(explore: List[Type], acc: mutable.ListBuffer[Type]): Type = explore match + case OrType(tp1, tp2) :: rest => loop(tp1 :: tp2 :: rest, acc ) + case tp :: rest => loop(rest , acc += instantiate(tp)) + case _ => TypeComparer.lub(acc.toList) + + def instantiateLub(tp1: Type, tp2: Type): Type = + loop(tp1 :: tp2 :: Nil, new mutable.ListBuffer[Type]) + + target match + case OrType(tp1, tp2) => instantiateLub(tp1, tp2) + case _ => instantiate(target) + + instantiateAll(using ctx.fresh.setExploreTyperState().setOwner(childClass)) case _ => childClass.typeRef case child => child.termRef @@ -374,12 +389,13 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): val (monoType, elemsType) = mirroredType match case mirroredType: HKTypeLambda => + val target = mirroredType.resultType val elems = mirroredType.derivedLambdaType( - resType = TypeOps.nestedPairs(cls.children.map(solve)) + resType = TypeOps.nestedPairs(cls.children.map(solve(target))) ) (mkMirroredMonoType(mirroredType), elems) - case _ => - val elems = TypeOps.nestedPairs(cls.children.map(solve)) + case target => + val elems = TypeOps.nestedPairs(cls.children.map(solve(target))) (mirroredType, elems) val mirrorType = diff --git a/tests/pos/i13493.scala b/tests/pos/i13493.scala new file mode 100644 index 000000000000..873ec74e9085 --- /dev/null +++ b/tests/pos/i13493.scala @@ -0,0 +1,43 @@ +import deriving.Mirror + +sealed trait Box[T] +object Box + +case class Child[T](t: T) extends Box[T] + +object MirrorK1: + type Of[F[_]] = Mirror { type MirroredType[A] = F[A] } + +def testSums = + + val foo = summon[Mirror.Of[Option[Int] | Option[String]]] + summon[foo.MirroredElemTypes =:= (None.type, Some[Int] | Some[String])] + + val bar = summon[Mirror.Of[Box[Int] | Box[String]]] + summon[bar.MirroredElemTypes =:= ((Child[Int] | Child[String]) *: EmptyTuple)] + + val qux = summon[Mirror.Of[Option[Int | String]]] + summon[qux.MirroredElemTypes =:= (None.type, Some[Int | String])] + + val bip = summon[Mirror.Of[Box[Int | String]]] + summon[bip.MirroredElemTypes =:= (Child[Int | String] *: EmptyTuple)] + + val bap = summon[MirrorK1.Of[[X] =>> Box[X] | Box[Int] | Box[String]]] + summon[bap.MirroredElemTypes[Boolean] =:= ((Child[Boolean] | Child[Int] | Child[String]) *: EmptyTuple)] + + +def testProducts = + val foo = summon[Mirror.Of[Some[Int] | Some[String]]] + summon[foo.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)] + + val bar = summon[Mirror.Of[Child[Int] | Child[String]]] + summon[bar.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)] + + val qux = summon[Mirror.Of[Some[Int | String]]] + summon[foo.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)] + + val bip = summon[Mirror.Of[Child[Int | String]]] + summon[bip.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)] + + val bap = summon[MirrorK1.Of[[X] =>> Child[X] | Child[Int] | Child[String]]] + summon[bap.MirroredElemTypes[Boolean] =:= ((Boolean | Int | String) *: EmptyTuple)]