diff --git a/library/src/scala/internal/quoted/Matcher.scala b/library/src/scala/internal/quoted/Matcher.scala index afe1bb3ee5b7..945d2182ee6a 100644 --- a/library/src/scala/internal/quoted/Matcher.scala +++ b/library/src/scala/internal/quoted/Matcher.scala @@ -4,6 +4,135 @@ import scala.annotation.internal.sharable import scala.quoted._ +/** Matches a quoted tree against a quoted pattern tree. + * A quoted pattern tree may have type and term holes in addition to normal terms. + * + * + * Semantics: + * + * We use `'{..}` for expression, `'[..]` for types and `⟨..⟩` for patterns nested in expressions. + * The semantics are defined as a list of reduction rules that are tried one by one until one matches. + * + * Operations: + * - `s =?= p` checks if a scrutinee `s` matches the pattern `p` while accumulating extracted parts of the code. + * - `isColosedUnder(x1, .., xn)('{e})` returns true if and only if all the references in `e` to names defined in the patttern are contained in the set `{x1, ... xn}`. + * - `lift(x1, .., xn)('{e})` returns `(y1, ..., yn) => [xi = $yi]'{e}` where `yi` is an `Expr` of the type of `xi`. + * - `withEnv(x1 -> y1, ..., xn -> yn)(matching)` evaluates mathing recording that `xi` is equivalent to `yi`. + * - `matched` denotes that the the match succedded and `matched('{e})` denotes that a matech succeded and extracts `'{e}` + * - `&&&` matches if both sides match. Concatenates the extracted expressions of both sides. + * + * Note: that not all quoted terms bellow are valid expressions + * + * ```scala + * /* Term hole */ + * '{ e } =?= '{ hole[T] } && typeOf('{e}) <:< T && isColosedUnder()('{e}) ===> matched('{e}) + * + * /* Higher order term hole */ + * '{ e } =?= '{ hole[(T1, ..., Tn) => T](x1, ..., xn) } && isColosedUnder(x1, ... xn)('{e}) ===> matched(lift(x1, ..., xn)('{e})) + * + * /* Match literal */ + * '{ lit } =?= '{ lit } ===> matched + * + * /* Match type ascription (a) */ + * '{ e: T } =?= '{ p } ===> '{e} =?= '{p} + * + * /* Match type ascription (b) */ + * '{ e } =?= '{ p: P } ===> '{e} =?= '{p} + * + * /* Match selection */ + * '{ e.x } =?= '{ p.x } ===> '{e} =?= '{p} + * + * /* Match reference */ + * '{ x } =?= '{ x } ===> matched + * + * /* Match application */ + * '{e0(e1, ..., en)} =?= '{p0(p1, ..., p2)} ===> '{e0} =?= '{p0} &&& '{e1} =?= '{p1} &&& ... %% '{en} =?= '{pn} + * + * /* Match type application */ + * '{e[T1, ..., Tn]} =?= '{p[P1, ..., Pn]} ===> '{e} =?= '{p} &&& '[T1] =?= '{P1} &&& ... %% '[Tn] =?= '[Pn] + * + * /* Match block flattening */ + * '{ {e0; e1; ...; en}; em } =?= '{ {p0; p1; ...; pm}; em } ===> '{ e0; {e1; ...; en; em} } =?= '{ p0; {p1; ...; pm; em} } + * + * /* Match block */ + * '{ e1; e2 } =?= '{ p1; p2 } ===> '{e1} =?= '{p1} &&& '{e2} =?= '{p2} + * + * /* Match def block */ + * '{ e1; e2 } =?= '{ p1; p2 } ===> withEnv(symOf(e1) -> symOf(p1))('{e1} =?= '{p1} &&& '{e2} =?= '{p2}) + * + * /* Match if */ + * '{ if e0 then e1 else e2 } =?= '{ if p0 then p1 else p2 } ===> '{e0} =?= '{p0} &&& '{e1} =?= '{p1} &&& '{e2} =?= '{p2} + * + * /* Match while */ + * '{ while e0 do e1 } =?= '{ while p0 do p1 } ===> '{e0} =?= '{p0} &&& '{e1} =?= '{p1} + * + * /* Match assign */ + * '{ e0 = e1 } =?= '{ p0 = p1 } && '{e0} =?= '{p0} ===> '{e1} =?= '{p1} + * + * /* Match new */ + * '{ new T } =?= '{ new T } ===> matched + * + * /* Match this */ + * '{ C.this } =?= '{ C.this } ===> matched + * + * /* Match super */ + * '{ e.super } =?= '{ p.super } ===> '{e} =?= '{p} + * + * /* Match varargs */ + * '{ e: _* } =?= '{ p: _* } ===> '{e} =?= '{p} + * + * /* Match val */ + * '{ val x: T = e1; e2 } =?= '{ val y: P = p1; p2 } ===> withEnv(x -> y)('[T] =?= '[P] &&& '{e1} =?= '{p1} &&& '{e2} =?= '{p2}) + * + * /* Match def */ + * '{ def x0(x1: T1, ..., xn: Tn): T0 = e1; e2 } =?= '{ def y0(y1: P1, ..., yn: Pn): P0 = p1; p2 } ===> withEnv(x0 -> y0, ..., xn -> yn)('[T0] =?= '[P0] &&& ... &&& '[Tn] =?= '[Pn] &&& '{e1} =?= '{p1} &&& '{e2} =?= '{p2}) + * + * /* Match match */ + * '{ e0 match { case u1 => e1; ...; case un => en } } =?= '{ p0 match { case q1 => p1; ...; case qn => pn } } ===> + * '{e0} =?= '{p0} &&& ... &&& '{en} =?= '{pn} &&& '⟨u1⟩ =?= '⟨q1⟩ &&& ... &&& '⟨un⟩ =?= '⟨qn⟩ + * + * /* Match try */ + * '{ try e0 catch { case u1 => e1; ...; case un => en } finally ef } =?= '{ try p0 catch { case q1 => p1; ...; case qn => pn } finally pf } ===> '{e0} =?= '{p0} &&& ... &&& '{en} =?= '{pn} &&& '⟨u1⟩ =?= '⟨q1⟩ &&& ... &&& '⟨un⟩ =?= '⟨qn⟩ &&& '{ef} =?= '{pf} + * + * // Types + * + * /* Match type */ + * '[T] =?= '[P] && T <:< P ===> matched + * + * /* Match applied type */ + * '[ T0[T1, ..., Tn] ] =?= '[ P0[P1, ..., Pn] ] ===> '[T0] =?= '[P0] &&& ... &&& '[Tn] =?= '[Pn] + * + * /* Match annot (a) */ + * '[T @annot] =?= '[P] ===> '[T] =?= '[P] + * + * /* Match annot (b) */ + * '[T] =?= '[P @annot] ===> '[T] =?= '[P] + * + * // Patterns + * + * /* Match pattern whildcard */ + * '⟨ _ ⟩ =?= '⟨ _ ⟩ ===> matched + * + * /* Match pattern bind */ + * '⟨ x @ e ⟩ =?= '⟨ y @ p ⟩ ===> withEnv(x -> y)('⟨e⟩ =?= '⟨p⟩) + * + * /* Match pattern unapply */ + * '⟨ e0(e1, ..., en)(using i1, ..., im ) ⟩ =?= '⟨ p0(p1, ..., pn)(using q1, ..., 1m) ⟩ ===> '⟨e0⟩ =?= '⟨p0⟩ &&& ... &&& '⟨en⟩ =?= '⟨pn⟩ &&& '{i1} =?= '{q1} &&& ... &&& '{im} =?= '{qm} + * + * /* Match pattern alternatives */ + * '⟨ e1 | ... | en ⟩ =?= '⟨ p1 | ... | pn ⟩ ===> '⟨e1⟩ =?= '⟨p1⟩ &&& ... &&& '⟨en⟩ =?= '⟨pn⟩ + * + * /* Match pattern type test */ + * '⟨ e: T ⟩ =?= '⟨ p: U ⟩ ===> '⟨e⟩ =?= '⟨p⟩ &&& '[T] =?= [U] + * + * /* Match pattern ref */ + * '⟨ `x` ⟩ =?= '⟨ `x` ⟩ ===> matched + * + * /* Match pattern ref splice */ + * '⟨ `x` ⟩ =?= '⟨ hole ⟩ ===> matched('{`x`}) + * + * ``` + */ private[quoted] object Matcher { class QuoteMatcher[QCtx <: QuoteContext & Singleton](using val qctx: QCtx) { @@ -83,15 +212,15 @@ private[quoted] object Matcher { case annot => annot.symbol.owner == internal.Definitions_InternalQuoted_fromAboveAnnot } - /** Check that all trees match with `mtch` and concatenate the results with && */ + /** Check that all trees match with `mtch` and concatenate the results with &&& */ private def matchLists[T](l1: List[T], l2: List[T])(mtch: (T, T) => Matching): Matching = (l1, l2) match { - case (x :: xs, y :: ys) => mtch(x, y) && matchLists(xs, ys)(mtch) + case (x :: xs, y :: ys) => mtch(x, y) &&& matchLists(xs, ys)(mtch) case (Nil, Nil) => matched case _ => notMatched } private extension treeListOps on (scrutinees: List[Tree]) { - /** Check that all trees match with =?= and concatenate the results with && */ + /** Check that all trees match with =?= and concatenate the results with &&& */ def =?= (patterns: List[Tree])(using Context, Env): Matching = matchLists(scrutinees, patterns)(_ =?= _) } @@ -108,6 +237,7 @@ private[quoted] object Matcher { */ def =?= (pattern0: Tree)(using Context, Env): Matching = { + /* Match block flattening */ // TODO move to cases /** Normalize the tree */ def normalize(tree: Tree): Tree = tree match { case Block(Nil, expr) => normalize(expr) @@ -129,6 +259,7 @@ private[quoted] object Matcher { (scrutinee, pattern) match { + /* Term hole */ // Match a scala.internal.Quoted.patternHole typed as a repeated argument and return the scrutinee tree case (scrutinee @ Typed(s, tpt1), Typed(TypeApply(patternHole, tpt :: Nil), tpt2)) if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole && @@ -136,13 +267,16 @@ private[quoted] object Matcher { tpt2.tpe.derivesFrom(defn.RepeatedParamClass) => matched(scrutinee.seal) + /* Term hole */ // Match a scala.internal.Quoted.patternHole and return the scrutinee tree case (ClosedPatternTerm(scrutinee), TypeApply(patternHole, tpt :: Nil)) if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole && scrutinee.tpe <:< tpt.tpe => matched(scrutinee.seal) + /* Higher order term hole */ // Matches an open term and wraps it into a lambda that provides the free variables + // TODO do not encode with `hole`. Maybe use `higherOrderHole[(T1, ..., Tn) => R]((x1: T1, ..., xn: Tn)): R` case (scrutinee, pattern @ Apply(Select(TypeApply(patternHole, List(Inferred())), "apply"), args0 @ IdentArgs(args))) if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole => def bodyFn(lambdaArgs: List[Tree]): Tree = { @@ -164,34 +298,47 @@ private[quoted] object Matcher { // Match two equivalent trees // + /* Match literal */ case (Literal(constant1), Literal(constant2)) if constant1 == constant2 => matched + /* Match type ascription (a) */ case (Typed(expr1, _), pattern) => expr1 =?= pattern + /* Match type ascription (b) */ case (scrutinee, Typed(expr2, _)) => scrutinee =?= expr2 - case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || summon[Env].get(scrutinee.symbol).contains(pattern.symbol) => - matched + /* Match selection */ case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol => qual1 =?= qual2 + /* Match reference */ + // TODO could be subsumed by the next case + case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || summon[Env].get(scrutinee.symbol).contains(pattern.symbol) => + matched + + /* Match reference */ case (_: Ref, _: Ref) if scrutinee.symbol == pattern.symbol => matched + /* Match application */ + // TODO may not need to check the symbol (done in fn1 =?= fn2) case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol || summon[Env].get(fn1.symbol).contains(fn2.symbol) => - fn1 =?= fn2 && args1 =?= args2 + fn1 =?= fn2 &&& args1 =?= args2 + /* Match type application */ + // TODO may not need to check the symbol (done in fn1 =?= fn2) case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol || summon[Env].get(fn1.symbol).contains(fn2.symbol) => - fn1 =?= fn2 && args1 =?= args2 + fn1 =?= fn2 &&& args1 =?= args2 case (Block(stats1, expr1), Block(binding :: stats2, expr2)) if isTypeBinding(binding) => qctx.tasty.internal.Context_GADT_addToConstraint(summon[Context])(binding.symbol :: Nil) - matched(new SymBinding(binding.symbol, hasFromAboveAnnotation(binding.symbol))) && Block(stats1, expr1) =?= Block(stats2, expr2) + matched(new SymBinding(binding.symbol, hasFromAboveAnnotation(binding.symbol))) &&& Block(stats1, expr1) =?= Block(stats2, expr2) + /* Match block */ case (Block(stat1 :: stats1, expr1), Block(stat2 :: stats2, expr2)) => val newEnv = (stat1, stat2) match { case (stat1: Definition, stat2: Definition) => @@ -200,48 +347,62 @@ private[quoted] object Matcher { summon[Env] } withEnv(newEnv) { - stat1 =?= stat2 && Block(stats1, expr1) =?= Block(stats2, expr2) + stat1 =?= stat2 &&& Block(stats1, expr1) =?= Block(stats2, expr2) } case (scrutinee, Block(typeBindings, expr2)) if typeBindings.forall(isTypeBinding) => val bindingSymbols = typeBindings.map(_.symbol) qctx.tasty.internal.Context_GADT_addToConstraint(summon[Context])(bindingSymbols) - bindingSymbols.foldRight(scrutinee =?= expr2)((x, acc) => matched(new SymBinding(x, hasFromAboveAnnotation(x))) && acc) + bindingSymbols.foldRight(scrutinee =?= expr2)((x, acc) => matched(new SymBinding(x, hasFromAboveAnnotation(x))) &&& acc) + /* Match if */ case (If(cond1, thenp1, elsep1), If(cond2, thenp2, elsep2)) => - cond1 =?= cond2 && thenp1 =?= thenp2 && elsep1 =?= elsep2 + cond1 =?= cond2 &&& thenp1 =?= thenp2 &&& elsep1 =?= elsep2 + /* Match while */ + case (While(cond1, body1), While(cond2, body2)) => + cond1 =?= cond2 &&& body1 =?= body2 + + /* Match assign */ case (Assign(lhs1, rhs1), Assign(lhs2, rhs2)) => val lhsMatch = if ((lhs1 =?= lhs2).isMatch) matched else notMatched - lhsMatch && rhs1 =?= rhs2 - - case (While(cond1, body1), While(cond2, body2)) => - cond1 =?= cond2 && body1 =?= body2 + // TODO lhs1 =?= lhs2 &&& rhs1 =?= rhs2 + lhsMatch &&& rhs1 =?= rhs2 + /* Match new */ case (New(tpt1), New(tpt2)) if tpt1.tpe.typeSymbol == tpt2.tpe.typeSymbol => matched + /* Match this */ case (This(_), This(_)) if scrutinee.symbol == pattern.symbol => matched + /* Match super */ case (Super(qual1, mix1), Super(qual2, mix2)) if mix1 == mix2 => qual1 =?= qual2 + /* Match varargs */ case (Repeated(elems1, _), Repeated(elems2, _)) if elems1.size == elems2.size => elems1 =?= elems2 + /* Match type */ + // TODO remove this? case (scrutinee: TypeTree, pattern: TypeTree) if scrutinee.tpe <:< pattern.tpe => matched + /* Match applied type */ + // TODO remove this? case (Applied(tycon1, args1), Applied(tycon2, args2)) => - tycon1 =?= tycon2 && args1 =?= args2 + tycon1 =?= tycon2 &&& args1 =?= args2 + /* Match val */ case (ValDef(_, tpt1, rhs1), ValDef(_, tpt2, rhs2)) if checkValFlags() => def rhsEnv = summon[Env] + (scrutinee.symbol -> pattern.symbol) - tpt1 =?= tpt2 && treeOptMatches(rhs1, rhs2)(using summon[Context], rhsEnv) + tpt1 =?= tpt2 &&& treeOptMatches(rhs1, rhs2)(using summon[Context], rhsEnv) + /* Match def */ case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) => def rhsEnv = val oldEnv: Env = summon[Env] @@ -251,23 +412,28 @@ private[quoted] object Matcher { oldEnv ++ newEnv typeParams1 =?= typeParams2 - && matchLists(paramss1, paramss2)(_ =?= _) - && tpt1 =?= tpt2 - && withEnv(rhsEnv)(rhs1 =?= rhs2) + &&& matchLists(paramss1, paramss2)(_ =?= _) + &&& tpt1 =?= tpt2 + &&& withEnv(rhsEnv)(rhs1 =?= rhs2) case (Closure(_, tpt1), Closure(_, tpt2)) => // TODO match tpt1 with tpt2? matched + /* Match match */ case (Match(scru1, cases1), Match(scru2, cases2)) => - scru1 =?= scru2 && matchLists(cases1, cases2)(caseMatches) + scru1 =?= scru2 &&& matchLists(cases1, cases2)(caseMatches) + /* Match try */ case (Try(body1, cases1, finalizer1), Try(body2, cases2, finalizer2)) => - body1 =?= body2 && matchLists(cases1, cases2)(caseMatches) && treeOptMatches(finalizer1, finalizer2) + body1 =?= body2 &&& matchLists(cases1, cases2)(caseMatches) &&& treeOptMatches(finalizer1, finalizer2) // Ignore type annotations + // TODO remove this + /* Match annot (a) */ case (Annotated(tpt, _), _) => tpt =?= pattern + /* Match annot (b) */ case (_, Annotated(tpt, _)) => scrutinee =?= tpt @@ -336,9 +502,9 @@ private[quoted] object Matcher { private def caseMatches(scrutinee: CaseDef, pattern: CaseDef)(using Context, Env): Matching = { val (caseEnv, patternMatch) = patternsMatches(scrutinee.pattern, pattern.pattern) withEnv(caseEnv) { - patternMatch && - treeOptMatches(scrutinee.guard, pattern.guard) && - scrutinee.rhs =?= pattern.rhs + patternMatch + &&& treeOptMatches(scrutinee.guard, pattern.guard) + &&& scrutinee.rhs =?= pattern.rhs } } @@ -354,24 +520,30 @@ private[quoted] object Matcher { * `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes. */ private def patternsMatches(scrutinee: Tree, pattern: Tree)(using Context, Env): (Env, Matching) = (scrutinee, pattern) match { + /* Match pattern ref splice */ case (v1: Term, Unapply(TypeApply(Select(patternHole @ Ident("patternHole"), "unapply"), List(tpt)), Nil, Nil)) if patternHole.symbol.owner == summon[Context].requiredModule("scala.runtime.quoted.Matcher") => (summon[Env], matched(v1.seal)) + /* Match pattern whildcard */ case (Ident("_"), Ident("_")) => (summon[Env], matched) + /* Match pattern bind */ case (Bind(name1, body1), Bind(name2, body2)) => val bindEnv = summon[Env] + (scrutinee.symbol -> pattern.symbol) patternsMatches(body1, body2)(using summon[Context], bindEnv) + /* Match pattern unapply */ case (Unapply(fun1, implicits1, patterns1), Unapply(fun2, implicits2, patterns2)) => val (patEnv, patternsMatch) = foldPatterns(patterns1, patterns2) - (patEnv, patternsMatches(fun1, fun2)._2 && implicits1 =?= implicits2 && patternsMatch) + (patEnv, patternsMatches(fun1, fun2)._2 &&& implicits1 =?= implicits2 &&& patternsMatch) + /* Match pattern alternatives */ case (Alternatives(patterns1), Alternatives(patterns2)) => foldPatterns(patterns1, patterns2) + /* Match pattern type test */ case (Typed(Ident("_"), tpt1), Typed(Ident("_"), tpt2)) => (summon[Env], tpt1 =?= tpt2) @@ -403,7 +575,7 @@ private[quoted] object Matcher { if (patterns1.size != patterns2.size) (summon[Env], notMatched) else patterns1.zip(patterns2).foldLeft((summon[Env], matched)) { (acc, x) => val (env, res) = patternsMatches(x._1, x._2)(using summon[Context], acc._1) - (env, acc._2 && res) + (env, acc._2 &&& res) } } @@ -425,7 +597,7 @@ private[quoted] object Matcher { def (self: Matching) asOptionOfTuple: Option[Tuple] = self /** Concatenates the contents of two successful matchings or return a `notMatched` */ - def (self: Matching) && (that: => Matching): Matching = self match { + def (self: Matching) &&& (that: => Matching): Matching = self match { case Some(x) => that match { case Some(y) => Some(x ++ y)