Skip to content

Backport #15625: Fix two problems related to match types as array elements #15761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/TypeErasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,9 @@ object TypeErasure {
isGenericArrayElement(tp.alias, isScala2)
case tp: TypeBounds =>
!fitsInJVMArray(tp.hi)
case tp: MatchType =>
val alts = tp.alternatives
alts.nonEmpty && !fitsInJVMArray(alts.reduce(OrType(_, _, soft = true)))
case tp: TypeProxy =>
isGenericArrayElement(tp.translucentSuperType, isScala2)
case tp: AndType =>
Expand Down
12 changes: 7 additions & 5 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -428,11 +428,13 @@ object Types {

/** Is this a match type or a higher-kinded abstraction of one?
*/
def isMatch(using Context): Boolean = stripped match {
case _: MatchType => true
case tp: HKTypeLambda => tp.resType.isMatch
case tp: AppliedType => tp.isMatchAlias
case _ => false
def isMatch(using Context): Boolean = underlyingMatchType.exists

def underlyingMatchType(using Context): Type = stripped match {
case tp: MatchType => tp
case tp: HKTypeLambda => tp.resType.underlyingMatchType
case tp: AppliedType if tp.isMatchAlias => tp.superType.underlyingMatchType
case _ => NoType
}

/** Is this a higher-kinded type lambda with given parameter variances? */
Expand Down
49 changes: 29 additions & 20 deletions compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,35 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
private type SpecialHandlers = List[(ClassSymbol, SpecialHandler)]

val synthesizedClassTag: SpecialHandler = (formal, span) =>
formal.argInfos match
case arg :: Nil =>
if isFullyDefined(arg, ForceDegree.all) then
arg match
case defn.ArrayOf(elemTp) =>
val etag = typer.inferImplicitArg(defn.ClassTagClass.typeRef.appliedTo(elemTp), span)
if etag.tpe.isError then EmptyTreeNoError else withNoErrors(etag.select(nme.wrap))
case tp if hasStableErasure(tp) && !defn.isBottomClassAfterErasure(tp.typeSymbol) =>
val sym = tp.typeSymbol
val classTag = ref(defn.ClassTagModule)
val tag =
if defn.SpecialClassTagClasses.contains(sym) then
classTag.select(sym.name.toTermName)
else
val clsOfType = escapeJavaArray(erasure(tp))
classTag.select(nme.apply).appliedToType(tp).appliedTo(clsOf(clsOfType))
withNoErrors(tag.withSpan(span))
case tp => EmptyTreeNoError
else EmptyTreeNoError
case _ => EmptyTreeNoError
val tag = formal.argInfos match
case arg :: Nil if isFullyDefined(arg, ForceDegree.all) =>
arg match
case defn.ArrayOf(elemTp) =>
val etag = typer.inferImplicitArg(defn.ClassTagClass.typeRef.appliedTo(elemTp), span)
if etag.tpe.isError then EmptyTree else etag.select(nme.wrap)
case tp if hasStableErasure(tp) && !defn.isBottomClassAfterErasure(tp.typeSymbol) =>
val sym = tp.typeSymbol
val classTagModul = ref(defn.ClassTagModule)
if defn.SpecialClassTagClasses.contains(sym) then
classTagModul.select(sym.name.toTermName).withSpan(span)
else
def clsOfType(tp: Type): Type = tp.dealias.underlyingMatchType match
case matchTp: MatchType =>
matchTp.alternatives.map(clsOfType) match
case ct1 :: cts if cts.forall(ct1 == _) => ct1
case _ => NoType
case _ =>
escapeJavaArray(erasure(tp))
val ctype = clsOfType(tp)
if ctype.exists then
classTagModul.select(nme.apply)
.appliedToType(tp)
.appliedTo(clsOf(ctype))
.withSpan(span)
else EmptyTree
case _ => EmptyTree
case _ => EmptyTree
(tag, Nil)
end synthesizedClassTag

val synthesizedTypeTest: SpecialHandler =
Expand Down
18 changes: 18 additions & 0 deletions tests/neg/i15618.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
-- Error: tests/neg/i15618.scala:17:44 ---------------------------------------------------------------------------------
17 | def toArray: Array[ScalaType[T]] = Array() // error
| ^
| No ClassTag available for ScalaType[T]
|
| where: T is a type in class Tensor with bounds <: DType
|
|
| Note: a match type could not be fully reduced:
|
| trying to reduce ScalaType[T]
| failed since selector T
| does not match case Float16 => Float
| and cannot be shown to be disjoint from it either.
| Therefore, reduction cannot advance to the remaining cases
|
| case Float32 => Float
| case Int32 => Int
23 changes: 23 additions & 0 deletions tests/neg/i15618.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
sealed abstract class DType
sealed class Float16 extends DType
sealed class Float32 extends DType
sealed class Int32 extends DType

object Float16 extends Float16
object Float32 extends Float32
object Int32 extends Int32

type ScalaType[U <: DType] <: Int | Float = U match
case Float16 => Float
case Float32 => Float
case Int32 => Int

class Tensor[T <: DType](dtype: T):
def toSeq: Seq[ScalaType[T]] = Seq()
def toArray: Array[ScalaType[T]] = Array() // error

@main
def Test =
val t = Tensor(Float32) // Tensor[Float32]
println(t.toSeq.headOption) // works, Seq[Float]
println(t.toArray.headOption) // ClassCastException
1 change: 1 addition & 0 deletions tests/run/i15618.check
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Some(1)
24 changes: 24 additions & 0 deletions tests/run/i15618.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
sealed abstract class DType
sealed class Float16 extends DType
sealed class Float32 extends DType
sealed class Int32 extends DType

object Float16 extends Float16
object Float32 extends Float32
object Int32 extends Int32

type ScalaType[U <: DType] <: Int | Float = U match
case Float16 => Float
case Float32 => Float
case Int32 => Int

abstract class Tensor[T <: DType]:
def toArray: Array[ScalaType[T]]

object IntTensor extends Tensor[Int32]:
def toArray: Array[Int] = Array(1, 2, 3)

@main
def Test =
val t = IntTensor: Tensor[Int32]
println(t.toArray.headOption) // was ClassCastException