diff --git a/compiler/src/dotty/tools/dotc/CompilationUnit.scala b/compiler/src/dotty/tools/dotc/CompilationUnit.scala index ea0faa19ce6d..0682d171b1b2 100644 --- a/compiler/src/dotty/tools/dotc/CompilationUnit.scala +++ b/compiler/src/dotty/tools/dotc/CompilationUnit.scala @@ -5,11 +5,13 @@ import util.SourceFile import ast.{tpd, untpd} import tpd.{Tree, TreeTraverser} import typer.PrepareInlineable.InlineAccessors +import typer.Nullables import dotty.tools.dotc.core.Contexts.Context import dotty.tools.dotc.core.SymDenotations.ClassDenotation import dotty.tools.dotc.core.Symbols._ import dotty.tools.dotc.transform.SymUtils._ import util.{NoSource, SourceFile} +import util.Spans.Span import core.Decorators._ class CompilationUnit protected (val source: SourceFile) { @@ -42,6 +44,16 @@ class CompilationUnit protected (val source: SourceFile) { suspended = true ctx.run.suspendedUnits += this throw CompilationUnit.SuspendException() + + private var myAssignmentSpans: Map[Int, List[Span]] = null + + /** A map from (name-) offsets of all local variables in this compilation unit + * that can be tracked for being not null to the list of spans of assignments + * to these variables. + */ + def assignmentSpans(given Context): Map[Int, List[Span]] = + if myAssignmentSpans == null then myAssignmentSpans = Nullables.assignmentSpans + myAssignmentSpans } object CompilationUnit { diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index 8cdb7ce38e85..7352139ea0e9 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -88,6 +88,12 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] => /** If this is a block, its expression part */ def stripBlock(tree: Tree): Tree = unsplice(tree) match { case Block(_, expr) => stripBlock(expr) + case Inlined(_, _, expr) => stripBlock(expr) + case _ => tree + } + + def stripInlined(tree: Tree): Tree = unsplice(tree) match { + case Inlined(_, _, expr) => stripInlined(expr) case _ => tree } @@ -391,7 +397,9 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] => if (fn.symbol.is(Erased) || fn.symbol == defn.InternalQuoted_typeQuote) Pure else exprPurity(fn) case Apply(fn, args) => def isKnownPureOp(sym: Symbol) = - sym.owner.isPrimitiveValueClass || sym.owner == defn.StringClass + sym.owner.isPrimitiveValueClass + || sym.owner == defn.StringClass + || defn.pureMethods.contains(sym) if (tree.tpe.isInstanceOf[ConstantType] && isKnownPureOp(tree.symbol) // A constant expression with pure arguments is pure. || (fn.symbol.isStableMember && !fn.symbol.is(Lazy)) || fn.symbol.isPrimaryConstructor && fn.symbol.owner.isNoInitsClass) // TODO: include in isStable? diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index 1827638ba9ad..117f4f1b8697 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -711,11 +711,19 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { class TimeTravellingTreeCopier extends TypedTreeCopier { override def Apply(tree: Tree)(fun: Tree, args: List[Tree])(implicit ctx: Context): Apply = - ta.assignType(untpdCpy.Apply(tree)(fun, args), fun, args) + tree match + case tree: Apply + if (tree.fun eq fun) && (tree.args eq args) + && tree.tpe.isInstanceOf[ConstantType] + && isPureExpr(tree) => tree + case _ => + ta.assignType(untpdCpy.Apply(tree)(fun, args), fun, args) // Note: Reassigning the original type if `fun` and `args` have the same types as before - // does not work here: The computed type depends on the widened function type, not - // the function type itself. A treetransform may keep the function type the + // does not work here in general: The computed type depends on the widened function type, not + // the function type itself. A tree transform may keep the function type the // same but its widened type might change. + // However, we keep constant types of pure expressions. This uses the underlying assumptions + // that pure functions yielding a constant will not change in later phases. override def TypeApply(tree: Tree)(fun: Tree, args: List[Tree])(implicit ctx: Context): TypeApply = ta.assignType(untpdCpy.TypeApply(tree)(fun, args), fun, args) diff --git a/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala b/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala index 8b053263ed00..d78f1cbd8b3d 100644 --- a/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala +++ b/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala @@ -162,6 +162,7 @@ class ScalaSettings extends Settings.SettingGroup { // Extremely experimental language features val YnoKindPolymorphism: Setting[Boolean] = BooleanSetting("-Yno-kind-polymorphism", "Enable kind polymorphism (see https://dotty.epfl.ch/docs/reference/kind-polymorphism.html). Potentially unsound.") + val YexplicitNulls: Setting[Boolean] = BooleanSetting("-Yexplicit-nulls", "Make reference types non-nullable. Nullable types can be expressed with unions: e.g. String|Null.") /** Area-specific debug output */ val YexplainLowlevel: Setting[Boolean] = BooleanSetting("-Yexplain-lowlevel", "When explaining type errors, show types at a lower level.") diff --git a/compiler/src/dotty/tools/dotc/core/Contexts.scala b/compiler/src/dotty/tools/dotc/core/Contexts.scala index 679c43f76a36..11ef4ef2a778 100644 --- a/compiler/src/dotty/tools/dotc/core/Contexts.scala +++ b/compiler/src/dotty/tools/dotc/core/Contexts.scala @@ -15,7 +15,8 @@ import ast.Trees._ import ast.untpd import Flags.GivenOrImplicit import util.{FreshNameCreator, NoSource, SimpleIdentityMap, SourceFile} -import typer.{Implicits, ImportInfo, Inliner, NamerContextOps, SearchHistory, SearchRoot, TypeAssigner, Typer} +import typer.{Implicits, ImportInfo, Inliner, NamerContextOps, SearchHistory, SearchRoot, TypeAssigner, Typer, Nullables} +import Nullables.{NotNullInfo, given} import Implicits.ContextualImplicits import config.Settings._ import config.Config @@ -47,7 +48,11 @@ object Contexts { private val (compilationUnitLoc, store6) = store5.newLocation[CompilationUnit]() private val (runLoc, store7) = store6.newLocation[Run]() private val (profilerLoc, store8) = store7.newLocation[Profiler]() - private val initialStore = store8 + private val (notNullInfosLoc, store9) = store8.newLocation[List[NotNullInfo]]() + private val initialStore = store9 + + /** The current context */ + def curCtx(given ctx: Context): Context = ctx /** A context is passed basically everywhere in dotc. * This is convenient but carries the risk of captured contexts in @@ -207,6 +212,9 @@ object Contexts { /** The current compiler-run profiler */ def profiler: Profiler = store(profilerLoc) + /** The paths currently known to be not null */ + def notNullInfos = store(notNullInfosLoc) + /** The new implicit references that are introduced by this scope */ protected var implicitsCache: ContextualImplicits = null def implicits: ContextualImplicits = { @@ -556,6 +564,7 @@ object Contexts { def setRun(run: Run): this.type = updateStore(runLoc, run) def setProfiler(profiler: Profiler): this.type = updateStore(profilerLoc, profiler) def setFreshNames(freshNames: FreshNameCreator): this.type = updateStore(freshNamesLoc, freshNames) + def setNotNullInfos(notNullInfos: List[NotNullInfo]): this.type = updateStore(notNullInfosLoc, notNullInfos) def setProperty[T](key: Key[T], value: T): this.type = setMoreProperties(moreProperties.updated(key, value)) @@ -587,6 +596,17 @@ object Contexts { def setDebug: this.type = setSetting(base.settings.Ydebug, true) } + given (c: Context) + def addNotNullInfo(info: NotNullInfo) = + c.withNotNullInfos(c.notNullInfos.extendWith(info)) + + def addNotNullRefs(refs: Set[TermRef]) = + c.addNotNullInfo(NotNullInfo(refs, Set())) + + def withNotNullInfos(infos: List[NotNullInfo]): Context = + if c.notNullInfos eq infos then c else c.fresh.setNotNullInfos(infos) + + // TODO: Fix issue when converting ModeChanges and FreshModeChanges to extension givens implicit class ModeChanges(val c: Context) extends AnyVal { final def withModeBits(mode: Mode): Context = if (mode != c.mode) c.fresh.setMode(mode) else c @@ -615,7 +635,9 @@ object Contexts { typeAssigner = TypeAssigner moreProperties = Map.empty source = NoSource - store = initialStore.updated(settingsStateLoc, settingsGroup.defaultState) + store = initialStore + .updated(settingsStateLoc, settingsGroup.defaultState) + .updated(notNullInfosLoc, Nil) typeComparer = new TypeComparer(this) searchHistory = new SearchRoot gadt = EmptyGadtConstraint diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index df5f6194d8ed..d4f78f93155e 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -310,6 +310,10 @@ class Definitions { def ObjectMethods: List[TermSymbol] = List(Object_eq, Object_ne, Object_synchronized, Object_clone, Object_finalize, Object_notify, Object_notifyAll, Object_wait, Object_waitL, Object_waitLI) + /** Methods in Object and Any that do not have a side effect */ + @tu lazy val pureMethods: List[TermSymbol] = List(Any_==, Any_!=, Any_equals, Any_hashCode, + Any_toString, Any_##, Any_getClass, Any_isInstanceOf, Any_typeTest, Object_eq, Object_ne) + @tu lazy val AnyKindClass: ClassSymbol = { val cls = ctx.newCompleteClassSymbol(ScalaPackageClass, tpnme.AnyKind, AbstractFinal | Permanent, Nil) if (!ctx.settings.YnoKindPolymorphism.value) diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 58fdc137e6bf..fa6a32d775a8 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -195,6 +195,7 @@ object StdNames { final val ExprApi: N = "ExprApi" final val Mirror: N = "Mirror" final val Nothing: N = "Nothing" + final val NotNull: N = "NotNull" final val Null: N = "Null" final val Object: N = "Object" final val Product: N = "Product" @@ -261,6 +262,7 @@ object StdNames { val MIRROR_PREFIX: N = "$m." val MIRROR_SHORT: N = "$m" val MIRROR_UNTYPED: N = "$m$untyped" + val NOT_NULL: N = "$nn" val REIFY_FREE_PREFIX: N = "free$" val REIFY_FREE_THIS_SUFFIX: N = "$this" val REIFY_FREE_VALUE_SUFFIX: N = "$value" diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index a881b361d553..1b902d20e016 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -1838,7 +1838,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w else if (!tp2.exists) tp1 else tp.derivedAndType(tp1, tp2) - /** If some (&-operand of) this type is a supertype of `sub` replace it with `NoType`. + /** If some (&-operand of) `tp` is a supertype of `sub` replace it with `NoType`. */ private def dropIfSuper(tp: Type, sub: Type): Type = if (isSubTypeWhenFrozen(sub, tp)) NoType diff --git a/compiler/src/dotty/tools/dotc/core/TypeOps.scala b/compiler/src/dotty/tools/dotc/core/TypeOps.scala index f0a710ea8b17..dd0467c1ef35 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeOps.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeOps.scala @@ -81,7 +81,7 @@ trait TypeOps { this: Context => // TODO: Make standalone object. // called which we override to set the `approximated` flag. range(defn.NothingType, pre) else pre - else if ((pre.termSymbol is Package) && !(thiscls is Package)) + else if (pre.termSymbol.is(Package) && !thiscls.is(Package)) toPrefix(pre.select(nme.PACKAGE), cls, thiscls) else toPrefix(pre.baseType(cls).normalizedPrefix, cls.owner, thiscls) diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index fb0f712a3e8f..b5edc3e59dd8 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -1068,8 +1068,9 @@ object Types { * instead of `ArrayBuffer[? >: Int | A <: Int & A]` */ def widenUnion(implicit ctx: Context): Type = widen match { - case OrType(tp1, tp2) => - ctx.typeComparer.lub(tp1.widenUnion, tp2.widenUnion, canConstrain = true) match { + case tp @ OrType(tp1, tp2) => + if tp1.isNull || tp2.isNull then tp + else ctx.typeComparer.lub(tp1.widenUnion, tp2.widenUnion, canConstrain = true) match { case union: OrType => union.join case res => res } @@ -1399,6 +1400,11 @@ object Types { case _ => true } + /** Is this (an alias of) the `scala.Null` type? */ + final def isNull(given Context) = + isRef(defn.NullClass) + || classSymbol.name == tpnme.Null // !!! temporary kludge for being able to test without the explicit nulls PR + /** The resultType of a LambdaType, or ExprType, the type itself for others */ def resultType(implicit ctx: Context): Type = this @@ -2293,7 +2299,7 @@ object Types { } /** The singleton type for path prefix#myDesignator. - */ + */ abstract case class TermRef(override val prefix: Type, private var myDesignator: Designator) extends NamedType with SingletonType with ImplicitRef { @@ -2886,6 +2892,24 @@ object Types { else apply(tp1, tp2) } + /** An extractor for `T | Null` or `Null | T`, returning the `T` */ + object OrNull with + private def stripNull(tp: Type)(given Context): Type = tp match + case tp @ OrType(tp1, tp2) => + if tp1.isNull then tp2 + else if tp2.isNull then tp1 + else tp.derivedOrType(stripNull(tp1), stripNull(tp2)) + case tp @ AndType(tp1, tp2) => + tp.derivedAndType(stripNull(tp1), stripNull(tp2)) + case _ => + tp + def apply(tp: Type)(given Context) = + OrType(tp, defn.NullType) + def unapply(tp: Type)(given Context): Option[Type] = + val tp1 = stripNull(tp) + if tp1 ne tp then Some(tp1) else None + end OrNull + // ----- ExprType and LambdaTypes ----------------------------------- // Note: method types are cached whereas poly types are not. The reason diff --git a/compiler/src/dotty/tools/dotc/interactive/Completion.scala b/compiler/src/dotty/tools/dotc/interactive/Completion.scala index d3178ba2bbca..ef573fb7a1df 100644 --- a/compiler/src/dotty/tools/dotc/interactive/Completion.scala +++ b/compiler/src/dotty/tools/dotc/interactive/Completion.scala @@ -207,7 +207,7 @@ object Completion { def addMemberCompletions(qual: Tree)(implicit ctx: Context): Unit = if (!qual.tpe.widenDealias.isBottomType) { addAccessibleMembers(qual.tpe) - if (!mode.is(Mode.Import) && !qual.tpe.isRef(defn.NullClass)) + if (!mode.is(Mode.Import) && !qual.tpe.isNull) // Implicit conversions do not kick in when importing // and for `NullClass` they produce unapplicable completions (for unclear reasons) implicitConversionTargets(qual)(ctx.fresh.setExploreTyperState()) diff --git a/compiler/src/dotty/tools/dotc/transform/Erasure.scala b/compiler/src/dotty/tools/dotc/transform/Erasure.scala index df8a3f8ca449..5f72f9214584 100644 --- a/compiler/src/dotty/tools/dotc/transform/Erasure.scala +++ b/compiler/src/dotty/tools/dotc/transform/Erasure.scala @@ -740,11 +740,12 @@ object Erasure { override def typedAnnotated(tree: untpd.Annotated, pt: Type)(implicit ctx: Context): Tree = typed(tree.arg, pt) - override def typedStats(stats: List[untpd.Tree], exprOwner: Symbol)(implicit ctx: Context): List[Tree] = { + override def typedStats(stats: List[untpd.Tree], exprOwner: Symbol)(implicit ctx: Context): (List[Tree], Context) = { val stats1 = if (takesBridges(ctx.owner)) new Bridges(ctx.owner.asClass, erasurePhase).add(stats) else stats - super.typedStats(stats1, exprOwner).filter(!_.isEmpty) + val (stats2, finalCtx) = super.typedStats(stats1, exprOwner) + (stats2.filter(!_.isEmpty), finalCtx) } override def adapt(tree: Tree, pt: Type, locked: TypeVars)(implicit ctx: Context): Tree = diff --git a/compiler/src/dotty/tools/dotc/transform/FirstTransform.scala b/compiler/src/dotty/tools/dotc/transform/FirstTransform.scala index 18a31b12cbf2..8050581ed544 100644 --- a/compiler/src/dotty/tools/dotc/transform/FirstTransform.scala +++ b/compiler/src/dotty/tools/dotc/transform/FirstTransform.scala @@ -160,8 +160,9 @@ class FirstTransform extends MiniPhase with InfoTransformer { thisPhase => constToLiteral(tree) override def transformIf(tree: If)(implicit ctx: Context): Tree = - tree.cond match { - case Literal(Constant(c: Boolean)) => if (c) tree.thenp else tree.elsep + tree.cond.tpe match { + case ConstantType(Constant(c: Boolean)) if isPureExpr(tree.cond) => + if (c) tree.thenp else tree.elsep case _ => tree } diff --git a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala index d20105ae2586..73af47c275bb 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala @@ -434,9 +434,9 @@ class TreeChecker extends Phase with SymTransformer { } } - override def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type)(implicit ctx: Context): CaseDef = + override def typedCase(tree: untpd.CaseDef, sel: Tree, selType: Type, pt: Type)(implicit ctx: Context): CaseDef = withPatSyms(tpd.patVars(tree.pat.asInstanceOf[tpd.Tree])) { - super.typedCase(tree, selType, pt) + super.typedCase(tree, sel, selType, pt) } override def typedClosure(tree: untpd.Closure, pt: Type)(implicit ctx: Context): Tree = { @@ -466,7 +466,7 @@ class TreeChecker extends Phase with SymTransformer { * is that we should be able to pull out an expression as an initializer * of a helper value without having to do a change owner traversal of the expression. */ - override def typedStats(trees: List[untpd.Tree], exprOwner: Symbol)(implicit ctx: Context): List[Tree] = { + override def typedStats(trees: List[untpd.Tree], exprOwner: Symbol)(implicit ctx: Context): (List[Tree], Context) = { for (tree <- trees) tree match { case tree: untpd.DefTree => checkOwner(tree) case _: untpd.Thicket => assert(false, i"unexpanded thicket $tree in statement sequence $trees%\n%") diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index cc3641f1c94f..2f04f7f17a8f 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -22,6 +22,7 @@ import NameKinds.DefaultGetterName import ProtoTypes._ import Inferencing._ import transform.TypeUtils._ +import Nullables.given import collection.mutable import config.Printers.{overload, typr, unapp} @@ -864,8 +865,9 @@ trait Applications extends Compatibility { if (proto.allArgTypesAreCurrent()) new ApplyToTyped(tree, fun1, funRef, proto.unforcedTypedArgs, pt) else - new ApplyToUntyped(tree, fun1, funRef, proto, pt)(argCtx(tree)) - convertNewGenericArray(app.result) + new ApplyToUntyped(tree, fun1, funRef, proto, pt)( + given fun1.nullableInArgContext(given argCtx(tree))) + convertNewGenericArray(app.result).computeNullable() case _ => handleUnexpectedFunType(tree, fun1) } diff --git a/compiler/src/dotty/tools/dotc/typer/ConstFold.scala b/compiler/src/dotty/tools/dotc/typer/ConstFold.scala index 9b4eef28369a..0fd7174dc774 100644 --- a/compiler/src/dotty/tools/dotc/typer/ConstFold.scala +++ b/compiler/src/dotty/tools/dotc/typer/ConstFold.scala @@ -11,6 +11,7 @@ import Constants._ import Names._ import StdNames._ import Contexts._ +import Nullables.{CompareNull, TrackedRef} object ConstFold { @@ -19,15 +20,17 @@ object ConstFold { /** If tree is a constant operation, replace with result. */ def apply[T <: Tree](tree: T)(implicit ctx: Context): T = finish(tree) { tree match { + case CompareNull(TrackedRef(ref), testEqual) + if ctx.settings.YexplicitNulls.value && ctx.notNullInfos.impliesNotNull(ref) => + // TODO maybe drop once we have general Nullability? + Constant(!testEqual) case Apply(Select(xt, op), yt :: Nil) => - xt.tpe.widenTermRefExpr.normalized match { + xt.tpe.widenTermRefExpr.normalized match case ConstantType(x) => - yt.tpe.widenTermRefExpr match { + yt.tpe.widenTermRefExpr match case ConstantType(y) => foldBinop(op, x, y) case _ => null - } case _ => null - } case Select(xt, op) => xt.tpe.widenTermRefExpr match { case ConstantType(x) => foldUnop(op, x) diff --git a/compiler/src/dotty/tools/dotc/typer/Docstrings.scala b/compiler/src/dotty/tools/dotc/typer/Docstrings.scala index a7fe1ac69b9c..7bedf6033b45 100644 --- a/compiler/src/dotty/tools/dotc/typer/Docstrings.scala +++ b/compiler/src/dotty/tools/dotc/typer/Docstrings.scala @@ -33,7 +33,7 @@ object Docstrings { expandComment(sym).map { expanded => val typedUsecases = expanded.usecases.map { usecase => ctx.typer.enterSymbol(ctx.typer.createSymbol(usecase.untpdCode)) - ctx.typer.typedStats(usecase.untpdCode :: Nil, owner) match { + ctx.typer.typedStats(usecase.untpdCode :: Nil, owner)._1 match { case List(df: tpd.DefDef) => usecase.typed(df) case _ => diff --git a/compiler/src/dotty/tools/dotc/typer/Inliner.scala b/compiler/src/dotty/tools/dotc/typer/Inliner.scala index 632b5be4ae53..9b04ce1e8b78 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inliner.scala @@ -24,6 +24,7 @@ import ErrorReporting.errorTree import dotty.tools.dotc.tastyreflect.ReflectionImpl import dotty.tools.dotc.util.{SimpleIdentityMap, SimpleIdentitySet, SourceFile, SourcePosition} import dotty.tools.dotc.parsing.Parsers.Parser +import Nullables.given import collection.mutable import reporting.trace @@ -1064,10 +1065,10 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { errorTree(tree, em"""cannot reduce inline if | its condition ${tree.cond} | is not a constant value""") - else { + else + cond1.computeNullableDeeply() val if1 = untpd.cpy.If(tree)(cond = untpd.TypedSplice(cond1)) super.typedIf(if1, pt) - } } override def typedApply(tree: untpd.Apply, pt: Type)(implicit ctx: Context): Tree = diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index 3d8583b78c6e..e45205a4457a 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -559,7 +559,9 @@ class Namer { typer: Typer => case _ => () } - /** Create top-level symbols for statements and enter them into symbol table */ + /** Create top-level symbols for statements and enter them into symbol table + * @return A context that reflects all imports in `stats`. + */ def index(stats: List[Tree])(implicit ctx: Context): Context = { // module name -> (stat, moduleCls | moduleVal) @@ -1345,11 +1347,10 @@ class Namer { typer: Typer => // We also drop the @Repeated annotation here to avoid leaking it in method result types // (see run/inferred-repeated-result). def widenRhs(tp: Type): Type = { - val tp1 = tp.widenTermRefExpr match { + val tp1 = tp.widenTermRefExpr.simplified match case ctp: ConstantType if isInlineVal => ctp case ref: TypeRef if ref.symbol.is(ModuleClass) => tp - case _ => tp.widenUnion - } + case tp => tp.widenUnion tp1.dropRepeatedAnnot } diff --git a/compiler/src/dotty/tools/dotc/typer/Nullables.scala b/compiler/src/dotty/tools/dotc/typer/Nullables.scala new file mode 100644 index 000000000000..d3b2d146d9c1 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/typer/Nullables.scala @@ -0,0 +1,360 @@ +package dotty.tools +package dotc +package typer + +import core._ +import Types._, Contexts._, Symbols._, Decorators._, Constants._ +import annotation.tailrec +import StdNames.nme +import util.Property +import Names.Name +import util.Spans.Span +import Flags.Mutable +import collection.mutable + +/** Operations for implementing a flow analysis for nullability */ +object Nullables with + import ast.tpd._ + + /** A set of val or var references that are known to be not null, plus a set of + * variable references that are not known (anymore) to be not null + */ + case class NotNullInfo(asserted: Set[TermRef], retracted: Set[TermRef]) + assert((asserted & retracted).isEmpty) + + def isEmpty = this eq NotNullInfo.empty + + def retractedInfo = NotNullInfo(Set(), retracted) + + /** The sequential combination with another not-null info */ + def seq(that: NotNullInfo): NotNullInfo = + if this.isEmpty then that + else if that.isEmpty then this + else NotNullInfo( + this.asserted.union(that.asserted).diff(that.retracted), + this.retracted.union(that.retracted).diff(that.asserted)) + + /** The alternative path combination with another not-null info. Used to merge + * the nullability info of the two branches of an if. + */ + def alt(that: NotNullInfo): NotNullInfo = + NotNullInfo(this.asserted.intersect(that.asserted), this.retracted.union(that.retracted)) + + object NotNullInfo with + val empty = new NotNullInfo(Set(), Set()) + def apply(asserted: Set[TermRef], retracted: Set[TermRef]): NotNullInfo = + if asserted.isEmpty && retracted.isEmpty then empty + else new NotNullInfo(asserted, retracted) + end NotNullInfo + + /** A pair of not-null sets, depending on whether a condition is `true` or `false` */ + case class NotNullConditional(ifTrue: Set[TermRef], ifFalse: Set[TermRef]) with + def isEmpty = this eq NotNullConditional.empty + + object NotNullConditional with + val empty = new NotNullConditional(Set(), Set()) + def apply(ifTrue: Set[TermRef], ifFalse: Set[TermRef]): NotNullConditional = + if ifTrue.isEmpty && ifFalse.isEmpty then empty + else new NotNullConditional(ifTrue, ifFalse) + end NotNullConditional + + /** An attachment that represents conditional flow facts established + * by this tree, which represents a condition. + */ + private[typer] val NNConditional = Property.StickyKey[NotNullConditional] + + /** An attachment that represents unconditional flow facts established + * by this tree. + */ + private[typer] val NNInfo = Property.StickyKey[NotNullInfo] + + /** An extractor for null comparisons */ + object CompareNull with + + /** Matches one of + * + * tree == null, tree eq null, null == tree, null eq tree + * tree != null, tree ne null, null != tree, null ne tree + * + * The second boolean result is true for equality tests, false for inequality tests + */ + def unapply(tree: Tree)(given Context): Option[(Tree, Boolean)] = tree match + case Apply(Select(l, _), Literal(Constant(null)) :: Nil) => + testSym(tree.symbol, l) + case Apply(Select(Literal(Constant(null)), _), r :: Nil) => + testSym(tree.symbol, r) + case _ => + None + + private def testSym(sym: Symbol, operand: Tree)(given Context) = + if sym == defn.Any_== || sym == defn.Object_eq then Some((operand, true)) + else if sym == defn.Any_!= || sym == defn.Object_ne then Some((operand, false)) + else None + + end CompareNull + + /** An extractor for null-trackable references */ + object TrackedRef + def unapply(tree: Tree)(given Context): Option[TermRef] = tree.typeOpt match + case ref: TermRef if isTracked(ref) => Some(ref) + case _ => None + end TrackedRef + + /** Is given reference tracked for nullability? + * This is the case if the reference is a path to an immutable val, or if it refers + * to a local mutable variable where all assignments to the variable are _reachable_ + * (in the sense of how it is defined in assignmentSpans). + */ + def isTracked(ref: TermRef)(given Context) = + ref.isStable + || { val sym = ref.symbol + sym.is(Mutable) + && sym.owner.isTerm + && sym.owner.enclosingMethod == curCtx.owner.enclosingMethod + && sym.span.exists + && curCtx.compilationUnit != null // could be null under -Ytest-pickler + && curCtx.compilationUnit.assignmentSpans.contains(sym.span.start) + } + + /** The nullability context to be used after a case that matches pattern `pat`. + * If `pat` is `null`, this will assert that the selector `sel` is not null afterwards. + */ + def afterPatternContext(sel: Tree, pat: Tree)(given ctx: Context) = (sel, pat) match + case (TrackedRef(ref), Literal(Constant(null))) => ctx.addNotNullRefs(Set(ref)) + case _ => ctx + + /** The nullability context to be used for the guard and rhs of a case with + * given pattern `pat`. If the pattern can only match non-null values, this + * will assert that the selector `sel` is not null in these regions. + */ + def caseContext(sel: Tree, pat: Tree)(given ctx: Context): Context = sel match + case TrackedRef(ref) if matchesNotNull(pat) => ctx.addNotNullRefs(Set(ref)) + case _ => ctx + + private def matchesNotNull(pat: Tree)(given Context): Boolean = pat match + case _: Typed | _: UnApply => true + case Alternative(pats) => pats.forall(matchesNotNull) + // TODO: Add constant pattern if the constant type is not nullable + case _ => false + + given (infos: List[NotNullInfo]) + + /** Do the current not-null infos imply that `ref` is not null? + * Not-null infos are as a history where earlier assertions and retractions replace + * later ones (i.e. it records the assignment history in reverse, with most recent first) + */ + @tailrec def impliesNotNull(ref: TermRef): Boolean = infos match + case info :: infos1 => + if info.asserted.contains(ref) then true + else if info.retracted.contains(ref) then false + else impliesNotNull(infos1)(ref) + case _ => + false + + /** Add `info` as the most recent entry to the list of null infos. Assertions + * or retractions in `info` supersede infos in existing entries of `infos`. + */ + def extendWith(info: NotNullInfo) = + if info.isEmpty + || info.asserted.forall(infos.impliesNotNull(_)) + && !info.retracted.exists(infos.impliesNotNull(_)) + then infos + else info :: infos + + given (tree: Tree) + + /* The `tree` with added nullability attachment */ + def withNotNullInfo(info: NotNullInfo): tree.type = + if !info.isEmpty then tree.putAttachment(NNInfo, info) + tree + + /* The nullability info of `tree` */ + def notNullInfo(given Context): NotNullInfo = + stripInlined(tree).getAttachment(NNInfo) match + case Some(info) if !curCtx.erasedTypes => info + case _ => NotNullInfo.empty + + /* The nullability info of `tree`, assuming it is a condition that evaluates to `c` */ + def notNullInfoIf(c: Boolean)(given Context): NotNullInfo = + val cond = tree.notNullConditional + if cond.isEmpty then tree.notNullInfo + else tree.notNullInfo.seq(NotNullInfo(if c then cond.ifTrue else cond.ifFalse, Set())) + + /** The paths that are known to be not null if the condition represented + * by `tree` yields `true` or `false`. Two empty sets if `tree` is not + * a condition. + */ + def notNullConditional(given Context): NotNullConditional = + stripBlock(tree).getAttachment(NNConditional) match + case Some(cond) if !curCtx.erasedTypes => cond + case _ => NotNullConditional.empty + + /** The current context augmented with nullability information of `tree` */ + def nullableContext(given Context): Context = + val info = tree.notNullInfo + if info.isEmpty then curCtx else curCtx.addNotNullInfo(info) + + /** The current context augmented with nullability information, + * assuming the result of the condition represented by `tree` is the same as + * the value of `c`. + */ + def nullableContextIf(c: Boolean)(given Context): Context = + val info = tree.notNullInfoIf(c) + if info.isEmpty then curCtx else curCtx.addNotNullInfo(info) + + /** The context to use for the arguments of the function represented by `tree`. + * This is the current context, augmented with nullability information + * of the left argument, if the application is a boolean `&&` or `||`. + */ + def nullableInArgContext(given Context): Context = tree match + case Select(x, _) if !curCtx.erasedTypes => + if tree.symbol == defn.Boolean_&& then x.nullableContextIf(true) + else if tree.symbol == defn.Boolean_|| then x.nullableContextIf(false) + else curCtx + case _ => curCtx + + /** The `tree` augmented with nullability information in an attachment. + * The following operations lead to nullability info being recorded: + * + * 1. Null tests using `==`, `!=`, `eq`, `ne`, if the compared entity is + * a path (i.e. a stable TermRef) + * 2. Boolean &&, ||, ! + */ + def computeNullable()(given Context): tree.type = + def setConditional(ifTrue: Set[TermRef], ifFalse: Set[TermRef]) = + tree.putAttachment(NNConditional, NotNullConditional(ifTrue, ifFalse)) + if !curCtx.erasedTypes && analyzedOps.contains(tree.symbol.name.toTermName) then + tree match + case CompareNull(TrackedRef(ref), testEqual) => + if testEqual then setConditional(Set(), Set(ref)) + else setConditional(Set(ref), Set()) + case Apply(Select(x, _), y :: Nil) => + val xc = x.notNullConditional + val yc = y.notNullConditional + if !(xc.isEmpty && yc.isEmpty) then + if tree.symbol == defn.Boolean_&& then + setConditional(xc.ifTrue | yc.ifTrue, xc.ifFalse & yc.ifFalse) + else if tree.symbol == defn.Boolean_|| then + setConditional(xc.ifTrue & yc.ifTrue, xc.ifFalse | yc.ifFalse) + case Select(x, _) if tree.symbol == defn.Boolean_! => + val xc = x.notNullConditional + if !xc.isEmpty then + setConditional(xc.ifFalse, xc.ifTrue) + case _ => + tree + + /** Compute nullability information for this tree and all its subtrees */ + def computeNullableDeeply()(given Context): Unit = + new TreeTraverser { + def traverse(tree: Tree)(implicit ctx: Context) = + traverseChildren(tree) + tree.computeNullable() + }.traverse(tree) + + given (tree: Assign) + def computeAssignNullable()(given Context): tree.type = tree.lhs match + case TrackedRef(ref) => + tree.withNotNullInfo(NotNullInfo(Set(), Set(ref))) // TODO: refine with nullability type info + case _ => tree + + private val analyzedOps = Set(nme.EQ, nme.NE, nme.eq, nme.ne, nme.ZAND, nme.ZOR, nme.UNARY_!) + + /** A map from (name-) offsets of all local variables in this compilation unit + * that can be tracked for being not null to the list of spans of assignments + * to these variables. A variable can be tracked if it has only reachable assignments + * An assignment is reachable if the path of tree nodes between the block enclosing + * the variable declaration to the assignment consists only of if-expressions, + * while-expressions, block-expressions and type-ascriptions. + * Only reachable assignments are handled correctly in the nullability analysis. + * Therefore, variables with unreachable assignments can be assumed to be not-null + * only if their type asserts it. + * + * Note: we track the local variables through their offset and not through their name + * because of shadowing. + */ + def assignmentSpans(given Context): Map[Int, List[Span]] = + import ast.untpd._ + + object populate extends UntypedTreeTraverser with + + /** The name offsets of variables that are tracked */ + var tracked: Map[Int, List[Span]] = Map.empty + + /** Map the names of potentially trackable candidate variables in scope to the spans + * of their reachable assignments + */ + val candidates = mutable.Map[Name, List[Span]]() + + /** An assignment to a variable that's not in reachable makes the variable + * ineligible for tracking + */ + var reachable: Set[Name] = Set.empty + + def traverse(tree: Tree)(implicit ctx: Context) = + val savedReachable = reachable + tree match + case Block(stats, expr) => + var shadowed: Set[(Name, List[Span])] = Set.empty + for case (stat: ValDef) <- stats if stat.mods.is(Mutable) do + for prevSpans <- candidates.put(stat.name, Nil) do + shadowed += (stat.name -> prevSpans) + reachable += stat.name + traverseChildren(tree) + for case (stat: ValDef) <- stats if stat.mods.is(Mutable) do + for spans <- candidates.remove(stat.name) do + tracked += (stat.nameSpan.start -> spans) // candidates that survive until here are tracked + candidates ++= shadowed + case Assign(Ident(name), rhs) => + candidates.get(name) match + case Some(spans) => + if reachable.contains(name) then candidates(name) = tree.span :: spans + else candidates -= name + case None => + traverseChildren(tree) + case _: (If | WhileDo | Typed) => + traverseChildren(tree) // assignments to candidate variables are OK here ... + case _ => + reachable = Set.empty // ... but not here + traverseChildren(tree) + reachable = savedReachable + + populate.traverse(curCtx.compilationUnit.untpdTree) + populate.tracked + end assignmentSpans + + /** The initial context to be used for a while expression with given span. + * In this context, all variables that are assigned within the while expression + * have their nullability status retracted, i.e. are not known to be not null. + * While necessary for soundness, this scheme loses precision: Even if + * the initial state of the variable is not null and all assignments to the variable + * in the while expression are also known to be not null, the variable is still + * assumed to be potentially null. The loss of precision is unavoidable during + * normal typing, since we can only do a linear traversal which does not allow + * a fixpoint computation. But it could be mitigated as follows: + * + * - initially, use `whileContext` as computed here + * - when typechecking the while, delay all errors due to a variable being potentially null + * - afterwards, if there are such delayed errors, run the analysis again with + * as a fixpoint computation, reporting all previously delayed errors that remain. + * + * The following code would produce an error in the current analysis, but not in the + * refined analysis: + * + * class Links(val elem: T, val next: Links | Null) + * + * var xs: Links | Null = Links(1, null) + * var ys: Links | Null = xs + * while xs != null + * ys = Links(xs.elem, ys.next) // error in unrefined: ys is potentially null here + * xs = xs.next + */ + def whileContext(whileSpan: Span)(given Context): Context = + def isRetracted(ref: TermRef): Boolean = + val sym = ref.symbol + sym.span.exists + && assignmentSpans.getOrElse(sym.span.start, Nil).exists(whileSpan.contains(_)) + && curCtx.notNullInfos.impliesNotNull(ref) + val retractedVars = curCtx.notNullInfos.flatMap(_.asserted.filter(isRetracted)).toSet + curCtx.addNotNullInfo(NotNullInfo(Set(), retractedVars)) + +end Nullables diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 505da8883248..55b44bd18f54 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -40,6 +40,7 @@ import dotty.tools.dotc.transform.{PCPCheckAndHeal, Staging, TreeMapWithStages} import transform.SymUtils._ import transform.TypeUtils._ import reporting.trace +import Nullables.{NotNullInfo, given} object Typer { @@ -78,8 +79,6 @@ object Typer { */ private[typer] val HiddenSearchFailure = new Property.Key[SearchFailure] } - - class Typer extends Namer with TypeAssigner with Applications @@ -453,6 +452,7 @@ class Typer extends Namer def typeSelectOnTerm(implicit ctx: Context): Tree = typedSelect(tree, pt, typedExpr(tree.qualifier, selectionProto(tree.name, pt, this))) + .computeNullable() def typeSelectOnType(qual: untpd.Tree)(implicit ctx: Context) = typedSelect(untpd.cpy.Select(tree)(qual, tree.name.toTypeName), pt) @@ -634,6 +634,7 @@ class Typer extends Namer else if (isWildcard) tree.expr.withType(tpt.tpe) else typed(tree.expr, tpt.tpe.widenSkolem) assignType(cpy.Typed(tree)(expr1, tpt), underlyingTreeTpe) + .withNotNullInfo(expr1.notNullInfo) } if (untpd.isWildcardStarArg(tree)) { @@ -732,6 +733,7 @@ class Typer extends Namer val lhsBounds = TypeBounds.lower(lhsVal.symbol.info).asSeenFrom(ref.prefix, lhsVal.symbol.owner) assignType(cpy.Assign(tree)(lhs1, typed(tree.rhs, lhsBounds.loBound))) + .computeAssignNullable() } else { val pre = ref.prefix @@ -754,15 +756,19 @@ class Typer extends Namer } } - def typedBlockStats(stats: List[untpd.Tree])(implicit ctx: Context): (Context, List[tpd.Tree]) = - (index(stats), typedStats(stats, ctx.owner)) + def typedBlockStats(stats: List[untpd.Tree])(implicit ctx: Context): (List[tpd.Tree], Context) = + index(stats) + typedStats(stats, ctx.owner) def typedBlock(tree: untpd.Block, pt: Type)(implicit ctx: Context): Tree = { val localCtx = ctx.retractMode(Mode.Pattern) - val (exprCtx, stats1) = typedBlockStats(tree.stats)(given localCtx) - val expr1 = typedExpr(tree.expr, pt.dropIfProto)(exprCtx) + val (stats1, exprCtx) = typedBlockStats(tree.stats)(given localCtx) + val expr1 = typedExpr(tree.expr, pt.dropIfProto)(given exprCtx) ensureNoLocalRefs( - cpy.Block(tree)(stats1, expr1).withType(expr1.tpe), pt, localSyms(stats1)) + cpy.Block(tree)(stats1, expr1) + .withType(expr1.tpe) + .withNotNullInfo(stats1.foldRight(expr1.notNullInfo)(_.notNullInfo.seq(_))), + pt, localSyms(stats1)) } def escapingRefs(block: Tree, localSyms: => List[Symbol])(implicit ctx: Context): collection.Set[NamedType] = { @@ -803,21 +809,31 @@ class Typer extends Namer } } - def typedIf(tree: untpd.If, pt: Type)(implicit ctx: Context): Tree = { - if (tree.isInline) checkInInlineContext("inline if", tree.posd) + def typedIf(tree: untpd.If, pt: Type)(implicit ctx: Context): Tree = + if tree.isInline then checkInInlineContext("inline if", tree.posd) val cond1 = typed(tree.cond, defn.BooleanType) - if (tree.elsep.isEmpty) { - val thenp1 = typed(tree.thenp, defn.UnitType) - val elsep1 = tpd.unitLiteral.withSpan(tree.span.endPos) - cpy.If(tree)(cond1, thenp1, elsep1).withType(defn.UnitType) - } - else { - val thenp1 :: elsep1 :: Nil = harmonic(harmonize, pt)( - (tree.thenp :: tree.elsep :: Nil).map(typed(_, pt.dropIfProto))) - assignType(cpy.If(tree)(cond1, thenp1, elsep1), thenp1, elsep1) - } - } + val result = + if tree.elsep.isEmpty then + val thenp1 = typed(tree.thenp, defn.UnitType)(given cond1.nullableContextIf(true)) + val elsep1 = tpd.unitLiteral.withSpan(tree.span.endPos) + cpy.If(tree)(cond1, thenp1, elsep1).withType(defn.UnitType) + else + val thenp1 :: elsep1 :: Nil = harmonic(harmonize, pt) { + val thenp0 = typed(tree.thenp, pt.dropIfProto)(given cond1.nullableContextIf(true)) + val elsep0 = typed(tree.elsep, pt.dropIfProto)(given cond1.nullableContextIf(false)) + thenp0 :: elsep0 :: Nil + } + assignType(cpy.If(tree)(cond1, thenp1, elsep1), thenp1, elsep1) + + def thenPathInfo = cond1.notNullInfoIf(true).seq(result.thenp.notNullInfo) + def elsePathInfo = cond1.notNullInfoIf(false).seq(result.elsep.notNullInfo) + result.withNotNullInfo( + if result.thenp.tpe.isRef(defn.NothingClass) then elsePathInfo + else if result.elsep.tpe.isRef(defn.NothingClass) then thenPathInfo + else thenPathInfo.alt(elsePathInfo) + ) + end typedIf /** Decompose function prototype into a list of parameter prototypes and a result prototype * tree, using WildcardTypes where a type is not known. @@ -1124,13 +1140,18 @@ class Typer extends Namer // Overridden in InlineTyper for inline matches def typedMatchFinish(tree: untpd.Match, sel: Tree, wideSelType: Type, cases: List[untpd.CaseDef], pt: Type)(implicit ctx: Context): Tree = { - val cases1 = harmonic(harmonize, pt)(typedCases(cases, wideSelType, pt.dropIfProto)) + val cases1 = harmonic(harmonize, pt)(typedCases(cases, sel, wideSelType, pt.dropIfProto)) .asInstanceOf[List[CaseDef]] assignType(cpy.Match(tree)(sel, cases1), sel, cases1) } - def typedCases(cases: List[untpd.CaseDef], selType: Type, pt: Type)(implicit ctx: Context): List[CaseDef] = - cases.mapconserve(typedCase(_, selType, pt)) + def typedCases(cases: List[untpd.CaseDef], sel: Tree, wideSelType: Type, pt: Type)(implicit ctx: Context): List[CaseDef] = + var caseCtx = ctx + cases.mapconserve { cas => + val case1 = typedCase(cas, sel, wideSelType, pt)(given caseCtx) + caseCtx = Nullables.afterPatternContext(sel, case1.pat) + case1 + } /** - strip all instantiated TypeVars from pattern types. * run/reducable.scala is a test case that shows stripping typevars is necessary. @@ -1157,7 +1178,7 @@ class Typer extends Namer } /** Type a case. */ - def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type)(implicit ctx: Context): CaseDef = { + def typedCase(tree: untpd.CaseDef, sel: Tree, wideSelType: Type, pt: Type)(implicit ctx: Context): CaseDef = { val originalCtx = ctx val gadtCtx: Context = ctx.fresh.setFreshGADTBounds @@ -1170,8 +1191,10 @@ class Typer extends Namer assignType(cpy.CaseDef(tree)(pat1, guard1, body1), pat1, body1) } - val pat1 = typedPattern(tree.pat, selType)(gadtCtx) - caseRest(pat1)(gadtCtx.fresh.setNewScope) + val pat1 = typedPattern(tree.pat, wideSelType)(gadtCtx) + caseRest(pat1)( + given Nullables.caseContext(sel, pat1)( + given gadtCtx.fresh.setNewScope)) } def typedLabeled(tree: untpd.Labeled)(implicit ctx: Context): Labeled = { @@ -1191,7 +1214,6 @@ class Typer extends Namer caseRest(ctx.fresh.setFreshGADTBounds.setNewScope) } - def typedReturn(tree: untpd.Return)(implicit ctx: Context): Return = { def returnProto(owner: Symbol, locals: Scope): Type = if (owner.isConstructor) defn.UnitType @@ -1238,17 +1260,19 @@ class Typer extends Namer } def typedWhileDo(tree: untpd.WhileDo)(implicit ctx: Context): Tree = { + given whileCtx: Context = Nullables.whileContext(tree.span)(given ctx) val cond1 = if (tree.cond eq EmptyTree) EmptyTree else typed(tree.cond, defn.BooleanType) - val body1 = typed(tree.body, defn.UnitType) + val body1 = typed(tree.body, defn.UnitType)(given cond1.nullableContextIf(true)) assignType(cpy.WhileDo(tree)(cond1, body1)) + .withNotNullInfo(body1.notNullInfo.retractedInfo.seq(cond1.notNullInfoIf(false))) } def typedTry(tree: untpd.Try, pt: Type)(implicit ctx: Context): Try = { val expr2 :: cases2x = harmonic(harmonize, pt) { val expr1 = typed(tree.expr, pt.dropIfProto) - val cases1 = typedCases(tree.cases, defn.ThrowableType, pt.dropIfProto) + val cases1 = typedCases(tree.cases, EmptyTree, defn.ThrowableType, pt.dropIfProto) expr1 :: cases1 } val finalizer1 = typed(tree.finalizer, defn.UnitType) @@ -1291,7 +1315,7 @@ class Typer extends Namer } def typedInlined(tree: untpd.Inlined, pt: Type)(implicit ctx: Context): Tree = { - val (exprCtx, bindings1) = typedBlockStats(tree.bindings) + val (bindings1, exprCtx) = typedBlockStats(tree.bindings) val expansion1 = typed(tree.expansion, pt)(inlineContext(tree.call)(exprCtx)) assignType(cpy.Inlined(tree)(tree.call, bindings1.asInstanceOf[List[MemberDef]], expansion1), bindings1, expansion1) @@ -1530,6 +1554,16 @@ class Typer extends Namer typed(annot, defn.AnnotationClass.typeRef) def typedValDef(vdef: untpd.ValDef, sym: Symbol)(implicit ctx: Context): Tree = { + sym.infoOrCompleter match + case completer: Namer#Completer + if completer.creationContext.notNullInfos ne ctx.notNullInfos => + // The RHS of a val def should know about not null facts established + // in preceding statements (unless the ValDef is completed ahead of time, + // then it is impossible). + vdef.symbol.info = Completer(completer.original)( + given completer.creationContext.withNotNullInfos(ctx.notNullInfos)) + case _ => + val ValDef(name, tpt, _) = vdef completeAnnotations(vdef, sym) if (sym.isOneOf(GivenOrImplicit)) checkImplicitConversionDefOK(sym) @@ -1728,7 +1762,7 @@ class Typer extends Namer else { val dummy = localDummy(cls, impl) val body1 = addAccessorDefs(cls, - typedStats(impl.body, dummy)(ctx.inClassContext(self1.symbol))) + typedStats(impl.body, dummy)(ctx.inClassContext(self1.symbol))._1) checkNoDoubleDeclaration(cls) val impl1 = cpy.Template(impl)(constr1, parents1, Nil, self1, body1) @@ -1848,9 +1882,9 @@ class Typer extends Namer case pid1: RefTree if pkg.exists => if (!pkg.is(Package)) ctx.error(PackageNameAlreadyDefined(pkg), tree.sourcePos) val packageCtx = ctx.packageContext(tree, pkg) - var stats1 = typedStats(tree.stats, pkg.moduleClass)(packageCtx) + var stats1 = typedStats(tree.stats, pkg.moduleClass)(packageCtx)._1 if (!ctx.isAfterTyper) - stats1 = stats1 ++ typedBlockStats(MainProxies.mainProxies(stats1))(packageCtx)._2 + stats1 = stats1 ++ typedBlockStats(MainProxies.mainProxies(stats1))(packageCtx)._1 cpy.PackageDef(tree)(pid1, stats1).withType(pkg.termRef) case _ => // Package will not exist if a duplicate type has already been entered, see `tests/neg/1708.scala` @@ -2167,11 +2201,12 @@ class Typer extends Namer def typedTrees(trees: List[untpd.Tree])(implicit ctx: Context): List[Tree] = trees mapconserve (typed(_)) - def typedStats(stats: List[untpd.Tree], exprOwner: Symbol)(implicit ctx: Context): List[Tree] = { + def typedStats(stats: List[untpd.Tree], exprOwner: Symbol)(implicit ctx: Context): (List[Tree], Context) = { val buf = new mutable.ListBuffer[Tree] val enumContexts = new mutable.HashMap[Symbol, Context] + val initialNotNullInfos = ctx.notNullInfos // A map from `enum` symbols to the contexts enclosing their definitions - @tailrec def traverse(stats: List[untpd.Tree])(implicit ctx: Context): List[Tree] = stats match { + @tailrec def traverse(stats: List[untpd.Tree])(implicit ctx: Context): (List[Tree], Context) = stats match { case (imp: untpd.Import) :: rest => val imp1 = typed(imp) buf += imp1 @@ -2181,7 +2216,14 @@ class Typer extends Namer case Some(xtree) => traverse(xtree :: rest) case none => - typed(mdef) match { + val defCtx = mdef match + // Keep preceding not null facts in the current context only if `mdef` + // cannot be executed out-of-sequence. + case _: ValDef if !mdef.mods.is(Lazy) && ctx.owner.isTerm => + ctx // all preceding statements will have been executed in this case + case _ => + ctx.withNotNullInfos(initialNotNullInfos) + typed(mdef)(given defCtx) match { case mdef1: DefDef if !Inliner.bodyToInline(mdef1.symbol).isEmpty => buf += inlineExpansion(mdef1) // replace body with expansion, because it will be used as inlined body @@ -2206,9 +2248,9 @@ class Typer extends Namer val stat1 = typed(stat)(ctx.exprContext(stat, exprOwner)) checkStatementPurity(stat1)(stat, exprOwner) buf += stat1 - traverse(rest) + traverse(rest)(given stat1.nullableContext) case nil => - buf.toList + (buf.toList, ctx) } val localCtx = { val exprOwnerOpt = if (exprOwner == ctx.owner) None else Some(exprOwner) @@ -2225,9 +2267,10 @@ class Typer extends Namer case _ => stat } - val stats1 = traverse(stats)(localCtx).mapConserve(finalize) + val (stats0, finalCtx) = traverse(stats)(localCtx) + val stats1 = stats0.mapConserve(finalize) if (ctx.owner == exprOwner) checkNoAlphaConflict(stats1) - stats1 + (stats1, finalCtx) } /** Given an inline method `mdef`, the method rewritten so that its body diff --git a/compiler/src/dotty/tools/repl/ReplDriver.scala b/compiler/src/dotty/tools/repl/ReplDriver.scala index acfa33f02765..e181ff042c72 100644 --- a/compiler/src/dotty/tools/repl/ReplDriver.scala +++ b/compiler/src/dotty/tools/repl/ReplDriver.scala @@ -272,7 +272,7 @@ class ReplDriver(settings: Array[String], val vals = info.fields - .filterNot(_.symbol.isOneOf(ParamAccessor | Private | Synthetic | Module)) + .filterNot(_.symbol.isOneOf(ParamAccessor | Private | Synthetic | Artifact | Module)) .filter(_.symbol.name.is(SimpleNameKind)) .sortBy(_.name) diff --git a/compiler/test-resources/repl/i5218 b/compiler/test-resources/repl/i5218 index 3402ed98e286..abe63009ef74 100644 --- a/compiler/test-resources/repl/i5218 +++ b/compiler/test-resources/repl/i5218 @@ -3,4 +3,4 @@ val tuple: (Int, String, Long) = (1,2,3) scala> 0.0 *: tuple val res0: (Double, Int, String, Long) = (0.0,1,2,3) scala> tuple ++ tuple -val res1: Int *: String *: Long *: scala.Tuple.Concat[Unit, tuple.type] = (1,2,3,1,2,3) \ No newline at end of file +val res1: Int *: String *: Long *: tuple.type = (1,2,3,1,2,3) diff --git a/compiler/test/dotc/pos-from-tasty.blacklist b/compiler/test/dotc/pos-from-tasty.blacklist index f3a7568dcf01..1fac00b403ab 100644 --- a/compiler/test/dotc/pos-from-tasty.blacklist +++ b/compiler/test/dotc/pos-from-tasty.blacklist @@ -8,4 +8,8 @@ t3612.scala t802.scala # Matchtype -i7087.scala \ No newline at end of file +i7087.scala + +# Nullability +nullable.scala +notNull.scala diff --git a/compiler/test/dotc/pos-test-pickling.blacklist b/compiler/test/dotc/pos-test-pickling.blacklist index 740ff1108fad..59a9b8498080 100644 --- a/compiler/test/dotc/pos-test-pickling.blacklist +++ b/compiler/test/dotc/pos-test-pickling.blacklist @@ -28,3 +28,6 @@ i5720.scala # Tuples toexproftuple.scala + +# Nullability +nullable.scala diff --git a/compiler/test/dotty/tools/dotc/CompilationTests.scala b/compiler/test/dotty/tools/dotc/CompilationTests.scala index 78915d3ce2f8..aefcb8b90d6e 100644 --- a/compiler/test/dotty/tools/dotc/CompilationTests.scala +++ b/compiler/test/dotty/tools/dotc/CompilationTests.scala @@ -59,6 +59,7 @@ class CompilationTests extends ParallelTesting { compileFile("tests/pos-special/typeclass-scaling.scala", defaultOptions.and("-Xmax-inlines", "40")), compileFile("tests/pos-special/indent-colons.scala", defaultOptions.and("-Yindent-colons")), compileFile("tests/pos-special/i7296.scala", defaultOptions.and("-strict", "-deprecation", "-Xfatal-warnings")), + compileFile("tests/pos-special/nullable.scala", defaultOptions.and("-Yexplicit-nulls")), compileDir("tests/pos-special/adhoc-extension", defaultOptions.and("-strict", "-feature", "-Xfatal-warnings")) ).checkCompile() } diff --git a/library/src/dotty/DottyPredef.scala b/library/src/dotty/DottyPredef.scala index 529266cb3922..c934bf0ddd67 100644 --- a/library/src/dotty/DottyPredef.scala +++ b/library/src/dotty/DottyPredef.scala @@ -8,13 +8,13 @@ object DottyPredef { assertFail(message) } - inline final def assert(assertion: => Boolean): Unit = { + inline final def assert(assertion: => Boolean) <: Unit = { if (!assertion) assertFail() } - def assertFail(): Unit = throw new java.lang.AssertionError("assertion failed") - def assertFail(message: => Any): Unit = throw new java.lang.AssertionError("assertion failed: " + message) + def assertFail(): Nothing = throw new java.lang.AssertionError("assertion failed") + def assertFail(message: => Any): Nothing = throw new java.lang.AssertionError("assertion failed: " + message) inline final def implicitly[T](implicit ev: T): T = ev diff --git a/tests/pos-special/nullable.scala b/tests/pos-special/nullable.scala new file mode 100644 index 000000000000..d2d68f0dcfd1 --- /dev/null +++ b/tests/pos-special/nullable.scala @@ -0,0 +1,66 @@ +trait T { def f: Int } +def impossible(x: Any): Unit = + val y = x + +def test: Unit = + val x, x2, x3, x4 = "" + + if x != null then + if x == null then impossible(new T{}) + + if x == null then () + else + if x == null then impossible(new T{}) + + if x == null || { + if x == null then impossible(new T{}) + true + } + then () + + if x != null && { + if x == null then impossible(new T{}) + true + } + then () + + if !(x == null) && { + if x == null then impossible(new T{}) + true + } + then () + + x match + case _: String => + if x == null then impossible(new T{}) + + val y: Any = List(x) + y match + case y1 :: ys => if y == null then impossible(new T{}) + case Some(_) | Seq(_: _*) => if y == null then impossible(new T{}) + + x match + case null => + case _ => if x == null then impossible(new T{}) + + if x == null then return + if x == null then impossible(new T{}) + + if x2 == null then throw AssertionError() + if x2 == null then impossible(new T{}) + + if !(x3 != null) then throw AssertionError() + if x3 == null then impossible(new T{}) + + assert(x4 != null) + if x4 == null then impossible(new T{}) + + class C(val x: Int, val next: C) + var xs: C = C(1, C(2, null)) + while xs != null do + if xs == null then println("?") + // looking at this with -Xprint-frontend -Xprint-types shows that the + // type of `xs == null` is indeed `false`. We cannot currently use this in a test + // since `xs == null` is not technically a pure expression since `xs` is not a path. + // We should test variable tracking once this is integrated with explicit not null types. + xs = xs.next diff --git a/tests/pos/notNull.scala b/tests/pos/notNull.scala new file mode 100644 index 000000000000..3d46fe658948 --- /dev/null +++ b/tests/pos/notNull.scala @@ -0,0 +1,22 @@ +trait Null extends Any +object Test with + def notNull[A](x: A | Null): x.type & A = + assert(x != null) + x.asInstanceOf // TODO: drop the .asInstanceOf when explicit nulls are implemented + + locally { + val x: (Int | Null) = ??? + val y = x; val _: Int | Null = y + } + locally { + val x: Int | Null = ??? + val y = notNull(identity(x)); val yc: Int = y + val z = notNull(x); val zc: Int = z + } + class C { type T } + locally { + val x: C { type T = Int } = new C { type T = Int } + val xnn: x.type & C { type T = Int } = notNull(x) + val y: xnn.T = 33 + val z = y; val zc: Int = z + }