Skip to content

An alternative feature to UnsafeNulls: UnsafeJavaReturn #15096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
7 changes: 5 additions & 2 deletions compiler/src/dotty/tools/dotc/Run.scala
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,11 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
.setTyper(new Typer)
.addMode(Mode.ImplicitsEnabled)
.setTyperState(ctx.typerState.fresh(ctx.reporter))
if ctx.settings.YexplicitNulls.value && !Feature.enabledBySetting(nme.unsafeNulls) then
start = start.addMode(Mode.SafeNulls)
if ctx.settings.YexplicitNulls.value then
if !Feature.enabledBySetting(nme.unsafeNulls) then
start = start.addMode(Mode.SafeNulls)
if Feature.enabledBySetting(Feature.unsafeJavaReturn) then
start = start.addMode(Mode.UnsafeJavaReturn)
ctx.initialize()(using start) // re-initialize the base context with start

// `this` must be unchecked for safe initialization because by being passed to setRun during
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/config/Feature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ object Feature:
val symbolLiterals = deprecated("symbolLiterals")
val fewerBraces = experimental("fewerBraces")
val saferExceptions = experimental("saferExceptions")
val unsafeJavaReturn = experimental("unsafeJavaReturn")

/** Is `feature` enabled by by a command-line setting? The enabling setting is
*
Expand Down
20 changes: 14 additions & 6 deletions compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.io.Codec
import collection.mutable
import printing._
import config.{JavaPlatform, SJSPlatform, Platform, ScalaSettings}
import config.Feature
import classfile.ReusableDataReader
import StdNames.nme

Expand Down Expand Up @@ -642,12 +643,19 @@ object Contexts {
def setProfiler(profiler: Profiler): this.type = updateStore(profilerLoc, profiler)
def setNotNullInfos(notNullInfos: List[NotNullInfo]): this.type = updateStore(notNullInfosLoc, notNullInfos)
def setImportInfo(importInfo: ImportInfo): this.type =
importInfo.mentionsFeature(nme.unsafeNulls) match
case Some(true) =>
setMode(this.mode &~ Mode.SafeNulls)
case Some(false) if ctx.settings.YexplicitNulls.value =>
setMode(this.mode | Mode.SafeNulls)
case _ =>
if ctx.settings.YexplicitNulls.value then
importInfo.mentionsFeature(nme.unsafeNulls) match
case Some(true) =>
setMode(this.mode &~ Mode.SafeNulls)
case Some(false) =>
setMode(this.mode | Mode.SafeNulls)
case _ =>
importInfo.mentionsFeature(Feature.unsafeJavaReturn) match
case Some(true) =>
setMode(this.mode | Mode.UnsafeJavaReturn)
case Some(false) =>
setMode(this.mode &~ Mode.UnsafeJavaReturn)
case _ =>
updateStore(importInfoLoc, importInfo)
def setTypeAssigner(typeAssigner: TypeAssigner): this.type = updateStore(typeAssignerLoc, typeAssigner)

Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,7 @@ class Definitions {
@tu lazy val FunctionalInterfaceAnnot: ClassSymbol = requiredClass("java.lang.FunctionalInterface")
@tu lazy val TargetNameAnnot: ClassSymbol = requiredClass("scala.annotation.targetName")
@tu lazy val VarargsAnnot: ClassSymbol = requiredClass("scala.annotation.varargs")
@tu lazy val CanEqualNullAnnot: ClassSymbol = requiredClass("scala.annotation.CanEqualNull")

@tu lazy val JavaRepeatableAnnot: ClassSymbol = requiredClass("java.lang.annotation.Repeatable")

Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Mode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,6 @@ object Mode {
* Type `Null` becomes a subtype of non-primitive value types in TypeComparer.
*/
val RelaxedOverriding: Mode = newMode(30, "RelaxedOverriding")

val UnsafeJavaReturn: Mode = newMode(31, "UnsafeJavaReturn")
}
61 changes: 56 additions & 5 deletions compiler/src/dotty/tools/dotc/core/NullOpsDecorator.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package dotty.tools.dotc
package core

import Annotations._
import Contexts._
import Flags._
import Symbols._
import Types._
import transform.SymUtils._

/** Defines operations on nullable types and tree. */
object NullOpsDecorator:
Expand Down Expand Up @@ -42,6 +46,24 @@ object NullOpsDecorator:
if ctx.explicitNulls then strip(self) else self
}

/** Strips `|Null` from the return type of a Java method,
* replacing it with a `@CanEqualNull` annotation
*/
def replaceOrNull(using Context): Type =
// Since this method should only be called on types from Java,
// handling these cases is enough.
def recur(tp: Type): Type = tp match
case tp @ OrType(lhs, rhs) if rhs.isNullType =>
AnnotatedType(recur(lhs), Annotation(defn.CanEqualNullAnnot))
case tp: AndOrType =>
tp.derivedAndOrType(recur(tp.tp1), recur(tp.tp2))
case tp @ AppliedType(tycon, targs) =>
tp.derivedAppliedType(tycon, targs.map(recur))
case mptp: MethodOrPoly =>
mptp.derivedLambdaType(resType = recur(mptp.resType))
case _ => tp
if ctx.explicitNulls then recur(self) else self

/** Is self (after widening and dealiasing) a type of the form `T | Null`? */
def isNullableUnion(using Context): Boolean = {
val stripped = self.stripNull
Expand All @@ -51,10 +73,39 @@ object NullOpsDecorator:

import ast.tpd._

extension (self: Tree)
extension (tree: Tree)

// cast the type of the tree to a non-nullable type
def castToNonNullable(using Context): Tree = self.typeOpt match {
case OrNull(tp) => self.cast(tp)
case _ => self
}
def castToNonNullable(using Context): Tree = tree.typeOpt match
case OrNull(tp) => tree.cast(tp)
case _ => tree

def tryToCastToCanEqualNull(using Context): Tree =
// return the tree directly if not at Typer phase
if !(ctx.explicitNulls && ctx.phase.isTyper) then return tree

val sym = tree.symbol
val tp = tree.tpe

if !ctx.mode.is(Mode.UnsafeJavaReturn)
|| !sym.is(JavaDefined)
|| sym.isNoValue
|| !sym.isTerm
|| tp.isError then
return tree

tree match
case _: Apply if sym.is(Method) =>
val tp2 = tp.replaceOrNull
if tp ne tp2 then
tree.cast(tp2)
else tree
case _: Select | _: Ident if !sym.is(Method) =>
val tpw = tp.widen
val tp2 = tpw.replaceOrNull
if tpw ne tp2 then
tree.cast(tp2)
else tree
case _ => tree

end NullOpsDecorator
8 changes: 6 additions & 2 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -756,8 +756,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
}
compareTypeBounds
case tp2: AnnotatedType if tp2.isRefining =>
(tp1.derivesAnnotWith(tp2.annot.sameAnnotation) || tp1.isBottomType) &&
recur(tp1, tp2.parent)
// `CanEqualNull` is a special refining annotation.
// An annotated type is equivalent to the original type.
(tp1.derivesAnnotWith(tp2.annot.sameAnnotation)
|| tp2.annot.matches(defn.CanEqualNullAnnot)
|| tp1.isBottomType)
&& recur(tp1, tp2.parent)
case ClassInfo(pre2, cls2, _, _, _) =>
def compareClassInfo = tp1 match {
case ClassInfo(pre1, cls1, _, _, _) =>
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import reporting._
import transform.TypeUtils._
import transform.SymUtils._
import Nullables._
import NullOpsDecorator._
import config.Feature

import collection.mutable
Expand Down Expand Up @@ -908,7 +909,7 @@ trait Applications extends Compatibility {
def simpleApply(fun1: Tree, proto: FunProto)(using Context): Tree =
methPart(fun1).tpe match {
case funRef: TermRef =>
val app = ApplyTo(tree, fun1, funRef, proto, pt)
val app = ApplyTo(tree, fun1, funRef, proto, pt).tryToCastToCanEqualNull
convertNewGenericArray(
widenEnumCase(
postProcessByNameArgs(funRef, app).computeNullable(),
Expand Down
19 changes: 16 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
}

/** Is an `CanEqual[cls1, cls2]` instance assumed for predefined classes `cls1`, cls2`? */
def canComparePredefinedClasses(cls1: ClassSymbol, cls2: ClassSymbol): Boolean =
def canComparePredefinedClasses(cls1: ClassSymbol, cls2: ClassSymbol)(using Context): Boolean =

def cmpWithBoxed(cls1: ClassSymbol, cls2: ClassSymbol) =
cls2 == defn.NothingClass
Expand All @@ -164,15 +164,17 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
cmpWithBoxed(cls2, cls1)
else if ctx.mode.is(Mode.SafeNulls) then
// If explicit nulls is enabled, and unsafeNulls is not enabled,
// and the types don't have `@CanEqualNull` annotation,
// we want to disallow comparison between Object and Null.
// If we have to check whether a variable with a non-nullable type has null value
// (for example, a NotNull java method returns null for some reasons),
// we can still cast it to a nullable type then compare its value.
// we can still use `eq/ne null` or cast it to a nullable type then compare its value.
//
// Example:
// val x: String = null.asInstanceOf[String]
// if (x == null) {} // error: x is non-nullable
// if (x.asInstanceOf[String|Null] == null) {} // ok
// if (x eq null) {} // ok
cls1 == defn.NullClass && cls1 == cls2
else if cls1 == defn.NullClass then
cls1 == cls2 || cls2.derivesFrom(defn.ObjectClass)
Expand All @@ -187,9 +189,20 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
* interpret.
*/
def canComparePredefined(tp1: Type, tp2: Type) =
// In explicit nulls, when one of type has `@CanEqualNull` annotation,
// we use unsafe nulls semantic to check, which allows reference types
// to be compared with `Null`.
// Example:
// val s1: String = ???
// s1 == null // error
// val s2: String @CanEqualNull = ???
// s2 == null // ok
val checkCtx = if ctx.explicitNulls
&& (tp1.hasAnnotation(defn.CanEqualNullAnnot) || tp2.hasAnnotation(defn.CanEqualNullAnnot))
then ctx.retractMode(Mode.SafeNulls) else ctx
tp1.classSymbols.exists(cls1 =>
tp2.classSymbols.exists(cls2 =>
canComparePredefinedClasses(cls1, cls2)))
canComparePredefinedClasses(cls1, cls2)(using checkCtx)))

formal.argTypes match
case args @ (arg1 :: arg2 :: Nil) =>
Expand Down
8 changes: 5 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
ref(ownType).withSpan(tree.span)
case _ =>
tree.withType(ownType)
val tree2 = toNotNullTermRef(tree1, pt)
val tree2 = toNotNullTermRef(tree1, pt).tryToCastToCanEqualNull
checkLegalValue(tree2, pt)
tree2

Expand Down Expand Up @@ -646,7 +646,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer

def typeSelectOnTerm(using Context): Tree =
val qual = typedExpr(tree.qualifier, shallowSelectionProto(tree.name, pt, this))
typedSelect(tree, pt, qual).withSpan(tree.span).computeNullable()
val sel = typedSelect(tree, pt, qual).withSpan(tree.span).computeNullable()
if pt != AssignProto then sel.tryToCastToCanEqualNull else sel

def javaSelectOnType(qual: Tree)(using Context) =
// semantic name conversion for `O$` in java code
Expand Down Expand Up @@ -3679,7 +3680,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
}
simplify(typed(etaExpand(tree, wtp, arity), pt), pt, locked)
else if (wtp.paramInfos.isEmpty && isAutoApplied(tree.symbol))
readaptSimplified(tpd.Apply(tree, Nil))
val app = tpd.Apply(tree, Nil).tryToCastToCanEqualNull
readaptSimplified(app)
else if (wtp.isImplicitMethod)
err.typeMismatch(tree, pt)
else
Expand Down
1 change: 1 addition & 0 deletions compiler/test/dotty/tools/dotc/CompilationTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ class CompilationTests {
compileFilesInDir("tests/explicit-nulls/pos-separate", explicitNullsOptions),
compileFilesInDir("tests/explicit-nulls/pos-patmat", explicitNullsOptions and "-Xfatal-warnings"),
compileFilesInDir("tests/explicit-nulls/unsafe-common", explicitNullsOptions and "-language:unsafeNulls"),
compileFilesInDir("tests/explicit-nulls/unsafe-java", explicitNullsOptions),
compileFile("tests/explicit-nulls/pos-special/i14682.scala", explicitNullsOptions and "-Ysafe-init"),
compileFile("tests/explicit-nulls/pos-special/i14947.scala", explicitNullsOptions and "-Ytest-pickler" and "-Xprint-types"),
)
Expand Down
22 changes: 22 additions & 0 deletions library/src/scala/annotation/CanEqualNull.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package scala.annotation

/** An annotation makes reference types comparable to `null` in explicit nulls.
* `CanEqualNull` is a special refining annotation. An annotated type is equivalent to the original type.
*
* For example:
* ```scala
* val s1: String = ???
* s1 == null // error
* val s2: String @CanEqualNull = ???
* s2 == null // ok
*
* // String =:= String @CanEqualNull
* val s3: String = s2
* val s4: String @CanEqualNull = s1
*
* val ss: Array[String @CanEqualNull] = ???
* ss.map(_ == null)
* ```
*/
@experimental
final class CanEqualNull extends RefiningAnnotation
5 changes: 5 additions & 0 deletions library/src/scala/runtime/stdLibPatches/language.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ object language:
@compileTimeOnly("`saferExceptions` can only be used at compile time in import statements")
object saferExceptions

/** Experimental support for unsafe Java return in explicit nulls
*/
@compileTimeOnly("`unsafeJavaReturn` can only be used at compile time in import statements")
object unsafeJavaReturn

end experimental

/** The deprecated object contains features that are no longer officially suypported in Scala.
Expand Down
4 changes: 2 additions & 2 deletions project/MiMaFilters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ object MiMaFilters {
ProblemFilters.exclude[MissingClassProblem]("scala.runtime.TupleMirror"),
ProblemFilters.exclude[MissingTypesProblem]("scala.Tuple$package$EmptyTuple$"), // we made the empty tuple a case object
ProblemFilters.exclude[DirectMissingMethodProblem]("scala.runtime.Scala3RunTime.nnFail"),
ProblemFilters.exclude[DirectMissingMethodProblem]("scala.runtime.Scala3RunTime.nnFail"),
ProblemFilters.exclude[DirectMissingMethodProblem]("scala.runtime.LazyVals.getOffsetStatic"), // Added for #14780
ProblemFilters.exclude[DirectMissingMethodProblem]("scala.runtime.LazyVals.getOffsetStatic"), // Added for #14780
ProblemFilters.exclude[MissingFieldProblem]("scala.runtime.stdLibPatches.language.3.2-migration"),
ProblemFilters.exclude[MissingFieldProblem]("scala.runtime.stdLibPatches.language.3.2"),
Expand All @@ -28,6 +26,8 @@ object MiMaFilters {
ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#SymbolMethods.typeRef"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#SymbolMethods.termRef"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#TypeTreeModule.ref"),
ProblemFilters.exclude[MissingFieldProblem]("scala.runtime.stdLibPatches.language#experimental.unsafeJavaReturn"),
ProblemFilters.exclude[MissingClassProblem]("scala.runtime.stdLibPatches.language$experimental$unsafeJavaReturn$"),

ProblemFilters.exclude[MissingClassProblem]("scala.annotation.since"),
)
Expand Down
14 changes: 14 additions & 0 deletions tests/explicit-nulls/unsafe-java/JavaStatic.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import language.experimental.unsafeJavaReturn

import java.math.MathContext, MathContext._

val x: MathContext = DECIMAL32
val y: MathContext = MathContext.DECIMAL32

import java.io.File

val s: String = File.separator
import java.time.ZoneId

val zids: java.util.Set[String] = ZoneId.getAvailableZoneIds
val zarr: Array[String] = ZoneId.getAvailableZoneIds.toArray(Array.empty[String | Null])
11 changes: 11 additions & 0 deletions tests/explicit-nulls/unsafe-java/UnaryCall.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import scala.language.experimental.unsafeJavaReturn

import java.lang.reflect.Method

def getMethods(f: String): List[Method] =
val clazz = Class.forName(f)
val methods = clazz.getMethods
if methods == null then List()
else methods.toList

def getClass(o: AnyRef): Class[?] = o.getClass
7 changes: 7 additions & 0 deletions tests/explicit-nulls/unsafe-java/java-chain/J.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class J1 {
J2 getJ2() { return new J2(); }
}

class J2 {
J1 getJ1() { return new J1(); }
}
6 changes: 6 additions & 0 deletions tests/explicit-nulls/unsafe-java/java-chain/S.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import scala.language.experimental.unsafeJavaReturn

def f = {
val j: J2 = new J2()
j.getJ1().getJ2().getJ1().getJ2().getJ1().getJ2()
}
Loading