Skip to content

Commit 9652826

Browse files
committed
Fix i11694: extract function type and SAM in union type
1 parent a288432 commit 9652826

File tree

3 files changed

+58
-10
lines changed

3 files changed

+58
-10
lines changed

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,6 +1107,20 @@ class Typer extends Namer
11071107
newTypeVar(apply(bounds.orElse(TypeBounds.empty)).bounds)
11081108
case _ => mapOver(t)
11091109
}
1110+
def extractInUnion(t: Type): Seq[Type] = t match {
1111+
case t: OrType =>
1112+
extractInUnion(t.tp1) ++ extractInUnion(t.tp2)
1113+
case t: TypeParamRef =>
1114+
extractInUnion(ctx.typerState.constraint.entry(t).bounds.hi)
1115+
case t if defn.isNonRefinedFunction(t) =>
1116+
Seq(t)
1117+
case SAMType(_: MethodType) =>
1118+
Seq(t)
1119+
case _ =>
1120+
Nil
1121+
}
1122+
def defaultResult = (List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree())
1123+
11101124
val pt1 = pt.stripTypeVar.dealias
11111125
if (pt1 ne pt1.dropDependentRefinement)
11121126
&& defn.isContextFunctionType(pt1.nonPrivateMember(nme.apply).info.finalResultType)
@@ -1115,22 +1129,25 @@ class Typer extends Namer
11151129
i"""Implementation restriction: Expected result type $pt1
11161130
|is a curried dependent context function type. Such types are not yet supported.""",
11171131
tree.srcPos)
1118-
pt1 match {
1132+
1133+
val elems = extractInUnion(pt1)
1134+
if elems.length != 1 then
1135+
// The union type containing multiple function types is ignored
1136+
defaultResult
1137+
else elems.head match {
11191138
case pt1 if defn.isNonRefinedFunction(pt1) =>
11201139
// if expected parameter type(s) are wildcards, approximate from below.
11211140
// if expected result type is a wildcard, approximate from above.
11221141
// this can type the greatest set of admissible closures.
11231142
(pt1.argTypesLo.init, typeTree(interpolateWildcards(pt1.argTypesHi.last)))
11241143
case SAMType(sam @ MethodTpe(_, formals, restpe)) =>
11251144
(formals,
1126-
if (sam.isResultDependent)
1127-
untpd.DependentTypeTree(syms => restpe.substParams(sam, syms.map(_.termRef)))
1128-
else
1129-
typeTree(restpe))
1130-
case tp: TypeParamRef =>
1131-
decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity, tree)
1145+
if (sam.isResultDependent)
1146+
untpd.DependentTypeTree(syms => restpe.substParams(sam, syms.map(_.termRef)))
1147+
else
1148+
typeTree(restpe))
11321149
case _ =>
1133-
(List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree())
1150+
defaultResult
11341151
}
11351152
}
11361153

@@ -1375,14 +1392,22 @@ class Typer extends Namer
13751392
}
13761393

13771394
def typedClosure(tree: untpd.Closure, pt: Type)(using Context): Tree = {
1395+
def extractInUnion(t: Type): Seq[Type] = t match {
1396+
case t: OrType =>
1397+
extractInUnion(t.tp1) ++ extractInUnion(t.tp2)
1398+
case SAMType(_) =>
1399+
Seq(t)
1400+
case _ =>
1401+
Nil
1402+
}
13781403
val env1 = tree.env mapconserve (typed(_))
13791404
val meth1 = typedUnadapted(tree.meth)
13801405
val target =
13811406
if (tree.tpt.isEmpty)
13821407
meth1.tpe.widen match {
13831408
case mt: MethodType =>
1384-
pt.stripNull match {
1385-
case pt @ SAMType(sam)
1409+
extractInUnion(pt) match {
1410+
case Seq(pt @ SAMType(sam))
13861411
if !defn.isFunctionType(pt) && mt <:< sam =>
13871412
// SAMs of the form C[?] where C is a class cannot be conversion targets.
13881413
// The resulting class `class $anon extends C[?] {...}` would be illegal,

tests/explicit-nulls/pos/i11694.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
def test = {
2+
val x = new java.util.ArrayList[String]()
3+
val y = x.stream().nn.filter(s => s.nn.length > 0)
4+
}

tests/neg/i11694.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
def test1 = {
2+
def f11: (Int => Int) | Unit = x => x + 1
3+
def f12: Null | (Int => Int) = x => x + 1
4+
5+
def f21: (Int => Int) | Null = x => x + 1
6+
def f22: Null | (Int => Int) = x => x + 1
7+
}
8+
9+
def test2 = {
10+
def f1: (Int => String) | (Int => Int) | Null = x => x + 1 // error
11+
def f2: (Int => String) | Function[String, Int] | Null = x => "" + x // error
12+
def f3: Function[Int, Int] | Function[String, Int] | Null = x => x + 1 // error
13+
}
14+
15+
def test3 = {
16+
import java.util.function.Function
17+
val f1: Function[String, Int] | Unit = x => x.length
18+
val f2: Function[String, Int] | Null = x => x.length
19+
}

0 commit comments

Comments
 (0)