diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index 7352139ea0e9..f6aba763da4e 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -861,6 +861,21 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] => case _ => None } } + + /** Extractor for not-null assertions. + * A not-null assertion for reference `x` has the form `x.$asInstanceOf$[x.type & T]`. + */ + object AssertNotNull with + def apply(tree: tpd.Tree, tpnn: Type)(given Context): tpd.Tree = + tree.select(defn.Any_typeCast).appliedToType(AndType(tree.tpe, tpnn)) + + def unapply(tree: tpd.TypeApply)(given Context): Option[tpd.Tree] = tree match + case TypeApply(Select(qual: RefTree, nme.asInstanceOfPM), arg :: Nil) => + arg.tpe match + case AndType(ref, _) if qual.tpe eq ref => Some(qual) + case _ => None + case _ => None + end AssertNotNull } object TreeInfo { diff --git a/compiler/src/dotty/tools/dotc/config/Printers.scala b/compiler/src/dotty/tools/dotc/config/Printers.scala index 6a0549501f3d..3536302bd865 100644 --- a/compiler/src/dotty/tools/dotc/config/Printers.scala +++ b/compiler/src/dotty/tools/dotc/config/Printers.scala @@ -30,6 +30,7 @@ object Printers { val lexical: Printer = noPrinter val inlining: Printer = noPrinter val interactiv: Printer = noPrinter + val nullables: Printer = noPrinter val overload: Printer = noPrinter val patmatch: Printer = noPrinter val pickling: Printer = noPrinter diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 4c4ec903c090..a3073e06c25a 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -22,7 +22,7 @@ import NameKinds.DefaultGetterName import ProtoTypes._ import Inferencing._ import transform.TypeUtils._ -import Nullables.given +import Nullables.{postProcessByNameArgs, given} import collection.mutable import config.Printers.{overload, typr, unapp} @@ -794,7 +794,7 @@ trait Applications extends Compatibility { /** Subclass of Application for type checking an Apply node with untyped arguments. */ class ApplyToUntyped(app: untpd.Apply, fun: Tree, methRef: TermRef, proto: FunProto, resultType: Type)(implicit ctx: Context) extends TypedApply(app, fun, methRef, proto.args, resultType) { - def typedArg(arg: untpd.Tree, formal: Type): TypedArg = proto.typedArg(arg, formal.widenExpr) + def typedArg(arg: untpd.Tree, formal: Type): TypedArg = proto.typedArg(arg, formal) def treeToArg(arg: Tree): untpd.Tree = untpd.TypedSplice(arg) def typeOfArg(arg: untpd.Tree): Type = proto.typeOfArg(arg) } @@ -868,7 +868,8 @@ trait Applications extends Compatibility { else new ApplyToUntyped(tree, fun1, funRef, proto, pt)( given fun1.nullableInArgContext(given argCtx(tree))) - convertNewGenericArray(app.result).computeNullable() + convertNewGenericArray( + postProcessByNameArgs(funRef, app.result).computeNullable()) case _ => handleUnexpectedFunType(tree, fun1) } @@ -1030,7 +1031,7 @@ trait Applications extends Compatibility { * It is performed during typer as creation of generic arrays needs a classTag. * we rely on implicit search to find one. */ - def convertNewGenericArray(tree: Tree)(implicit ctx: Context): Tree = tree match { + def convertNewGenericArray(tree: Tree)(implicit ctx: Context): Tree = tree match { case Apply(TypeApply(tycon, targs@(targ :: Nil)), args) if tycon.symbol == defn.ArrayConstructor => fullyDefinedType(tree.tpe, "array", tree.span) diff --git a/compiler/src/dotty/tools/dotc/typer/Nullables.scala b/compiler/src/dotty/tools/dotc/typer/Nullables.scala index 0142b23dcdc0..f66f5e996b90 100644 --- a/compiler/src/dotty/tools/dotc/typer/Nullables.scala +++ b/compiler/src/dotty/tools/dotc/typer/Nullables.scala @@ -12,6 +12,8 @@ import util.Spans.Span import Flags._ import NullOpsDecorator._ import collection.mutable +import config.Printers.nullables +import ast.{tpd, untpd} /** Operations for implementing a flow analysis for nullability */ object Nullables with @@ -182,6 +184,13 @@ object Nullables with then infos else info :: infos + /** Retract all references to mutable variables */ + def retractMutables(given Context) = + val mutables = infos.foldLeft(Set[TermRef]())((ms, info) => + ms.union(info.asserted.filter(_.symbol.is(Mutable)))) + infos.extendWith(NotNullInfo(Set(), mutables)) + end notNullInfoOps + given refOps: extension (ref: TermRef) with /** Is the use of a mutable variable out of order @@ -443,4 +452,70 @@ object Nullables with val retractedVars = curCtx.notNullInfos.flatMap(_.asserted.filter(isRetracted)).toSet curCtx.addNotNullInfo(NotNullInfo(Set(), retractedVars)) + /** Post process all arguments to by-name parameters by removing any not-null + * info that was used when typing them. Concretely: + * If an argument corresponds to a call-by-name parameter, drop all + * embedded not-null assertions of the form `x.$asInstanceOf[x.type & T]` + * where `x` is a reference to a mutable variable. If the argument still typechecks + * with the removed assertions and is still compatible with the formal parameter, + * keep it. Otherwise issue an error that the call-by-name argument was typed using + * flow assumptions about mutable variables and suggest that it is enclosed + * in a `byName(...)` call instead. + */ + def postProcessByNameArgs(fn: TermRef, app: Tree)(given ctx: Context): Tree = + fn.widen match + case mt: MethodType if mt.paramInfos.exists(_.isInstanceOf[ExprType]) => + app match + case Apply(fn, args) => + val dropNotNull = new TreeMap with + override def transform(t: Tree)(given Context) = t match + case AssertNotNull(t0) if t0.symbol.is(Mutable) => + nullables.println(i"dropping $t") + transform(t0) + case t: ValDef if !t.symbol.is(Lazy) => super.transform(t) + case t: MemberDef => + // stop here since embedded references to mutable variables would be + // out of order, so they would not asserted ot be not-null anyway. + // @see Nullables.usedOutOfOrder + t + case _ => super.transform(t) + + object retyper extends ReTyper with + override def typedUnadapted(t: untpd.Tree, pt: Type, locked: TypeVars)(implicit ctx: Context): Tree = t match + case t: ValDef if !t.symbol.is(Lazy) => super.typedUnadapted(t, pt, locked) + case t: MemberDef => promote(t) + case _ => super.typedUnadapted(t, pt, locked) + + def postProcess(formal: Type, arg: Tree): Tree = + val arg1 = dropNotNull.transform(arg) + if arg1 eq arg then arg + else + val nestedCtx = ctx.fresh.setNewTyperState() + val arg2 = retyper.typed(arg1, formal)(given nestedCtx) + if nestedCtx.reporter.hasErrors || !(arg2.tpe <:< formal) then + ctx.error(em"""This argument was typed using flow assumptions about mutable variables + |but it is passed to a by-name parameter where such flow assumptions are unsound. + |Wrapping the argument in `byName(...)` fixes the problem by disabling the flow assumptions. + | + |`byName` needs to be imported from the `scala.compiletime` package.""", + arg.sourcePos) + arg + else + nestedCtx.typerState.commit() + arg2 + + def recur(formals: List[Type], args: List[Tree]): List[Tree] = (formals, args) match + case (formal :: formalsRest, arg :: argsRest) => + val arg1 = postProcess(formal.widenExpr.repeatedToSingle, arg) + val argsRest1 = recur( + if formal.isRepeatedParam then formals else formalsRest, + argsRest) + if (arg1 eq arg) && (argsRest1 eq argsRest) then args + else arg1 :: argsRest1 + case _ => args + + tpd.cpy.Apply(app)(fn, recur(mt.paramInfos, args)) + case _ => app + case _ => app + end postProcessByNameArgs end Nullables diff --git a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala index 9e379c6f7367..cb2bdb20822a 100644 --- a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala +++ b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala @@ -322,9 +322,15 @@ object ProtoTypes { * used to avoid repeated typings of trees when backtracking. */ def typedArg(arg: untpd.Tree, formal: Type)(implicit ctx: Context): Tree = { + val wideFormal = formal.widenExpr + val argCtx = + if wideFormal eq formal then ctx + else ctx.withNotNullInfos(ctx.notNullInfos.retractMutables) val locked = ctx.typerState.ownedVars - val targ = cacheTypedArg(arg, typer.typedUnadapted(_, formal, locked), force = true) - typer.adapt(targ, formal, locked) + val targ = cacheTypedArg(arg, + typer.typedUnadapted(_, wideFormal, locked)(given argCtx), + force = true) + typer.adapt(targ, wideFormal, locked) } /** The type of the argument `arg`, or `NoType` if `arg` has not been typed before diff --git a/compiler/test/dotty/tools/dotc/CompilationTests.scala b/compiler/test/dotty/tools/dotc/CompilationTests.scala index 5b2c86f2bc09..fe65b48bda91 100644 --- a/compiler/test/dotty/tools/dotc/CompilationTests.scala +++ b/compiler/test/dotty/tools/dotc/CompilationTests.scala @@ -115,6 +115,7 @@ class CompilationTests extends ParallelTesting { compileFilesInDir("tests/neg-custom-args/deprecation", defaultOptions.and("-Xfatal-warnings", "-deprecation")), compileFilesInDir("tests/neg-custom-args/fatal-warnings", defaultOptions.and("-Xfatal-warnings")), compileFilesInDir("tests/neg-custom-args/allow-double-bindings", allowDoubleBindings), + compileFilesInDir("tests/neg-custom-args/explicit-nulls", defaultOptions.and("-Yexplicit-nulls")), compileDir("tests/neg-custom-args/impl-conv", defaultOptions.and("-Xfatal-warnings", "-feature")), compileFile("tests/neg-custom-args/implicit-conversions.scala", defaultOptions.and("-Xfatal-warnings", "-feature")), compileFile("tests/neg-custom-args/implicit-conversions-old.scala", defaultOptions.and("-Xfatal-warnings", "-feature")), diff --git a/library/src/scala/compiletime/package.scala b/library/src/scala/compiletime/package.scala index 28ac8e519637..023183972c88 100644 --- a/library/src/scala/compiletime/package.scala +++ b/library/src/scala/compiletime/package.scala @@ -63,4 +63,7 @@ package object compiletime { * } */ type S[N <: Int] <: Int + + /** Assertion that an argument is by-name. Used for nullability checking. */ + def byName[T](x: => T): T = x } diff --git a/tests/neg-custom-args/explicit-nulls/byname-nullables.check b/tests/neg-custom-args/explicit-nulls/byname-nullables.check new file mode 100644 index 000000000000..bb8969474289 --- /dev/null +++ b/tests/neg-custom-args/explicit-nulls/byname-nullables.check @@ -0,0 +1,28 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/explicit-nulls/byname-nullables.scala:19:24 ----------------------- +19 | if x != null then f(x) // error: f is call-by-name + | ^ + | Found: (x : String | Null) + | Required: String +-- Error: tests/neg-custom-args/explicit-nulls/byname-nullables.scala:43:32 -------------------------------------------- +43 | if x != null then f(identity(x), 1) // error: dropping not null check fails typing + | ^^^^^^^^^^^ + | This argument was typed using flow assumptions about mutable variables + | but it is passed to a by-name parameter where such flow assumptions are unsound. + | Wrapping the argument in `byName(...)` fixes the problem by disabling the flow assumptions. + | + | `byName` needs to be imported from the `scala.compiletime` package. +-- Error: tests/neg-custom-args/explicit-nulls/byname-nullables.scala:68:24 -------------------------------------------- +68 | if x != null then f(x, 1) // error: dropping not null check typechecks OK, but gives incompatible result type + | ^ + | This argument was typed using flow assumptions about mutable variables + | but it is passed to a by-name parameter where such flow assumptions are unsound. + | Wrapping the argument in `byName(...)` fixes the problem by disabling the flow assumptions. + | + | `byName` needs to be imported from the `scala.compiletime` package. +-- [E134] Type Mismatch Error: tests/neg-custom-args/explicit-nulls/byname-nullables.scala:81:22 ----------------------- +81 | if x != null then f(byName(x), 1) // error: none of the overloaded methods match argument types + | ^ + | None of the overloaded alternatives of method f in object Test7 with types + | (x: => String, y: Int): String + | (x: String, y: String): String + | match arguments (String | Null, (1 : Int)) diff --git a/tests/neg-custom-args/explicit-nulls/byname-nullables.scala b/tests/neg-custom-args/explicit-nulls/byname-nullables.scala new file mode 100644 index 000000000000..ada838b269f6 --- /dev/null +++ b/tests/neg-custom-args/explicit-nulls/byname-nullables.scala @@ -0,0 +1,82 @@ +object Test1 with + + def f(x: String) = + x ++ x + + def g() = + var x: String | Null = "abc" + if x != null then f(x) // OK: f is call-by-value + else x + + +object Test2 with + + def f(x: => String) = + x ++ x + + def g() = + var x: String | Null = "abc" + if x != null then f(x) // error: f is call-by-name + else x + +object Test3 with + + def f(x: String, y: String) = x + + def f(x: => String | Null, y: Int) = + x + + def g() = + var x: String | Null = "abc" + if x != null then f(x, 1) // OK: not-null check successfully dropped + else x + +object Test4 with + + def f(x: String, y: String) = x + + def f(x: => String | Null, y: Int) = + x + + def g() = + var x: String | Null = "abc" + if x != null then f(identity(x), 1) // error: dropping not null check fails typing + else x + +object Test5 with + import compiletime.byName + + def f(x: String, y: String) = x + + def f(x: => String | Null, y: Int) = + x + + def g() = + var x: String | Null = "abc" + if x != null then f(byName(identity(x)), 1) // OK, byName avoids the flow typing + else x + +object Test6 with + + def f(x: String, y: String) = x + + def f(x: => String, y: Int) = + x + + def g() = + var x: String | Null = "abc" + if x != null then f(x, 1) // error: dropping not null check typechecks OK, but gives incompatible result type + else x + +object Test7 with + import compiletime.byName + + def f(x: String, y: String) = x + + def f(x: => String, y: Int) = + x + + def g() = + var x: String | Null = "abc" + if x != null then f(byName(x), 1) // error: none of the overloaded methods match argument types + else x