diff --git a/compiler/src/dotty/tools/dotc/transform/FirstTransform.scala b/compiler/src/dotty/tools/dotc/transform/FirstTransform.scala index 8d01d2415340..8fc9f02c1e38 100644 --- a/compiler/src/dotty/tools/dotc/transform/FirstTransform.scala +++ b/compiler/src/dotty/tools/dotc/transform/FirstTransform.scala @@ -19,10 +19,14 @@ import NameKinds.OuterSelectName import StdNames.* import config.Feature import inlines.Inlines.inInlineMethod +import util.Property object FirstTransform { val name: String = "firstTransform" val description: String = "some transformations to put trees into a canonical form" + + /** Attachment key for named argument patterns */ + val WasNamedArg: Property.StickyKey[Unit] = Property.StickyKey() } /** The first tree transform @@ -38,6 +42,7 @@ object FirstTransform { */ class FirstTransform extends MiniPhase with SymTransformer { thisPhase => import ast.tpd.* + import FirstTransform.* override def phaseName: String = FirstTransform.name @@ -156,7 +161,13 @@ class FirstTransform extends MiniPhase with SymTransformer { thisPhase => override def transformOther(tree: Tree)(using Context): Tree = tree match { case tree: Export => EmptyTree - case tree: NamedArg => transformAllDeep(tree.arg) + case tree: NamedArg => + val res = transformAllDeep(tree.arg) + if ctx.mode.is(Mode.Pattern) then + // Need to keep NamedArg status for pattern matcher to work correctly when faced + // with single-element named tuples. + res.pushAttachment(WasNamedArg, ()) + res case tree => if (tree.isType) toTypeTree(tree) else tree } diff --git a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala index 250d4844d2b3..e2505144abda 100644 --- a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -386,9 +386,20 @@ object PatternMatcher { } else letAbstract(get) { getResult => - val selectors = - if (args.tail.isEmpty) ref(getResult) :: Nil - else productSelectors(getResult.info).map(ref(getResult).select(_)) + def isUnaryNamedTupleSelectArg(arg: Tree) = + get.tpe.widenDealias.isNamedTupleType + && arg.removeAttachment(FirstTransform.WasNamedArg).isDefined + // Special case: Normally, we pull out the argument wholesale if + // there is only one. But if the argument is a named argument for + // a single-element named tuple, we have to select the field instead. + // NamedArg trees are eliminated in FirstTransform but for named arguments + // of patterns we add a WasNamedArg attachment, which is used to guide the + // logic here. See i22900.scala for test cases. + val selectors = args match + case arg :: Nil if !isUnaryNamedTupleSelectArg(arg) => + ref(getResult) :: Nil + case _ => + productSelectors(getResult.info).map(ref(getResult).select(_)) matchArgsPlan(selectors, args, onSuccess) } } diff --git a/compiler/src/dotty/tools/dotc/transform/patmat/Space.scala b/compiler/src/dotty/tools/dotc/transform/patmat/Space.scala index 30b892ece470..ab5885e6278c 100644 --- a/compiler/src/dotty/tools/dotc/transform/patmat/Space.scala +++ b/compiler/src/dotty/tools/dotc/transform/patmat/Space.scala @@ -279,7 +279,7 @@ object SpaceEngine { || unappResult <:< ConstantType(Constant(true)) // only for unapply || (unapp.symbol.is(Synthetic) && unapp.symbol.owner.linkedClass.is(Case)) // scala2 compatibility || unapplySeqTypeElemTp(unappResult).exists // only for unapplySeq - || isProductMatch(unappResult, argLen) + || isProductMatch(unappResult.stripNamedTuple, argLen) || extractorMemberType(unappResult, nme.isEmpty, NoSourcePosition) <:< ConstantType(Constant(false)) || unappResult.derivesFrom(defn.NonEmptyTupleClass) || unapp.symbol == defn.TupleXXL_unapplySeq // Fixes TupleXXL.unapplySeq which returns Some but declares Option diff --git a/compiler/src/dotty/tools/dotc/typer/Checking.scala b/compiler/src/dotty/tools/dotc/typer/Checking.scala index ec07fefc64ab..2d6817f74ff7 100644 --- a/compiler/src/dotty/tools/dotc/typer/Checking.scala +++ b/compiler/src/dotty/tools/dotc/typer/Checking.scala @@ -1036,6 +1036,8 @@ trait Checking { pats.forall(recur(_, pt)) case Typed(arg, tpt) => check(pat, pt) && recur(arg, pt) + case NamedArg(name, pat) => + recur(pat, pt) case Ident(nme.WILDCARD) => true case pat: QuotePattern => diff --git a/tests/run/i22900.check b/tests/run/i22900.check new file mode 100644 index 000000000000..f6636139a701 --- /dev/null +++ b/tests/run/i22900.check @@ -0,0 +1,8 @@ +6 +6 +6 +6 +7 +6 +7 +(6) diff --git a/tests/run/i22900.scala b/tests/run/i22900.scala new file mode 100644 index 000000000000..f7786d32d717 --- /dev/null +++ b/tests/run/i22900.scala @@ -0,0 +1,26 @@ +object NameBaseExtractor { + def unapply(x: Int): Some[(someName: Int)] = Some((someName = x + 3)) +} +object NameBaseExtractor2 { + def unapply(x: Int): Some[(someName: Int, age: Int)] = Some((someName = x + 3, age = x + 4)) +} +@main +def Test = + val x1 = 3 match + case NameBaseExtractor(someName = x) => x + println(x1) + val NameBaseExtractor(someName = x2) = 3 + println(x2) + val NameBaseExtractor((someName = x3)) = 3 + println(x3) + + val NameBaseExtractor2(someName = x4, age = x5) = 3 + println(x4) + println(x5) + + val NameBaseExtractor2((someName = x6, age = x7)) = 3 + println(x6) + println(x7) + + val NameBaseExtractor(y1) = 3 + println(y1) diff --git a/tests/run/i22900a.check b/tests/run/i22900a.check new file mode 100644 index 000000000000..a94217f352f9 --- /dev/null +++ b/tests/run/i22900a.check @@ -0,0 +1,3 @@ +3 +6 +3 diff --git a/tests/run/i22900a.scala b/tests/run/i22900a.scala new file mode 100644 index 000000000000..301deeecdf13 --- /dev/null +++ b/tests/run/i22900a.scala @@ -0,0 +1,15 @@ +case class C(someName: Int) + +object NameBaseExtractor3 { + def unapply(x: Int): Some[C] = Some(C(someName = x + 3)) +} + +@main +def Test = { + val C(someName = xx) = C(3) + println(xx) + val NameBaseExtractor3(C(someName = x)) = 3 + println(x) + C(3) match + case C(someName = xx) => println(xx) +} \ No newline at end of file diff --git a/tests/warn/i22899.scala b/tests/warn/i22899.scala new file mode 100644 index 000000000000..ae6544e29286 --- /dev/null +++ b/tests/warn/i22899.scala @@ -0,0 +1,27 @@ +case class CaseClass(a: Int) + +object ProductMatch_CaseClass { + def unapply(int: Int): CaseClass = CaseClass(int) +} + +object ProductMatch_NamedTuple { + def unapply(int: Int): (a: Int) = (a = int) +} + +object NameBasedMatch_CaseClass { + def unapply(int: Int): Some[CaseClass] = Some(CaseClass(int)) +} + +object NameBasedMatch_NamedTuple { + def unapply(int: Int): Some[(a: Int)] = Some((a = int)) +} + +object Test { + val ProductMatch_CaseClass(a = x1) = 1 // ok, was pattern's type (x1 : Int) is more specialized than the right hand side expression's type Int + val ProductMatch_NamedTuple(a = x2) = 2 // ok, was pattern binding uses refutable extractor `org.test.ProductMatch_NamedTuple` + val NameBasedMatch_CaseClass(a = x3) = 3 // ok, was pattern's type (x3 : Int) is more specialized than the right hand side expression's type Int + val NameBasedMatch_NamedTuple(a = x4) = 4 // ok, was pattern's type (x4 : Int) is more specialized than the right hand side expression's type Int + + val CaseClass(a = x5) = CaseClass(5) // ok, was pattern's type (x5 : Int) is more specialized than the right hand side expression's type Int + val (a = x6) = (a = 6) // ok +} \ No newline at end of file