Skip to content

Fix comparing AnyVal | Null to Null and selecting in UnsafeNulls #13837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -853,18 +853,23 @@ object Types {
def goAnd(l: Type, r: Type) =
go(l).meet(go(r), pre, safeIntersection = ctx.base.pendingMemberSearches.contains(name))

def goOr(tp: OrType) = tp match {
case OrNull(tp1) if Nullables.unsafeNullsEnabled =>
// Selecting `name` from a type `T | Null` is like selecting `name` from `T`, if
// unsafeNulls is enabled. This can throw at runtime, but we trade soundness for usability.
tp1.findMember(name, pre.stripNull, required, excluded)
case _ =>
def goOr(tp: OrType) =
inline def searchAfterJoin =
// we need to keep the invariant that `pre <: tp`. Branch `union-types-narrow-prefix`
// achieved that by narrowing `pre` to each alternative, but it led to merge errors in
// lots of places. The present strategy is instead of widen `tp` using `join` to be a
// supertype of `pre`.
go(tp.join)
}

if Nullables.unsafeNullsEnabled then tp match
case OrNull(tp1) if tp1 <:< defn.ObjectType =>
// Selecting `name` from a type `T | Null` is like selecting `name` from `T`, if
// unsafeNulls is enabled and T is a subtype of AnyRef.
// This can throw at runtime, but we trade soundness for usability.
tp1.findMember(name, pre.stripNull, required, excluded)
case _ =>
searchAfterJoin
else searchAfterJoin

val recCount = ctx.base.findMemberCount
if (recCount >= Config.LogPendingFindMemberThreshold)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ trait TypeAssigner {
val qualType = qual.tpe.widenIfUnstable
def kind = if tree.isType then "type" else "value"
val foundWithoutNull = qualType match
case OrNull(qualType1) =>
case OrNull(qualType1) if qualType1 <:< defn.ObjectType =>
val name = tree.name
val pre = maybeSkolemizePrefix(qualType1, name)
reallyExists(qualType1.findMember(name, pre))
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ class Typer extends Namer
val qual = typedExpr(tree.qualifier, shallowSelectionProto(tree.name, pt, this))
val qual1 = if Nullables.unsafeNullsEnabled then
qual.tpe match {
case OrNull(tpe1) =>
case OrNull(tpe1) if tpe1 <:< defn.ObjectType =>
qual.cast(AndType(qual.tpe, tpe1))
case tp =>
if tp.isNullType
Expand Down
13 changes: 13 additions & 0 deletions tests/explicit-nulls/neg/AnyValOrNullSelect.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
case class MyVal(i: Int) extends AnyVal:
def printVal: Unit =
println(i)

class Test:
val v: MyVal | Null = MyVal(1)

def f1 =
v.printVal // error: value printVal is not a member of MyVal | Null

def f1 =
import scala.language.unsafeNulls
v.printVal // error: value printVal is not a member of MyVal | Null
36 changes: 36 additions & 0 deletions tests/explicit-nulls/pos/AnyValOrNull.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
case class MyVal(i: Boolean) extends AnyVal

class Test1:

def test1 =
val v: AnyVal | Null = null
if v == null then
println("null")

def test2 =
val v: Int | Null = 1
if v != null then
println(v)

def test3 =
val v: MyVal | Null = MyVal(false)
if v != null then
println(v)

class Test2:
import scala.language.unsafeNulls

def test1 =
val v: AnyVal | Null = null
if v == null then
println("null")

def test2 =
val v: Int | Null = 1
if v != null then
println(v)

def test3 =
val v: MyVal | Null = MyVal(false)
if v != null then
println(v)