From 83cdfc9d519bbda4db23ded8f334fe704393d54c Mon Sep 17 00:00:00 2001 From: odersky Date: Fri, 8 Jul 2022 20:59:42 +0200 Subject: [PATCH 1/3] Fix two problems related to match types as array elements 1. The erasure of an array of matchtypes should sometimes be Object instead of Object[] 2. Classtags of matchtypes can be created only if all alternatives produce the same classtag. About 1: If a matchtype with alternative types A_1, ... A_n is an array element, it should be treated in the same way as the type ? <: A_1 | ... | A_n. It's an _unknown_ subtype of A_1 | ... | A_n. That can cause the erasure of the underlying array to be Object. Fixes #15618 --- .../dotty/tools/dotc/core/TypeErasure.scala | 3 +++ .../dotty/tools/dotc/typer/Synthesizer.scala | 25 ++++++++++++++++--- tests/neg/i15618.check | 18 +++++++++++++ tests/neg/i15618.scala | 23 +++++++++++++++++ tests/run/i15618.check | 2 ++ tests/run/i15618.scala | 24 ++++++++++++++++++ 6 files changed, 91 insertions(+), 4 deletions(-) create mode 100644 tests/neg/i15618.check create mode 100644 tests/neg/i15618.scala create mode 100644 tests/run/i15618.check create mode 100644 tests/run/i15618.scala diff --git a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala index 6e7bb0e7a0d4..41993d9fb578 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala @@ -328,6 +328,9 @@ object TypeErasure { isGenericArrayElement(tp.alias, isScala2) case tp: TypeBounds => !fitsInJVMArray(tp.hi) + case tp: MatchType => + val alts = tp.alternatives + alts.nonEmpty && !fitsInJVMArray(alts.reduce(OrType(_, _, soft = true))) case tp: TypeProxy => isGenericArrayElement(tp.translucentSuperType, isScala2) case tp: AndType => diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index 171e171f33f1..5258c31005ef 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -40,11 +40,28 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): val classTag = ref(defn.ClassTagModule) val tag = if defn.SpecialClassTagClasses.contains(sym) then - classTag.select(sym.name.toTermName) + classTag.select(sym.name.toTermName).withSpan(span) else - val clsOfType = escapeJavaArray(erasure(tp)) - classTag.select(nme.apply).appliedToType(tp).appliedTo(clsOf(clsOfType)) - withNoErrors(tag.withSpan(span)) + def clsOfType(tp: Type): Type = + val tp1 = tp.dealias + if tp1.isMatch then + val matchTp = tp1.underlyingIterator.collect { + case mt: MatchType => mt + }.next + matchTp.alternatives.map(clsOfType) match + case ct1 :: cts if cts.forall(ct1 == _) => ct1 + case _ => NoType + else + escapeJavaArray(erasure(tp)) + val ctype = clsOfType(tp) + if ctype.exists then + classTag.select(nme.apply) + .appliedToType(tp) + .appliedTo(clsOf(ctype)) + .withSpan(span) + else + EmptyTree + withNoErrors(tag) case tp => EmptyTreeNoError else EmptyTreeNoError case _ => EmptyTreeNoError diff --git a/tests/neg/i15618.check b/tests/neg/i15618.check new file mode 100644 index 000000000000..0853da26c27a --- /dev/null +++ b/tests/neg/i15618.check @@ -0,0 +1,18 @@ +-- Error: tests/neg/i15618.scala:17:44 --------------------------------------------------------------------------------- +17 | def toArray: Array[ScalaType[T]] = Array() // error + | ^ + | No ClassTag available for ScalaType[T] + | + | where: T is a type in class Tensor with bounds <: DType + | + | + | Note: a match type could not be fully reduced: + | + | trying to reduce ScalaType[T] + | failed since selector T + | does not match case Float16 => Float + | and cannot be shown to be disjoint from it either. + | Therefore, reduction cannot advance to the remaining cases + | + | case Float32 => Float + | case Int32 => Int diff --git a/tests/neg/i15618.scala b/tests/neg/i15618.scala new file mode 100644 index 000000000000..fd38c8c48f6b --- /dev/null +++ b/tests/neg/i15618.scala @@ -0,0 +1,23 @@ +sealed abstract class DType +sealed class Float16 extends DType +sealed class Float32 extends DType +sealed class Int32 extends DType + +object Float16 extends Float16 +object Float32 extends Float32 +object Int32 extends Int32 + +type ScalaType[U <: DType] <: Int | Float = U match + case Float16 => Float + case Float32 => Float + case Int32 => Int + +class Tensor[T <: DType](dtype: T): + def toSeq: Seq[ScalaType[T]] = Seq() + def toArray: Array[ScalaType[T]] = Array() // error + +@main +def Test = + val t = Tensor(Float32) // Tensor[Float32] + println(t.toSeq.headOption) // works, Seq[Float] + println(t.toArray.headOption) // ClassCastException diff --git a/tests/run/i15618.check b/tests/run/i15618.check new file mode 100644 index 000000000000..8e26b1641990 --- /dev/null +++ b/tests/run/i15618.check @@ -0,0 +1,2 @@ +None +None diff --git a/tests/run/i15618.scala b/tests/run/i15618.scala new file mode 100644 index 000000000000..8149be15b6ba --- /dev/null +++ b/tests/run/i15618.scala @@ -0,0 +1,24 @@ +sealed abstract class DType +sealed class Float16 extends DType +sealed class Float32 extends DType +sealed class Int32 extends DType + +object Float16 extends Float16 +object Float32 extends Float32 +object Int32 extends Int32 + +type ScalaType[U <: DType] <: Int | Float = U match + case Float16 => Float + case Float32 => Float + case Int32 => Int + +abstract class Tensor[T <: DType]: + def toArray: Array[ScalaType[T]] + +object FloatTensor extends Tensor[Float16]: + def toArray: Array[Float] = Array(1, 2, 3) + +@main +def Test = + val t = FloatTensor: Tensor[Float16] // Tensor[Float32] + println(t.toArray.headOption) // was ClassCastException From b4ec7791e3d6a49f7d82bdec86309d06e0aa2185 Mon Sep 17 00:00:00 2001 From: odersky Date: Fri, 8 Jul 2022 22:18:47 +0200 Subject: [PATCH 2/3] Fix check file --- tests/run/i15618.check | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/run/i15618.check b/tests/run/i15618.check index 8e26b1641990..a38595cb9fcb 100644 --- a/tests/run/i15618.check +++ b/tests/run/i15618.check @@ -1,2 +1 @@ -None -None +Some(1.0) From 23ab8009197f344a1c5dae8508595cb1081fe3b5 Mon Sep 17 00:00:00 2001 From: odersky Date: Sat, 9 Jul 2022 10:55:56 +0200 Subject: [PATCH 3/3] Refactor classtag synthesis --- .../src/dotty/tools/dotc/core/Types.scala | 12 ++-- .../dotty/tools/dotc/typer/Synthesizer.scala | 66 ++++++++----------- tests/run/i15618.check | 2 +- tests/run/i15618.scala | 6 +- 4 files changed, 40 insertions(+), 46 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 241621631c41..67959712e445 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -428,11 +428,13 @@ object Types { /** Is this a match type or a higher-kinded abstraction of one? */ - def isMatch(using Context): Boolean = stripped match { - case _: MatchType => true - case tp: HKTypeLambda => tp.resType.isMatch - case tp: AppliedType => tp.isMatchAlias - case _ => false + def isMatch(using Context): Boolean = underlyingMatchType.exists + + def underlyingMatchType(using Context): Type = stripped match { + case tp: MatchType => tp + case tp: HKTypeLambda => tp.resType.underlyingMatchType + case tp: AppliedType if tp.isMatchAlias => tp.superType.underlyingMatchType + case _ => NoType } /** Is this a higher-kinded type lambda with given parameter variances? */ diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index 5258c31005ef..67bf0e83baf0 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -28,43 +28,35 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): private type SpecialHandlers = List[(ClassSymbol, SpecialHandler)] val synthesizedClassTag: SpecialHandler = (formal, span) => - formal.argInfos match - case arg :: Nil => - if isFullyDefined(arg, ForceDegree.all) then - arg match - case defn.ArrayOf(elemTp) => - val etag = typer.inferImplicitArg(defn.ClassTagClass.typeRef.appliedTo(elemTp), span) - if etag.tpe.isError then EmptyTreeNoError else withNoErrors(etag.select(nme.wrap)) - case tp if hasStableErasure(tp) && !defn.isBottomClassAfterErasure(tp.typeSymbol) => - val sym = tp.typeSymbol - val classTag = ref(defn.ClassTagModule) - val tag = - if defn.SpecialClassTagClasses.contains(sym) then - classTag.select(sym.name.toTermName).withSpan(span) - else - def clsOfType(tp: Type): Type = - val tp1 = tp.dealias - if tp1.isMatch then - val matchTp = tp1.underlyingIterator.collect { - case mt: MatchType => mt - }.next - matchTp.alternatives.map(clsOfType) match - case ct1 :: cts if cts.forall(ct1 == _) => ct1 - case _ => NoType - else - escapeJavaArray(erasure(tp)) - val ctype = clsOfType(tp) - if ctype.exists then - classTag.select(nme.apply) - .appliedToType(tp) - .appliedTo(clsOf(ctype)) - .withSpan(span) - else - EmptyTree - withNoErrors(tag) - case tp => EmptyTreeNoError - else EmptyTreeNoError - case _ => EmptyTreeNoError + val tag = formal.argInfos match + case arg :: Nil if isFullyDefined(arg, ForceDegree.all) => + arg match + case defn.ArrayOf(elemTp) => + val etag = typer.inferImplicitArg(defn.ClassTagClass.typeRef.appliedTo(elemTp), span) + if etag.tpe.isError then EmptyTree else etag.select(nme.wrap) + case tp if hasStableErasure(tp) && !defn.isBottomClassAfterErasure(tp.typeSymbol) => + val sym = tp.typeSymbol + val classTagModul = ref(defn.ClassTagModule) + if defn.SpecialClassTagClasses.contains(sym) then + classTagModul.select(sym.name.toTermName).withSpan(span) + else + def clsOfType(tp: Type): Type = tp.dealias.underlyingMatchType match + case matchTp: MatchType => + matchTp.alternatives.map(clsOfType) match + case ct1 :: cts if cts.forall(ct1 == _) => ct1 + case _ => NoType + case _ => + escapeJavaArray(erasure(tp)) + val ctype = clsOfType(tp) + if ctype.exists then + classTagModul.select(nme.apply) + .appliedToType(tp) + .appliedTo(clsOf(ctype)) + .withSpan(span) + else EmptyTree + case _ => EmptyTree + case _ => EmptyTree + (tag, Nil) end synthesizedClassTag val synthesizedTypeTest: SpecialHandler = diff --git a/tests/run/i15618.check b/tests/run/i15618.check index a38595cb9fcb..aeb2d5e2398d 100644 --- a/tests/run/i15618.check +++ b/tests/run/i15618.check @@ -1 +1 @@ -Some(1.0) +Some(1) diff --git a/tests/run/i15618.scala b/tests/run/i15618.scala index 8149be15b6ba..a92bef156b7a 100644 --- a/tests/run/i15618.scala +++ b/tests/run/i15618.scala @@ -15,10 +15,10 @@ type ScalaType[U <: DType] <: Int | Float = U match abstract class Tensor[T <: DType]: def toArray: Array[ScalaType[T]] -object FloatTensor extends Tensor[Float16]: - def toArray: Array[Float] = Array(1, 2, 3) +object IntTensor extends Tensor[Int32]: + def toArray: Array[Int] = Array(1, 2, 3) @main def Test = - val t = FloatTensor: Tensor[Float16] // Tensor[Float32] + val t = IntTensor: Tensor[Int32] println(t.toArray.headOption) // was ClassCastException