diff --git a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala index a152ec3ed981..b87686996ebd 100644 --- a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala +++ b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala @@ -1,4 +1,5 @@ -package dotty.tools.dotc +package dotty.tools +package dotc package transform import core._ @@ -7,6 +8,7 @@ import MegaPhase._ import SymUtils._ import NullOpsDecorator._ import ast.Trees._ +import ast.untpd import reporting._ import dotty.tools.dotc.util.Spans.Span @@ -103,78 +105,73 @@ class ExpandSAMs extends MiniPhase: * ``` */ private def toPartialFunction(tree: Block, tpe: Type)(using Context): Tree = { - /** An extractor for match, either contained in a block or standalone. */ - object PartialFunctionRHS { - def unapply(tree: Tree): Option[Match] = tree match { - case Block(Nil, expr) => unapply(expr) - case m: Match => Some(m) - case _ => None - } - } - val closureDef(anon @ DefDef(_, List(List(param)), _, _)) = tree - anon.rhs match { - case PartialFunctionRHS(pf) => - val anonSym = anon.symbol - val anonTpe = anon.tpe.widen - val parents = List( - defn.AbstractPartialFunctionClass.typeRef.appliedTo(anonTpe.firstParamTypes.head, anonTpe.resultType), - defn.SerializableType) - val pfSym = newNormalizedClassSymbol(anonSym.owner, tpnme.ANON_CLASS, Synthetic | Final, parents, coord = tree.span) - - def overrideSym(sym: Symbol) = sym.copy( - owner = pfSym, - flags = Synthetic | Method | Final | Override, - info = tpe.memberInfo(sym), - coord = tree.span).asTerm.entered - val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt) - val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse) - - def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree)(using Context) = { - val selector = tree.selector - val selectorTpe = selector.tpe.widen - val defaultSym = newSymbol(pfParam.owner, nme.WILDCARD, Synthetic | Case, selectorTpe) - val defaultCase = - CaseDef( - Bind(defaultSym, Underscore(selectorTpe)), - EmptyTree, - defaultValue) - val unchecked = selector.annotated(New(ref(defn.UncheckedAnnot.typeRef))) - cpy.Match(tree)(unchecked, cases :+ defaultCase) - .subst(param.symbol :: Nil, pfParam :: Nil) - // Needed because a partial function can be written as: - // param => param match { case "foo" if foo(param) => param } - // And we need to update all references to 'param' - } - - def isDefinedAtRhs(paramRefss: List[List[Tree]])(using Context) = { - val tru = Literal(Constant(true)) - def translateCase(cdef: CaseDef) = - cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn) - val paramRef = paramRefss.head.head - val defaultValue = Literal(Constant(false)) - translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue) - } - - def applyOrElseRhs(paramRefss: List[List[Tree]])(using Context) = { - val List(paramRef, defaultRef) = paramRefss(1) - def translateCase(cdef: CaseDef) = - cdef.changeOwner(anonSym, applyOrElseFn) - val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef) - translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue) - } - - val constr = newConstructor(pfSym, Synthetic, Nil, Nil).entered - val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn)))) - val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn)))) - val pfDef = ClassDef(pfSym, DefDef(constr), List(isDefinedAtDef, applyOrElseDef)) - cpy.Block(tree)(pfDef :: Nil, New(pfSym.typeRef, Nil)) + // The right hand side from which to construct the partial function. This is always a Match. + // If the original rhs is already a Match (possibly in braces), return that. + // Otherwise construct a match `x match case _ => rhs` where `x` is the parameter of the closure. + def partialFunRHS(tree: Tree): Match = tree match + case m: Match => m + case Block(Nil, expr) => partialFunRHS(expr) case _ => - val found = tpe.baseType(defn.Function1) - report.error(TypeMismatch(found, tpe), tree.srcPos) - tree + Match(ref(param.symbol), + CaseDef(untpd.Ident(nme.WILDCARD).withType(param.symbol.info), EmptyTree, tree) :: Nil) + + val pfRHS = partialFunRHS(anon.rhs) + val anonSym = anon.symbol + val anonTpe = anon.tpe.widen + val parents = List( + defn.AbstractPartialFunctionClass.typeRef.appliedTo(anonTpe.firstParamTypes.head, anonTpe.resultType), + defn.SerializableType) + val pfSym = newNormalizedClassSymbol(anonSym.owner, tpnme.ANON_CLASS, Synthetic | Final, parents, coord = tree.span) + + def overrideSym(sym: Symbol) = sym.copy( + owner = pfSym, + flags = Synthetic | Method | Final | Override, + info = tpe.memberInfo(sym), + coord = tree.span).asTerm.entered + val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt) + val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse) + + def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree)(using Context) = { + val selector = tree.selector + val selectorTpe = selector.tpe.widen + val defaultSym = newSymbol(pfParam.owner, nme.WILDCARD, Synthetic | Case, selectorTpe) + val defaultCase = + CaseDef( + Bind(defaultSym, Underscore(selectorTpe)), + EmptyTree, + defaultValue) + val unchecked = selector.annotated(New(ref(defn.UncheckedAnnot.typeRef))) + cpy.Match(tree)(unchecked, cases :+ defaultCase) + .subst(param.symbol :: Nil, pfParam :: Nil) + // Needed because a partial function can be written as: + // param => param match { case "foo" if foo(param) => param } + // And we need to update all references to 'param' + } + + def isDefinedAtRhs(paramRefss: List[List[Tree]])(using Context) = { + val tru = Literal(Constant(true)) + def translateCase(cdef: CaseDef) = + cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn) + val paramRef = paramRefss.head.head + val defaultValue = Literal(Constant(false)) + translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue) + } + + def applyOrElseRhs(paramRefss: List[List[Tree]])(using Context) = { + val List(paramRef, defaultRef) = paramRefss(1) + def translateCase(cdef: CaseDef) = + cdef.changeOwner(anonSym, applyOrElseFn) + val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef) + translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue) } + + val constr = newConstructor(pfSym, Synthetic, Nil, Nil).entered + val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn)))) + val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn)))) + val pfDef = ClassDef(pfSym, DefDef(constr), List(isDefinedAtDef, applyOrElseDef)) + cpy.Block(tree)(pfDef :: Nil, New(pfSym.typeRef, Nil)) } private def checkRefinements(tpe: Type, tree: Tree)(using Context): Type = tpe.dealias match { diff --git a/tests/neg/i4241.scala b/tests/neg/i4241.scala deleted file mode 100644 index 3d93a44a015a..000000000000 --- a/tests/neg/i4241.scala +++ /dev/null @@ -1,12 +0,0 @@ -class Test { - def test: Unit = { - val a: PartialFunction[Int, Int] = { case x => x } - val b: PartialFunction[Int, Int] = x => x match { case 1 => 1; case _ => 2 } - val c: PartialFunction[Int, Int] = x => { x match { case y => y } } - val d: PartialFunction[Int, Int] = x => { { x match { case y => y } } } - - val e: PartialFunction[Int, Int] = x => { println("foo"); x match { case y => y } } // error - val f: PartialFunction[Int, Int] = x => x // error - val g: PartialFunction[Int, String] = { x => x.toString } // error - } -} diff --git a/tests/run/i4241.scala b/tests/run/i4241.scala new file mode 100644 index 000000000000..c55cb5be475f --- /dev/null +++ b/tests/run/i4241.scala @@ -0,0 +1,24 @@ +object Test extends App { + val a: PartialFunction[Int, Int] = { case x => x } + val b: PartialFunction[Int, Int] = x => x match { case 1 => 1; case 2 => 2 } + val c: PartialFunction[Int, Int] = x => { x match { case 1 => 1 } } + val d: PartialFunction[Int, Int] = x => { { x match { case 1 => 1 } } } + + val e: PartialFunction[Int, Int] = x => { println("foo"); x match { case 1 => 1 } } + val f: PartialFunction[Int, Int] = x => x + val g: PartialFunction[Int, String] = { x => x.toString } + val h: PartialFunction[Int, String] = _.toString + assert(a.isDefinedAt(2)) + assert(b.isDefinedAt(2)) + assert(!b.isDefinedAt(3)) + assert(c.isDefinedAt(1)) + assert(!c.isDefinedAt(2)) + assert(d.isDefinedAt(1)) + assert(!d.isDefinedAt(2)) + assert(e.isDefinedAt(2)) + assert(f.isDefinedAt(2)) + assert(g.isDefinedAt(2)) + assert(h.isDefinedAt(2)) +} + +