diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 573143a0e2cf..88c024dae041 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -40,7 +40,11 @@ object desugar { def derivedType(sym: Symbol)(implicit ctx: Context) = sym.typeRef } - class DerivedFromParamTree extends DerivedTypeTree { + /** A type tree that computes its type from an existing parameter. + * @param suffix String difference between existing parameter (call it `P`) and parameter owning the + * DerivedTypeTree (call it `O`). We have: `O.name == P.name + suffix`. + */ + class DerivedFromParamTree(suffix: String) extends DerivedTypeTree { /** Make sure that for all enclosing module classes their companion lasses * are completed. Reason: We need the constructor of such companion classes to @@ -58,12 +62,16 @@ object desugar { /** Return info of original symbol, where all references to siblings of the * original symbol (i.e. sibling and original symbol have the same owner) - * are rewired to same-named parameters or accessors in the scope enclosing + * are rewired to like-named* parameters or accessors in the scope enclosing * the current scope. The current scope is the scope owned by the defined symbol * itself, that's why we have to look one scope further out. If the resulting * type is an alias type, dealias it. This is necessary because the * accessor of a type parameter is a private type alias that cannot be accessed * from subclasses. + * + * (*) like-named means: + * + * parameter name == reference name ++ suffix */ def derivedType(sym: Symbol)(implicit ctx: Context) = { val relocate = new TypeMap { @@ -71,11 +79,11 @@ object desugar { def apply(tp: Type) = tp match { case tp: NamedType if tp.symbol.exists && (tp.symbol.owner eq originalOwner) => val defctx = ctx.outersIterator.dropWhile(_.scope eq ctx.scope).next() - var local = defctx.denotNamed(tp.name).suchThat(_.isParamOrAccessor).symbol + var local = defctx.denotNamed(tp.name ++ suffix).suchThat(_.isParamOrAccessor).symbol if (local.exists) (defctx.owner.thisType select local).dealias else { def msg = - s"no matching symbol for ${tp.symbol.showLocated} in ${defctx.owner} / ${defctx.effectiveScope}" + s"no matching symbol for ${tp.symbol.showLocated} in ${defctx.owner} / ${defctx.effectiveScope.toList}" if (ctx.reporter.errorsReported) ErrorType(msg) else throw new java.lang.Error(msg) } @@ -88,14 +96,20 @@ object desugar { } /** A type definition copied from `tdef` with a rhs typetree derived from it */ - def derivedTypeParam(tdef: TypeDef) = + def derivedTypeParam(tdef: TypeDef, suffix: String = ""): TypeDef = cpy.TypeDef(tdef)( - rhs = new DerivedFromParamTree() withPos tdef.rhs.pos watching tdef) + name = tdef.name ++ suffix, + rhs = new DerivedFromParamTree(suffix).withPos(tdef.rhs.pos).watching(tdef) + ) + + /** A derived type definition watching `sym` */ + def derivedTypeParam(sym: TypeSymbol)(implicit ctx: Context): TypeDef = + TypeDef(sym.name, new DerivedFromParamTree("").watching(sym)).withFlags(TypeParam) /** A value definition copied from `vdef` with a tpt typetree derived from it */ def derivedTermParam(vdef: ValDef) = cpy.ValDef(vdef)( - tpt = new DerivedFromParamTree() withPos vdef.tpt.pos watching vdef) + tpt = new DerivedFromParamTree("") withPos vdef.tpt.pos watching vdef) // ----- Desugar methods ------------------------------------------------- @@ -317,8 +331,8 @@ object desugar { } def anyRef = ref(defn.AnyRefAlias.typeRef) - val derivedTparams = constrTparams map derivedTypeParam - val derivedVparamss = constrVparamss nestedMap derivedTermParam + val derivedTparams = constrTparams.map(derivedTypeParam(_)) + val derivedVparamss = constrVparamss.nestedMap(derivedTermParam(_)) val arity = constrVparamss.head.length val classTycon: Tree = new TypeRefTree // watching is set at end of method @@ -419,9 +433,8 @@ object desugar { // ev1: Eq[T1$1, T1$2], ..., evn: Eq[Tn$1, Tn$2]]) // : Eq[C[T1$1, ..., Tn$1], C[T1$2, ..., Tn$2]] = Eq def eqInstance = { - def append(tdef: TypeDef, str: String) = cpy.TypeDef(tdef)(name = tdef.name ++ str) - val leftParams = derivedTparams.map(append(_, "$1")) - val rightParams = derivedTparams.map(append(_, "$2")) + val leftParams = constrTparams.map(derivedTypeParam(_, "$1")) + val rightParams = constrTparams.map(derivedTypeParam(_, "$2")) val subInstances = (leftParams, rightParams).zipped.map((param1, param2) => appliedRef(ref(defn.EqType), List(param1, param2))) DefDef( @@ -456,19 +469,16 @@ object desugar { // For all other classes, the parent is AnyRef. val companions = if (isCaseClass) { - def extractType(t: Tree): Tree = t match { - case Apply(t1, _) => extractType(t1) - case TypeApply(t1, ts) => AppliedTypeTree(extractType(t1), ts) - case Select(t1, nme.CONSTRUCTOR) => extractType(t1) - case New(t1) => t1 - case t1 => t1 - } // The return type of the `apply` method - val applyResultTpt = - if (isEnumCase) - if (parents.isEmpty) enumClassTypeRef - else parents.map(extractType).reduceLeft(AndTypeTree) - else TypeTree() + val (applyResultTpt, widenDefs) = + if (!isEnumCase) + (TypeTree(), Nil) + else if (parents.isEmpty || enumClass.typeParams.isEmpty) + (enumClassTypeRef, Nil) + else { + val tparams = enumClass.typeParams.map(derivedTypeParam) + enumApplyResult(cdef, parents, tparams, appliedRef(enumClassRef, tparams)) + } val parent = if (constrTparams.nonEmpty || @@ -479,11 +489,13 @@ object desugar { // todo: also use anyRef if constructor has a dependent method type (or rule that out)! (constrVparamss :\ (if (isEnumCase) applyResultTpt else classTypeRef)) ( (vparams, restpe) => Function(vparams map (_.tpt), restpe)) + def widenedCreatorExpr = + (creatorExpr /: widenDefs)((rhs, meth) => Apply(Ident(meth.name), rhs :: Nil)) val applyMeths = if (mods is Abstract) Nil else - DefDef(nme.apply, derivedTparams, derivedVparamss, applyResultTpt, creatorExpr) - .withFlags(Synthetic | (constr1.mods.flags & DefaultParameterized)) :: Nil + DefDef(nme.apply, derivedTparams, derivedVparamss, applyResultTpt, widenedCreatorExpr) + .withFlags(Synthetic | (constr1.mods.flags & DefaultParameterized)) :: widenDefs val unapplyMeth = { val unapplyParam = makeSyntheticParameter(tpt = classTypeRef) val unapplyRHS = if (arity == 0) Literal(Constant(true)) else Ident(unapplyParam.name) diff --git a/compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala b/compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala index bd9e24f76dc7..4d1dda5b8930 100644 --- a/compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala +++ b/compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala @@ -120,6 +120,52 @@ object DesugarEnums { TypeTree(), creator) } + /** The return type of an enum case apply method and any widening methods in which + * the apply's right hand side will be wrapped. For parents of the form + * + * extends E(args) with T1(args1) with ... TN(argsN) + * + * and type parameters `tparams` the generated widen method is + * + * def C$to$E[tparams](x$1: E[tparams] with T1 with ... TN) = x$1 + * + * @param cdef The case definition + * @param parents The declared parents of the enum case + * @param tparams The type parameters of the enum case + * @param appliedEnumRef The enum class applied to `tparams`. + */ + def enumApplyResult( + cdef: TypeDef, + parents: List[Tree], + tparams: List[TypeDef], + appliedEnumRef: Tree)(implicit ctx: Context): (Tree, List[DefDef]) = { + + def extractType(t: Tree): Tree = t match { + case Apply(t1, _) => extractType(t1) + case TypeApply(t1, ts) => AppliedTypeTree(extractType(t1), ts) + case Select(t1, nme.CONSTRUCTOR) => extractType(t1) + case New(t1) => t1 + case t1 => t1 + } + + val parentTypes = parents.map(extractType) + parentTypes.head match { + case parent: RefTree if parent.name == enumClass.name => + // need a widen method to compute correct type parameters for enum base class + val widenParamType = (appliedEnumRef /: parentTypes.tail)(AndTypeTree) + val widenParam = makeSyntheticParameter(tpt = widenParamType) + val widenDef = DefDef( + name = s"${cdef.name}$$to$$${enumClass.name}".toTermName, + tparams = tparams, + vparamss = (widenParam :: Nil) :: Nil, + tpt = TypeTree(), + rhs = Ident(widenParam.name)) + (TypeTree(), widenDef :: Nil) + case _ => + (parentTypes.reduceLeft(AndTypeTree), Nil) + } + } + /** A pair consisting of * - the next enum tag * - scaffolding containing the necessary definitions for singleton enum cases diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index cd6952ed16ab..7452fa9179e7 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -227,6 +227,12 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { this } + /** Install the derived type tree as a dependency on `sym` */ + def watching(sym: Symbol): this.type = { + pushAttachment(OriginalSymbol, sym) + this + } + /** A hook to ensure that all necessary symbols are completed so that * OriginalSymbol attachments are propagated to this tree */ @@ -240,7 +246,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { * from the symbol in this type. These type trees have marker trees * TypeRefOfSym or InfoOfSym as their originals. */ - val References = new Property.Key[List[Tree]] + val References = new Property.Key[List[DerivedTypeTree]] /** Property key for TypeTrees marked with TypeRefOfSym or InfoOfSym * which contains the symbol of the original tree from which this diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index 6201afd271b3..e9dcd406716d 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -228,11 +228,7 @@ class Namer { typer: Typer => /** Record `sym` as the symbol defined by `tree` */ def recordSym(sym: Symbol, tree: Tree)(implicit ctx: Context): Symbol = { - val refs = tree.attachmentOrElse(References, Nil) - if (refs.nonEmpty) { - tree.removeAttachment(References) - refs foreach (_.pushAttachment(OriginalSymbol, sym)) - } + for (refs <- tree.removeAttachment(References); ref <- refs) ref.watching(sym) tree.pushAttachment(SymOfTree, sym) sym } diff --git a/tests/pos/i2663.scala b/tests/pos/i2663.scala new file mode 100644 index 000000000000..5c69eb5f7b4a --- /dev/null +++ b/tests/pos/i2663.scala @@ -0,0 +1,30 @@ +trait Tr +enum Foo[T](x: T) { + case Bar[T](y: T) extends Foo(y) + case Bas[T](y: Int) extends Foo(y) + case Bam[T](y: String) extends Foo(y) with Tr + case Baz[S, T](y: String) extends Foo(y) with Tr +} +object Test { + import Foo._ + val bar: Foo[Boolean] = Bar(true) + val bas: Foo[Int] = Bas(1) + val bam: Foo[String] & Tr = Bam("") + val baz: Foo[String] & Tr = Baz("") +} + +enum Foo2[S <: T, T](x1: S, x2: T) { + case Bar[T](y: T) extends Foo2(y, y) + case Bas[T](y: Int) extends Foo2(y, y) + case Bam[T](y: String) extends Foo2(y, y) with Tr + case Baz[S, T](y: String) extends Foo2(y, y) with Tr +} +object Test2 { + import Foo2._ + val bar: Foo2[Boolean, Boolean] = Bar(true) + val bas: Foo2[Int, Int] = Bas(1) + val bam: Foo2[String, String] & Tr = Bam("") + val baz: Foo2[String, String] & Tr = Baz("") +} + +