diff --git a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala index 6e7bb0e7a0d4..41993d9fb578 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala @@ -328,6 +328,9 @@ object TypeErasure { isGenericArrayElement(tp.alias, isScala2) case tp: TypeBounds => !fitsInJVMArray(tp.hi) + case tp: MatchType => + val alts = tp.alternatives + alts.nonEmpty && !fitsInJVMArray(alts.reduce(OrType(_, _, soft = true))) case tp: TypeProxy => isGenericArrayElement(tp.translucentSuperType, isScala2) case tp: AndType => diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 2b51d64a2f27..5292c0506c36 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -419,11 +419,13 @@ object Types { /** Is this a match type or a higher-kinded abstraction of one? */ - def isMatch(using Context): Boolean = stripped match { - case _: MatchType => true - case tp: HKTypeLambda => tp.resType.isMatch - case tp: AppliedType => tp.isMatchAlias - case _ => false + def isMatch(using Context): Boolean = underlyingMatchType.exists + + def underlyingMatchType(using Context): Type = stripped match { + case tp: MatchType => tp + case tp: HKTypeLambda => tp.resType.underlyingMatchType + case tp: AppliedType if tp.isMatchAlias => tp.superType.underlyingMatchType + case _ => NoType } /** Is this a higher-kinded type lambda with given parameter variances? */ diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index 8b0fa88ad5a9..b8bd98fd4056 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -28,26 +28,35 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): private type SpecialHandlers = List[(ClassSymbol, SpecialHandler)] val synthesizedClassTag: SpecialHandler = (formal, span) => - formal.argInfos match - case arg :: Nil => - if isFullyDefined(arg, ForceDegree.all) then - arg match - case defn.ArrayOf(elemTp) => - val etag = typer.inferImplicitArg(defn.ClassTagClass.typeRef.appliedTo(elemTp), span) - if etag.tpe.isError then EmptyTreeNoError else withNoErrors(etag.select(nme.wrap)) - case tp if hasStableErasure(tp) && !defn.isBottomClassAfterErasure(tp.typeSymbol) => - val sym = tp.typeSymbol - val classTag = ref(defn.ClassTagModule) - val tag = - if defn.SpecialClassTagClasses.contains(sym) then - classTag.select(sym.name.toTermName) - else - val clsOfType = escapeJavaArray(erasure(tp)) - classTag.select(nme.apply).appliedToType(tp).appliedTo(clsOf(clsOfType)) - withNoErrors(tag.withSpan(span)) - case tp => EmptyTreeNoError - else EmptyTreeNoError - case _ => EmptyTreeNoError + val tag = formal.argInfos match + case arg :: Nil if isFullyDefined(arg, ForceDegree.all) => + arg match + case defn.ArrayOf(elemTp) => + val etag = typer.inferImplicitArg(defn.ClassTagClass.typeRef.appliedTo(elemTp), span) + if etag.tpe.isError then EmptyTree else etag.select(nme.wrap) + case tp if hasStableErasure(tp) && !defn.isBottomClassAfterErasure(tp.typeSymbol) => + val sym = tp.typeSymbol + val classTagModul = ref(defn.ClassTagModule) + if defn.SpecialClassTagClasses.contains(sym) then + classTagModul.select(sym.name.toTermName).withSpan(span) + else + def clsOfType(tp: Type): Type = tp.dealias.underlyingMatchType match + case matchTp: MatchType => + matchTp.alternatives.map(clsOfType) match + case ct1 :: cts if cts.forall(ct1 == _) => ct1 + case _ => NoType + case _ => + escapeJavaArray(erasure(tp)) + val ctype = clsOfType(tp) + if ctype.exists then + classTagModul.select(nme.apply) + .appliedToType(tp) + .appliedTo(clsOf(ctype)) + .withSpan(span) + else EmptyTree + case _ => EmptyTree + case _ => EmptyTree + (tag, Nil) end synthesizedClassTag val synthesizedTypeTest: SpecialHandler = diff --git a/tests/neg/i15618.check b/tests/neg/i15618.check new file mode 100644 index 000000000000..0853da26c27a --- /dev/null +++ b/tests/neg/i15618.check @@ -0,0 +1,18 @@ +-- Error: tests/neg/i15618.scala:17:44 --------------------------------------------------------------------------------- +17 | def toArray: Array[ScalaType[T]] = Array() // error + | ^ + | No ClassTag available for ScalaType[T] + | + | where: T is a type in class Tensor with bounds <: DType + | + | + | Note: a match type could not be fully reduced: + | + | trying to reduce ScalaType[T] + | failed since selector T + | does not match case Float16 => Float + | and cannot be shown to be disjoint from it either. + | Therefore, reduction cannot advance to the remaining cases + | + | case Float32 => Float + | case Int32 => Int diff --git a/tests/neg/i15618.scala b/tests/neg/i15618.scala new file mode 100644 index 000000000000..fd38c8c48f6b --- /dev/null +++ b/tests/neg/i15618.scala @@ -0,0 +1,23 @@ +sealed abstract class DType +sealed class Float16 extends DType +sealed class Float32 extends DType +sealed class Int32 extends DType + +object Float16 extends Float16 +object Float32 extends Float32 +object Int32 extends Int32 + +type ScalaType[U <: DType] <: Int | Float = U match + case Float16 => Float + case Float32 => Float + case Int32 => Int + +class Tensor[T <: DType](dtype: T): + def toSeq: Seq[ScalaType[T]] = Seq() + def toArray: Array[ScalaType[T]] = Array() // error + +@main +def Test = + val t = Tensor(Float32) // Tensor[Float32] + println(t.toSeq.headOption) // works, Seq[Float] + println(t.toArray.headOption) // ClassCastException diff --git a/tests/run/i15618.check b/tests/run/i15618.check new file mode 100644 index 000000000000..aeb2d5e2398d --- /dev/null +++ b/tests/run/i15618.check @@ -0,0 +1 @@ +Some(1) diff --git a/tests/run/i15618.scala b/tests/run/i15618.scala new file mode 100644 index 000000000000..a92bef156b7a --- /dev/null +++ b/tests/run/i15618.scala @@ -0,0 +1,24 @@ +sealed abstract class DType +sealed class Float16 extends DType +sealed class Float32 extends DType +sealed class Int32 extends DType + +object Float16 extends Float16 +object Float32 extends Float32 +object Int32 extends Int32 + +type ScalaType[U <: DType] <: Int | Float = U match + case Float16 => Float + case Float32 => Float + case Int32 => Int + +abstract class Tensor[T <: DType]: + def toArray: Array[ScalaType[T]] + +object IntTensor extends Tensor[Int32]: + def toArray: Array[Int] = Array(1, 2, 3) + +@main +def Test = + val t = IntTensor: Tensor[Int32] + println(t.toArray.headOption) // was ClassCastException