Skip to content

More consistent results for union types in pickling #15515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1194,10 +1194,6 @@ class TreeUnpickler(reader: TastyReader,
res.withAttachment(SuppressedApplyToNone, ())
else res

def simplifyLub(tree: Tree): Tree =
tree.overwriteType(tree.tpe.simplified)
tree

def readLengthTerm(): Tree = {
val end = readEnd()
val result =
Expand Down Expand Up @@ -1247,25 +1243,23 @@ class TreeUnpickler(reader: TastyReader,
val tpt = ifBefore(end)(readTpt(), EmptyTree)
Closure(Nil, meth, tpt)
case MATCH =>
simplifyLub(
if (nextByte == IMPLICIT) {
readByte()
InlineMatch(EmptyTree, readCases(end))
}
else if (nextByte == INLINE) {
readByte()
InlineMatch(readTerm(), readCases(end))
}
else Match(readTerm(), readCases(end)))
if (nextByte == IMPLICIT) {
readByte()
InlineMatch(EmptyTree, readCases(end))
}
else if (nextByte == INLINE) {
readByte()
InlineMatch(readTerm(), readCases(end))
}
else Match(readTerm(), readCases(end))
case RETURN =>
val from = readSymRef()
val expr = ifBefore(end)(readTerm(), EmptyTree)
Return(expr, Ident(from.termRef))
case WHILE =>
WhileDo(readTerm(), readTerm())
case TRY =>
simplifyLub(
Try(readTerm(), readCases(end), ifBefore(end)(readTerm(), EmptyTree)))
Try(readTerm(), readCases(end), ifBefore(end)(readTerm(), EmptyTree))
case SELECTouter =>
val levels = readNat()
readTerm().outerSelect(levels, SkolemType(readType()))
Expand Down
21 changes: 20 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/Pickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class Pickler extends Phase {
private val beforePickling = new mutable.HashMap[ClassSymbol, String]
private val picklers = new mutable.HashMap[ClassSymbol, TastyPickler]

private val typeSimplifier = new TypeSimplifier

/** Drop any elements of this list that are linked module classes of other elements in the list */
private def dropCompanionModuleClasses(clss: List[ClassSymbol])(using Context): List[ClassSymbol] = {
val companionModuleClasses =
Expand Down Expand Up @@ -135,7 +137,7 @@ class Pickler extends Phase {
}
pickling.println("************* entered toplevel ***********")
for ((cls, unpickler) <- unpicklers) {
val unpickled = unpickler.rootTrees
val unpickled = typeSimplifier.transform(unpickler.rootTrees)
testSame(i"$unpickled%\n%", beforePickling(cls), cls)
}
}
Expand All @@ -151,4 +153,21 @@ class Pickler extends Phase {
|
| diff before-pickling.txt after-pickling.txt""".stripMargin)
end testSame

// Overwrite types of If, Match, and Try nodes with simplified types
// to avoid inconsistencies in unsafe nulls
class TypeSimplifier extends TreeMapWithPreciseStatContexts:
override def transform(tree: Tree)(using Context): Tree =
try tree match
case _: If | _: Match | _: Try =>
val newTree = super.transform(tree)
newTree.overwriteType(newTree.tpe.simplified)
newTree
case _ =>
super.transform(tree)
catch
case ex: TypeError =>
report.error(ex, tree.srcPos)
tree
end TypeSimplifier
}
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,10 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
)
case Block(_, Closure(_, _, tpt)) if ExpandSAMs.needsWrapperClass(tpt.tpe) =>
superAcc.withInvalidCurrentClass(super.transform(tree))
case _: If | _: Match | _: Try =>
val newTree = super.transform(tree)
newTree.overwriteType(newTree.tpe.simplified)
newTree
case tree =>
super.transform(tree)
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Nullables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ object Nullables:
ctx.explicitNulls && !ctx.mode.is(Mode.SafeNulls)

private def needNullifyHi(lo: Type, hi: Type)(using Context): Boolean =
ctx.explicitNulls
ctx.mode.is(Mode.SafeNulls)
&& lo.isExactlyNull // only nullify hi if lo is exactly Null type
&& hi.isValueType
// We cannot check if hi is nullable, because it can cause cyclic reference.
Expand Down
2 changes: 1 addition & 1 deletion compiler/test/dotty/tools/dotc/CompilationTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,9 @@ class CompilationTests {
compileFilesInDir("tests/explicit-nulls/pos", explicitNullsOptions),
compileFilesInDir("tests/explicit-nulls/pos-separate", explicitNullsOptions),
compileFilesInDir("tests/explicit-nulls/pos-patmat", explicitNullsOptions and "-Xfatal-warnings"),
compileFilesInDir("tests/explicit-nulls/pos-pickling", explicitNullsOptions and "-Ytest-pickler" and "-Xprint-types"),
compileFilesInDir("tests/explicit-nulls/unsafe-common", explicitNullsOptions and "-language:unsafeNulls"),
compileFile("tests/explicit-nulls/pos-special/i14682.scala", explicitNullsOptions and "-Ysafe-init"),
compileFile("tests/explicit-nulls/pos-special/i14947.scala", explicitNullsOptions and "-Ytest-pickler" and "-Xprint-types"),
)
}.checkCompile()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class C:
if ??? then g else ""

def f3 =
import scala.language.unsafeNulls
(??? : Boolean) match
case true => g
case _ => ""
Expand Down
59 changes: 59 additions & 0 deletions tests/explicit-nulls/pos-pickling/match-case.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import scala.language.unsafeNulls

def a(): String | Null = ???
val b: String | Null = ???

val i: Int = ???

def f1 = i match {
case 0 => b
case _ => a()
}

def f2 = i match {
case 0 => a()
case _ => b
}

def f3 = i match {
case 0 => a()
case _ => "".trim
}

def f4 = i match {
case 0 => b
case _ => "".trim
}

def g1 = i match {
case 0 => a()
case 1 => ""
case _ => null
}

def g2 = i match {
case 0 => ""
case 1 => null
case _ => b
}

def g3 = i match {
case 0 => null
case 1 => b
case _ => ""
}

def h1(i: Int) = i match
case 0 => 0
case 1 => true
case 2 => i.toString
case _ => null

// This test still fails.
// Even without explicit nulls, the type of Match
// is (0, true, "2"), which is wrong.
// def h2(i: Int) = i match
// case 0 => 0
// case 1 => true
// case 2 => "2"
// case _ => null
35 changes: 35 additions & 0 deletions tests/explicit-nulls/pos-pickling/other-pickling.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import scala.language.unsafeNulls

def associatedFile: String | Null = ???

def source: String = {
val f = associatedFile
if (f != null) f else associatedFile
}

def defines(raw: String): List[String] = {
val ds: List[(Int, Int)] = ???
ds map { case (start, end) => raw.substring(start, end) }
}

abstract class DeconstructorCommon[T >: Null <: AnyRef] {
var field: T = null
def get: this.type = this
def isEmpty: Boolean = field eq null
def isDefined = !isEmpty
def unapply(s: T): this.type ={
field = s
this
}
}

def genBCode =
val bsmArgs: Array[Object | Null] | Null = null
val implMethod = bsmArgs(3).asInstanceOf[Integer].toInt
implMethod

val arrayApply = "a".split(" ")(2)

val globdir: String = if (??? : Boolean) then associatedFile.replaceAll("[\\\\/][^\\\\/]*$", "") else ""

def newInstOfC(c: Class[?]) = c.getConstructor().newInstance()
27 changes: 27 additions & 0 deletions tests/explicit-nulls/pos-pickling/try-catch.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import scala.language.unsafeNulls


def s: String | Null = ???

def tryString = try s catch {
case _: NullPointerException => null
case _ => ""
}

def tryString2 = try s catch {
case _: NullPointerException => ""
case _ => s
}

def loadClass(classLoader: ClassLoader, name: String): Class[?] =
try classLoader.loadClass(name)
catch {
case _ =>
throw new Exception()
}

def loadClass2(classLoader: ClassLoader, name: String): Class[?] =
try classLoader.loadClass(name)
catch {
case _ => null
}