Skip to content

Commit dcb296e

Browse files
committed
Fix #7554: Implement TypeTest interface
Using tests from: https://gist.github.com/Blaisorblade/a0eebb6a4f35344e48c4c60dc2a14ce6
1 parent 6cd3a9d commit dcb296e

16 files changed

+519
-9
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,10 @@ class Definitions {
669669
@tu lazy val ClassTagModule: Symbol = ClassTagClass.companionModule
670670
@tu lazy val ClassTagModule_apply: Symbol = ClassTagModule.requiredMethod(nme.apply)
671671

672+
@tu lazy val TypeTestClass: ClassSymbol = ctx.requiredClass("scala.reflect.TypeTest")
673+
@tu lazy val TypeTestModule: Symbol = TypeTestClass.companionModule
674+
@tu lazy val TypeTestModule_identity: Symbol = TypeTestModule.requiredMethod(nme.identity)
675+
672676
@tu lazy val QuotedExprClass: ClassSymbol = ctx.requiredClass("scala.quoted.Expr")
673677
@tu lazy val QuotedExprModule: Symbol = QuotedExprClass.companionModule
674678
@tu lazy val QuotedExprModule_nullExpr: Symbol = QuotedExprModule.requiredMethod(nme.nullExpr)

compiler/src/dotty/tools/dotc/typer/Implicits.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,36 @@ trait Implicits { self: Typer =>
718718
EmptyTree
719719
}
720720

721+
lazy val synthesizedTypeTest: SpecialHandler =
722+
(formal, span) => implicit ctx => formal.argInfos match {
723+
case arg1 :: arg2 :: Nil if !defn.isBottomClass(arg2.typeSymbol) =>
724+
val tp1 = fullyDefinedType(arg1, "TypeTest argument", span)
725+
val tp2 = fullyDefinedType(arg2, "TypeTest argument", span)
726+
val sym2 = tp2.typeSymbol
727+
if tp1 <:< tp2 then
728+
ref(defn.TypeTestModule_identity).appliedToType(tp2).withSpan(span)
729+
else if sym2 == defn.AnyValClass || sym2 == defn.AnyRefAlias || sym2 == defn.ObjectClass then
730+
EmptyTree
731+
else
732+
// Generate SAM: (s: <tp1>) => if s.isInstanceOf[s.type & <tp2>] then Some(s.asInstanceOf[s.type & <tp2>]) else None
733+
def body(args: List[Tree]): Tree = {
734+
val arg :: Nil = args
735+
val t = arg.tpe & tp2
736+
If(
737+
arg.select(defn.Any_isInstanceOf).appliedToType(t),
738+
ref(defn.SomeClass.companionModule.termRef).select(nme.apply)
739+
.appliedToType(t)
740+
.appliedTo(arg.select(nme.asInstanceOf_).appliedToType(t)),
741+
ref(defn.NoneModule))
742+
}
743+
val tpe = MethodType(List(nme.s))(_ => List(tp1), mth => defn.OptionClass.typeRef.appliedTo(mth.newParamRef(0) & tp2))
744+
val meth = ctx.newSymbol(ctx.owner, nme.ANON_FUN, Synthetic | Method, tpe, coord = span)
745+
val typeTestType = defn.TypeTestClass.typeRef.appliedTo(List(tp1, tp2))
746+
Closure(meth, tss => body(tss.head).changeOwner(ctx.owner, meth), targetType = typeTestType).withSpan(span)
747+
case _ =>
748+
EmptyTree
749+
}
750+
721751
/** Synthesize the tree for `'[T]` for an implicit `scala.quoted.Type[T]`.
722752
* `T` is deeply dealiased to avoid references to local type aliases.
723753
*/
@@ -1094,6 +1124,7 @@ trait Implicits { self: Typer =>
10941124
if (mySpecialHandlers == null)
10951125
mySpecialHandlers = List(
10961126
defn.ClassTagClass -> synthesizedClassTag,
1127+
defn.TypeTestClass -> synthesizedTypeTest,
10971128
defn.QuotedTypeClass -> synthesizedTypeTag,
10981129
defn.EqlClass -> synthesizedEql,
10991130
defn.TupledFunctionClass -> synthesizedTupleFunction,

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -736,14 +736,15 @@ class Typer extends Namer
736736
*/
737737
def tryWithClassTag(tree: Typed, pt: Type)(implicit ctx: Context): Tree = tree.tpt.tpe.dealias match {
738738
case tref: TypeRef if !tref.symbol.isClass && !ctx.isAfterTyper && !(tref =:= pt) =>
739-
require(ctx.mode.is(Mode.Pattern))
740-
inferImplicit(defn.ClassTagClass.typeRef.appliedTo(tref),
741-
EmptyTree, tree.tpt.span)(ctx.retractMode(Mode.Pattern)) match {
742-
case SearchSuccess(clsTag, _, _) =>
743-
typed(untpd.Apply(untpd.TypedSplice(clsTag), untpd.TypedSplice(tree.expr)), pt)
744-
case _ =>
745-
tree
739+
def withTag(tpe: Type): Option[Tree] = {
740+
inferImplicit(tpe, EmptyTree, tree.tpt.span)(ctx.retractMode(Mode.Pattern)) match
741+
case SearchSuccess(typeTest, _, _) =>
742+
Some(typed(untpd.Apply(untpd.TypedSplice(typeTest), untpd.TypedSplice(tree.expr)), pt))
743+
case _ =>
744+
None
746745
}
746+
withTag(defn.TypeTestClass.typeRef.appliedTo(pt, tref)).orElse(
747+
withTag(defn.ClassTagClass.typeRef.appliedTo(tref))).getOrElse(tree)
747748
case _ => tree
748749
}
749750

@@ -1580,8 +1581,8 @@ class Typer extends Namer
15801581
val body1 = typed(tree.body, pt)
15811582
body1 match {
15821583
case UnApply(fn, Nil, arg :: Nil)
1583-
if fn.symbol.exists && fn.symbol.owner == defn.ClassTagClass && !body1.tpe.isError =>
1584-
// A typed pattern `x @ (e: T)` with an implicit `ctag: ClassTag[T]`
1584+
if fn.symbol.exists && (fn.symbol.owner == defn.ClassTagClass || fn.symbol.owner.derivesFrom(defn.TypeTestClass)) && !body1.tpe.isError =>
1585+
// A typed pattern `x @ (e: T)` with an implicit `ctag: ClassTag[T]` or `ctag: TypeTest[T]`
15851586
// was rewritten to `x @ ctag(e)` by `tryWithClassTag`.
15861587
// Rewrite further to `ctag(x @ e)`
15871588
tpd.cpy.UnApply(body1)(fn, Nil,
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
---
2+
layout: doc-page
3+
title: "TypeTest"
4+
---
5+
6+
TypeTest
7+
--------
8+
9+
`TypeTest` provides a replacement for `ClassTag.unapply` where the type of the argument is generalized.
10+
`TypeTest.unapply` will return `Some(x.asInstanceOf[Y])` if `x` conforms to `Y`, otherwise it returns `None`.
11+
12+
```scala
13+
trait TypeTest[-S, T] extends Serializable {
14+
def unapply(s: S): Option[s.type & T]
15+
}
16+
```
17+
18+
Just like `ClassTag` used to do, it can be used to perform type checks in patterns.
19+
20+
```scala
21+
type X
22+
type Y <: X
23+
given TypeTest[X, Y] = ...
24+
(x: X) match {
25+
case y: Y => ... // safe checked downcast
26+
case _ => ...
27+
}
28+
```
29+
30+
31+
Examples
32+
--------
33+
34+
Given the following abstract definition of `Peano` numbers that provides `TypeTest[Nat, Zero]` and `TypeTest[Nat, Succ]`
35+
36+
```scala
37+
trait Peano {
38+
type Nat
39+
type Zero <: Nat
40+
type Succ <: Nat
41+
def safeDiv(m: Nat, n: Succ): (Nat, Nat)
42+
val Zero: Zero
43+
val Succ: SuccExtractor
44+
trait SuccExtractor {
45+
def apply(nat: Nat): Succ
46+
def unapply(nat: Succ): Option[Nat]
47+
}
48+
given TypeTest[Nat, Zero] = typeTestOfZero
49+
protected def typeTestOfZero: TypeTest[Nat, Zero]
50+
given TypeTest[Nat, Succ]
51+
protected def typeTestOfSucc: TypeTest[Nat, Succ]
52+
```
53+
54+
it will be possible to write the following program
55+
56+
```scala
57+
val peano: Peano = ...
58+
import peano.{_, given}
59+
def divOpt(m: Nat, n: Nat): Option[(Nat, Nat)] = {
60+
n match {
61+
case Zero => None
62+
case s @ Succ(_) => Some(safeDiv(m, s))
63+
}
64+
}
65+
val two = Succ(Succ(Zero))
66+
val five = Succ(Succ(Succ(two)))
67+
println(divOpt(five, two))
68+
```
69+
70+
Note that without the `TypeTest[Nat, Succ]` the pattern `Succ.unapply(nat: Succ)` would be unchecked.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package scala.reflect
2+
3+
/** A `TypeTest[S, T] contains the logic needed to know at runtime if a value of
4+
* type `S` can be downcasted to `T`.
5+
*
6+
* If a pattern match is performed on a term of type `s: S` that is uncheckable with `s.isInstanceOf[T]` and
7+
* the pattern are of the form:
8+
* - `t: T`
9+
* - `t @ X()` where the `X.unapply` has takes an argument of type `T`
10+
* then a given instance of `TypeTest[S, T]` is summoned and used to perform the test.
11+
*/
12+
@scala.annotation.implicitNotFound(msg = "No TypeTest available for [${S}, ${T}]")
13+
trait TypeTest[-S, T] extends Serializable {
14+
15+
/** A TypeTest[S, T] can serve as an extractor that matches only S of type T.
16+
*
17+
* The compiler tries to turn unchecked type tests in pattern matches into checked ones
18+
* by wrapping a `(_: T)` type pattern as `tt(_: T)`, where `tt` is the `TypeTest[S, T]` instance.
19+
* Type tests necessary before calling other extractors are treated similarly.
20+
* `SomeExtractor(...)` is turned into `tt(SomeExtractor(...))` if `T` in `SomeExtractor.unapply(x: T)`
21+
* is uncheckable, but we have an instance of `TypeTest[S, T]`.
22+
*/
23+
def unapply(x: S): Option[x.type & T]
24+
25+
}
26+
27+
object TypeTest {
28+
29+
def identity[T]: TypeTest[T, T] = Some(_)
30+
31+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import scala.reflect.ClassTag
2+
3+
object IsInstanceOfClassTag {
4+
def safeCast[T: ClassTag](x: Any): Option[T] = {
5+
x match {
6+
case x: T => Some(x) // TODO error: deprecation waring
7+
case _ => None
8+
}
9+
}
10+
11+
def main(args: Array[String]): Unit = {
12+
safeCast[List[String]](List[Int](1)) match {
13+
case None =>
14+
case Some(xs) =>
15+
xs.head.substring(0)
16+
}
17+
18+
safeCast[List[_]](List[Int](1)) match {
19+
case None =>
20+
case Some(xs) =>
21+
xs.head.substring(0) // error
22+
}
23+
}
24+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import scala.reflect.TypeTest
2+
3+
object IsInstanceOfClassTag {
4+
def safeCast[T](x: Any)(using TypeTest[Any, T]): Option[T] = {
5+
x match {
6+
case x: T => Some(x)
7+
case _ => None
8+
}
9+
}
10+
11+
def main(args: Array[String]): Unit = {
12+
safeCast[List[String]](List[Int](1)) match { // error
13+
case None =>
14+
case Some(xs) =>
15+
}
16+
17+
safeCast[List[_]](List[Int](1)) match {
18+
case None =>
19+
case Some(xs) =>
20+
}
21+
}
22+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import scala.reflect.TypeTest
2+
3+
trait R {
4+
type Nat
5+
type Succ <: Nat
6+
type Idx
7+
given TypeTest[Nat, Succ] = typeTestOfSucc
8+
protected def typeTestOfSucc: TypeTest[Nat, Succ]
9+
def n: Nat
10+
def one: Succ
11+
}
12+
13+
object RI extends R {
14+
type Nat = Int
15+
type Succ = Int
16+
type Idx = Int
17+
protected def typeTestOfSucc: TypeTest[Nat, Succ] = new {
18+
def unapply(x: Int): Option[x.type & Succ] =
19+
if x > 0 then Some(x) else None
20+
}
21+
def n: Nat = 4
22+
def one: Succ = 1
23+
}
24+
25+
object Test {
26+
val r1: R = RI
27+
val r2: R = RI
28+
29+
r1.n match {
30+
case n: r2.Nat => // error: the type test for Test.r2.Nat cannot be checked at runtime
31+
case n: r1.Idx => // error: the type test for Test.r1.Idx cannot be checked at runtime
32+
case n: r1.Succ => // Ok
33+
case n: r1.Nat => // Ok
34+
}
35+
36+
r1.one match {
37+
case n: r2.Nat => // error: the type test for Test.r2.Nat cannot be checked at runtime
38+
case n: r1.Idx => // error: the type test for Test.r1.Idx cannot be checked at runtime
39+
case n: r1.Nat => // Ok
40+
}
41+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import scala.reflect.TypeTest
2+
3+
object Test {
4+
def main(args: Array[String]): Unit = {
5+
val p1: T = T1
6+
val p2: T = T1
7+
8+
(p1.y: p1.X) match {
9+
case x: p2.Y => // error: unchecked
10+
case x: p1.Y =>
11+
case _ =>
12+
}
13+
}
14+
15+
}
16+
17+
trait T {
18+
type X
19+
type Y <: X
20+
def x: X
21+
def y: Y
22+
given TypeTest[X, Y] = typeTestOfY
23+
protected def typeTestOfY: TypeTest[X, Y]
24+
}
25+
26+
object T1 extends T {
27+
type X = Boolean
28+
type Y = true
29+
def x: X = false
30+
def y: Y = true
31+
protected def typeTestOfY: TypeTest[X, Y] = new {
32+
def unapply(x: X): Option[x.type & Y] = x match
33+
case x: (true & x.type) => Some(x)
34+
case _ => None
35+
}
36+
37+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import scala.reflect.TypeTest
2+
3+
object Test {
4+
def test[S, T](using TypeTest[S, T]): Unit = ()
5+
val a: A = ???
6+
7+
test[Any, Any]
8+
test[Int, Int]
9+
10+
test[Int, Any]
11+
test[String, Any]
12+
test[String, AnyRef]
13+
14+
test[Any, Int]
15+
test[Any, String]
16+
test[Any, Some[_]]
17+
test[Any, Array[Int]]
18+
test[Seq[Int], List[Int]]
19+
20+
test[Any, Some[Int]] // error
21+
test[Any, a.X] // error
22+
test[a.X, a.Y] // error
23+
24+
}
25+
26+
class A {
27+
type X
28+
type Y <: X
29+
}

tests/neg/type-test-syntesize.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import scala.reflect.TypeTest
2+
3+
object Test {
4+
def test[S, T](using x: TypeTest[S, T]): Unit = ()
5+
6+
test[Any, AnyRef] // error
7+
test[Any, AnyVal] // error
8+
test[Any, Object] // error
9+
10+
test[Any, Nothing] // error
11+
test[AnyRef, Nothing] // error
12+
test[AnyVal, Nothing] // error
13+
test[Null, Nothing] // error
14+
test[Unit, Nothing] // error
15+
test[Int, Nothing] // error
16+
test[8, Nothing] // error
17+
test[List[_], Nothing] // error
18+
test[Nothing, Nothing] // error
19+
}

0 commit comments

Comments
 (0)