Skip to content

Treat Scala.js pseudo-unions as real unions #11671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 16, 2021
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/backend/sjs/JSDefinitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ final class JSDefinitions()(using Context) {
@threadUnsafe lazy val PseudoUnion_fromTypeConstructorR = PseudoUnionModule.requiredMethodRef("fromTypeConstructor")
def PseudoUnion_fromTypeConstructor(using Context) = PseudoUnion_fromTypeConstructorR.symbol

@threadUnsafe lazy val UnionOpsModuleRef = requiredModuleRef("scala.scalajs.js.internal.UnitOps")

@threadUnsafe lazy val JSArrayType: TypeRef = requiredClassRef("scala.scalajs.js.Array")
def JSArrayClass(using Context) = JSArrayType.symbol.asClass
@threadUnsafe lazy val JSDynamicType: TypeRef = requiredClassRef("scala.scalajs.js.Dynamic")
Expand Down
58 changes: 38 additions & 20 deletions compiler/src/dotty/tools/dotc/core/TypeErasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Symbols._, Types._, Contexts._, Flags._, Names._, StdNames._, Phases._
import Flags.JavaDefined
import Uniques.unique
import TypeOps.makePackageObjPrefixExplicit
import backend.sjs.JSDefinitions
import transform.ExplicitOuter._
import transform.ValueClasses._
import transform.TypeUtils._
Expand Down Expand Up @@ -142,29 +143,31 @@ object TypeErasure {
}
}

private def erasureIdx(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConstructor: Boolean, wildcardOK: Boolean) =
private def erasureIdx(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConstructor: Boolean, isSymbol: Boolean, wildcardOK: Boolean) =
extension (b: Boolean) def toInt = if b then 1 else 0
wildcardOK.toInt
+ (isConstructor.toInt << 1)
+ (semiEraseVCs.toInt << 2)
+ (sourceLanguage.ordinal << 3)
+ (isSymbol.toInt << 1)
+ (isConstructor.toInt << 2)
+ (semiEraseVCs.toInt << 3)
+ (sourceLanguage.ordinal << 4)

private val erasures = new Array[TypeErasure](1 << (SourceLanguage.bits + 3))
private val erasures = new Array[TypeErasure](1 << (SourceLanguage.bits + 4))

for
sourceLanguage <- SourceLanguage.values
semiEraseVCs <- List(false, true)
isConstructor <- List(false, true)
isSymbol <- List(false, true)
wildcardOK <- List(false, true)
do
erasures(erasureIdx(sourceLanguage, semiEraseVCs, isConstructor, wildcardOK)) =
new TypeErasure(sourceLanguage, semiEraseVCs, isConstructor, wildcardOK)
erasures(erasureIdx(sourceLanguage, semiEraseVCs, isConstructor, isSymbol, wildcardOK)) =
new TypeErasure(sourceLanguage, semiEraseVCs, isConstructor, isSymbol, wildcardOK)

/** Produces an erasure function. See the documentation of the class [[TypeErasure]]
* for a description of each parameter.
*/
private def erasureFn(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConstructor: Boolean, wildcardOK: Boolean): TypeErasure =
erasures(erasureIdx(sourceLanguage, semiEraseVCs, isConstructor, wildcardOK))
private def erasureFn(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConstructor: Boolean, isSymbol: Boolean, wildcardOK: Boolean): TypeErasure =
erasures(erasureIdx(sourceLanguage, semiEraseVCs, isConstructor, isSymbol, wildcardOK))

/** The current context with a phase no later than erasure */
def preErasureCtx(using Context) =
Expand All @@ -175,19 +178,19 @@ object TypeErasure {
* @param tp The type to erase.
*/
def erasure(tp: Type)(using Context): Type =
erasureFn(sourceLanguage = SourceLanguage.Scala3, semiEraseVCs = false, isConstructor = false, wildcardOK = false)(tp)(using preErasureCtx)
erasureFn(sourceLanguage = SourceLanguage.Scala3, semiEraseVCs = false, isConstructor = false, isSymbol = false, wildcardOK = false)(tp)(using preErasureCtx)

/** The value class erasure of a Scala type, where value classes are semi-erased to
* ErasedValueType (they will be fully erased in [[ElimErasedValueType]]).
*
* @param tp The type to erase.
*/
def valueErasure(tp: Type)(using Context): Type =
erasureFn(sourceLanguage = SourceLanguage.Scala3, semiEraseVCs = true, isConstructor = false, wildcardOK = false)(tp)(using preErasureCtx)
erasureFn(sourceLanguage = SourceLanguage.Scala3, semiEraseVCs = true, isConstructor = false, isSymbol = false, wildcardOK = false)(tp)(using preErasureCtx)

/** The erasure that Scala 2 would use for this type. */
def scala2Erasure(tp: Type)(using Context): Type =
erasureFn(sourceLanguage = SourceLanguage.Scala2, semiEraseVCs = true, isConstructor = false, wildcardOK = false)(tp)(using preErasureCtx)
erasureFn(sourceLanguage = SourceLanguage.Scala2, semiEraseVCs = true, isConstructor = false, isSymbol = false, wildcardOK = false)(tp)(using preErasureCtx)

/** Like value class erasure, but value classes erase to their underlying type erasure */
def fullErasure(tp: Type)(using Context): Type =
Expand All @@ -197,7 +200,7 @@ object TypeErasure {

def sigName(tp: Type, sourceLanguage: SourceLanguage)(using Context): TypeName = {
val normTp = tp.translateFromRepeated(toArray = sourceLanguage.isJava)
val erase = erasureFn(sourceLanguage, semiEraseVCs = !sourceLanguage.isJava, isConstructor = false, wildcardOK = true)
val erase = erasureFn(sourceLanguage, semiEraseVCs = !sourceLanguage.isJava, isConstructor = false, isSymbol = false, wildcardOK = true)
erase.sigName(normTp)(using preErasureCtx)
}

Expand Down Expand Up @@ -227,7 +230,7 @@ object TypeErasure {
def transformInfo(sym: Symbol, tp: Type)(using Context): Type = {
val sourceLanguage = SourceLanguage(sym)
val semiEraseVCs = !sourceLanguage.isJava // Java sees our value classes as regular classes.
val erase = erasureFn(sourceLanguage, semiEraseVCs, sym.isConstructor, wildcardOK = false)
val erase = erasureFn(sourceLanguage, semiEraseVCs, sym.isConstructor, isSymbol = true, wildcardOK = false)

def eraseParamBounds(tp: PolyType): Type =
tp.derivedLambdaType(
Expand Down Expand Up @@ -446,10 +449,11 @@ import TypeErasure._
* (they will be fully erased in [[ElimErasedValueType]]).
* If false, they are erased like normal classes.
* @param isConstructor Argument forms part of the type of a constructor
* @param isSymbol If true, the type being erased is the info of a symbol.
* @param wildcardOK Wildcards are acceptable (true when using the erasure
* for computing a signature name).
*/
class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConstructor: Boolean, wildcardOK: Boolean) {
class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConstructor: Boolean, isSymbol: Boolean, wildcardOK: Boolean) {

/** The erasure |T| of a type T. This is:
*
Expand Down Expand Up @@ -520,10 +524,22 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
else
erasedGlb(this(tp1), this(tp2), isJava = sourceLanguage.isJava)
case OrType(tp1, tp2) =>
TypeComparer.orType(this(tp1), this(tp2), isErased = true)
if isSymbol && sourceLanguage.isScala2 && ctx.settings.scalajs.value then
// In Scala2Unpickler we unpickle Scala.js pseudo-unions as if they were
// real unions, but we must still erase them as Scala 2 would to emit
// the correct signatures in SJSIR.
// We only do this when `isSymbol` is true since in other situations we
// cannot distinguish a Scala.js pseudo-union from a Scala 3 union that
// has been substituted into a Scala 2 type (e.g., via `asSeenFrom`),
// erasing these unions as if they were pseudo-unions could have an
// impact on overriding relationships so it's best to leave them
// alone (and this doesn't impact the SJSIR we generate).
JSDefinitions.jsdefn.PseudoUnionType
else
TypeComparer.orType(this(tp1), this(tp2), isErased = true)
case tp: MethodType =>
def paramErasure(tpToErase: Type) =
erasureFn(sourceLanguage, semiEraseVCs, isConstructor, wildcardOK)(tpToErase)
erasureFn(sourceLanguage, semiEraseVCs, isConstructor, isSymbol, wildcardOK)(tpToErase)
val (names, formals0) = if (tp.isErasedMethod) (Nil, Nil) else (tp.paramNames, tp.paramInfos)
val formals = formals0.mapConserve(paramErasure)
eraseResult(tp.resultType) match {
Expand Down Expand Up @@ -567,7 +583,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
val defn.ArrayOf(elemtp) = tp
if (classify(elemtp).derivesFrom(defn.NullClass)) JavaArrayType(defn.ObjectType)
else if (isUnboundedGeneric(elemtp) && !sourceLanguage.isJava) defn.ObjectType
else JavaArrayType(erasureFn(sourceLanguage, semiEraseVCs = false, isConstructor, wildcardOK)(elemtp))
else JavaArrayType(erasureFn(sourceLanguage, semiEraseVCs = false, isConstructor, isSymbol, wildcardOK)(elemtp))
}

private def erasePair(tp: Type)(using Context): Type = {
Expand Down Expand Up @@ -608,7 +624,9 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
val genericUnderlying = unbox.info.resultType
val underlying = tp.select(unbox).widen.resultType

val erasedUnderlying = erasure(underlying)
// The underlying part of an ErasedValueType cannot be an ErasedValueType itself
val erase = erasureFn(sourceLanguage, semiEraseVCs = false, isConstructor, isSymbol, wildcardOK)
val erasedUnderlying = erase(underlying)

// Ideally, we would just use `erasedUnderlying` as the erasure of `tp`, but to
// be binary-compatible with Scala 2 we need two special cases for polymorphic
Expand Down Expand Up @@ -646,7 +664,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
// correctly (see SIP-15 and [[Erasure.Boxing.adaptToType]]), so the result type of a
// constructor method should not be semi-erased.
if semiEraseVCs && isConstructor && !tp.isInstanceOf[MethodOrPoly] then
erasureFn(sourceLanguage, semiEraseVCs = false, isConstructor, wildcardOK).eraseResult(tp)
erasureFn(sourceLanguage, semiEraseVCs = false, isConstructor, isSymbol, wildcardOK).eraseResult(tp)
else tp match
case tp: TypeRef =>
val sym = tp.symbol
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import NameKinds.{Scala2MethodNameKinds, SuperAccessorName, ExpandedName}
import util.Spans._
import dotty.tools.dotc.ast.{tpd, untpd}, ast.tpd._
import ast.untpd.Modifiers
import backend.sjs.JSDefinitions
import printing.Texts._
import printing.Printer
import io.AbstractFile
Expand Down Expand Up @@ -675,6 +676,10 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas

def removeSingleton(tp: Type): Type =
if (tp isRef defn.SingletonClass) defn.AnyType else tp
def mapArg(arg: Type) = arg match {
case arg: TypeRef if isBound(arg) => arg.symbol.info
case _ => arg
}
def elim(tp: Type): Type = tp match {
case tp @ RefinedType(parent, name, rinfo) =>
val parent1 = elim(tp.parent)
Expand All @@ -690,12 +695,11 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
}
case tp @ AppliedType(tycon, args) =>
val tycon1 = tycon.safeDealias
def mapArg(arg: Type) = arg match {
case arg: TypeRef if isBound(arg) => arg.symbol.info
case _ => arg
}
if (tycon1 ne tycon) elim(tycon1.appliedTo(args))
else tp.derivedAppliedType(tycon, args.map(mapArg))
case tp: AndOrType =>
// scalajs.js.|.UnionOps has a type parameter upper-bounded by `_ | _`
tp.derivedAndOrType(mapArg(tp.tp1).bounds.hi, mapArg(tp.tp2).bounds.hi)
case _ =>
tp
}
Expand Down Expand Up @@ -777,6 +781,12 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
val tycon = select(pre, sym)
val args = until(end, () => readTypeRef())
if (sym == defn.ByNameParamClass2x) ExprType(args.head)
else if (ctx.settings.scalajs.value && args.length == 2 &&
sym.owner == JSDefinitions.jsdefn.ScalaJSJSPackageClass && sym == JSDefinitions.jsdefn.PseudoUnionClass) {
// Treat Scala.js pseudo-unions as real unions, this requires a
// special-case in erasure, see TypeErasure#eraseInfo.
OrType(args(0), args(1), soft = false)
}
else if (args.nonEmpty) tycon.safeAppliedTo(EtaExpandIfHK(sym.typeParams, args.map(translateTempPoly)))
else if (sym.typeParams.nonEmpty) tycon.EtaExpand(sym.typeParams)
else tycon
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/reporting/messages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2126,7 +2126,7 @@ import transform.SymUtils._
val addendum =
if (scrutTp != testTp) s" is a subtype of ${testTp.show}"
else " is the same as the tested type"
s"The highlighted type test will always succeed since the scrutinee type ($scrutTp.show)" + addendum
s"The highlighted type test will always succeed since the scrutinee type ${scrutTp.show}" + addendum
}
def explain = ""
}
Expand Down
9 changes: 9 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/Implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dotty.tools
package dotc
package typer

import backend.sjs.JSDefinitions
import core._
import ast.{Trees, TreeTypeMap, untpd, tpd, DesugarEnums}
import util.Spans._
Expand Down Expand Up @@ -634,6 +635,14 @@ trait ImplicitRunInfo:
else pre.member(sym.name.toTermName)
.suchThat(companion => companion.is(Module) && companion.owner == sym.owner)
.symbol)

// The companion of `js.|` defines an implicit conversions from
// `A | Unit` to `js.UndefOrOps[A]`. To keep this conversion in scope
// in Scala 3, where we re-interpret `js.|` as a real union, we inject
// it in the scope of `Unit`.
if t.isRef(defn.UnitClass) && ctx.settings.scalajs.value then
companions += JSDefinitions.jsdefn.UnionOpsModuleRef

if sym.isClass then
for p <- t.parents do companions ++= iscopeRefs(p)
else
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/ImportSuggestions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package dotty.tools
package dotc
package typer

import backend.sjs.JSDefinitions
import core._
import Contexts._, Types._, Symbols._, Names._, Decorators._, ProtoTypes._
import Flags._, SymDenotations._
import NameKinds.FlatName
import NameOps._
import StdNames._
import config.Printers.{implicits, implicitsDetailed}
import util.Spans.Span
import ast.{untpd, tpd}
Expand Down Expand Up @@ -64,6 +66,8 @@ trait ImportSuggestions:
else !root.name.is(FlatName)
&& !root.name.lastPart.contains('$')
&& root.is(ModuleVal, butNot = JavaDefined)
// The implicits in `scalajs.js.|` are implementation details and shouldn't be suggested
&& !(root.name == nme.raw.BAR && ctx.settings.scalajs.value && root == JSDefinitions.jsdefn.PseudoUnionModule)
}

def nestedRoots(site: Type)(using Context): List[Symbol] =
Expand Down
9 changes: 8 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dotty.tools
package dotc
package typer

import backend.sjs.JSDefinitions
import core._
import ast._
import Trees._
Expand Down Expand Up @@ -219,7 +220,13 @@ class Typer extends Namer
denot = denot.filterWithPredicate { mbr =>
mbr.matchesImportBound(if mbr.symbol.is(Given) then imp.givenBound else imp.wildcardBound)
}
if reallyExists(denot) then
def isScalaJsPseudoUnion =
denot.name == tpnme.raw.BAR && ctx.settings.scalajs.value && denot.symbol == JSDefinitions.jsdefn.PseudoUnionClass
// Just like Scala2Unpickler reinterprets Scala.js pseudo-unions
// as real union types, we want references to `A | B` in sources
// to be typed as a real union even if `js.|` has been imported,
// so we ignore that import.
if reallyExists(denot) && !isScalaJsPseudoUnion then
if unimported.isEmpty || !unimported.contains(pre.termSymbol) then
return pre.select(name, denot)
case _ =>
Expand Down
8 changes: 8 additions & 0 deletions library-js/src/scala/scalajs/js/internal/UnitOps.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package scala.scalajs.js.internal

import scala.scalajs.js

/** Under -scalajs, this object is part of the implicit scope of `scala.Unit` */
object UnitOps:
implicit def unitOrOps[A](x: A | Unit): js.UndefOrOps[A] =
new js.UndefOrOps(x)
8 changes: 6 additions & 2 deletions project/Build.scala
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ object Build {
settings(
libraryDependencies +=
("org.scala-js" %% "scalajs-library" % scalaJSVersion).withDottyCompat(scalaVersion.value),
Compile / unmanagedSourceDirectories :=
Compile / unmanagedSourceDirectories ++=
(`scala3-library-bootstrapped` / Compile / unmanagedSourceDirectories).value,

// Configure the source maps to point to GitHub for releases
Expand Down Expand Up @@ -1105,9 +1105,13 @@ object Build {
-- "ObjectTest.scala" // compile errors caused by #9588
-- "StackTraceTest.scala" // would require `npm install source-map-support`
-- "UnionTypeTest.scala" // requires the Scala 2 macro defined in Typechecking*.scala
-- "PromiseMock.scala" // TODO: Enable once we use a Scala.js with https://github.com/scala-js/scala-js/pull/4451 in
// and remove copy in tests/sjs-junit
)).get

++ (dir / "js/src/test/require-2.12" ** "*.scala").get
++ (dir / "js/src/test/require-2.12" ** (("*.scala": FileFilter)
-- "JSOptionalTest212.scala" // TODO: Enable once we use a Scala.js with https://github.com/scala-js/scala-js/pull/4451 in
)).get
++ (dir / "js/src/test/require-sam" ** "*.scala").get
++ (dir / "js/src/test/scala-new-collections" ** "*.scala").get
)
Expand Down
16 changes: 16 additions & 0 deletions sbt-dotty/sbt-test/scala2-compat/erasure-scalajs/build.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
lazy val scala2Lib = project.in(file("scala2Lib"))
.enablePlugins(ScalaJSPlugin)
.settings(
// TODO: switch to 2.13.5 once we've upgrade sbt-scalajs to 1.5.0
scalaVersion := "2.13.4"
)

lazy val dottyApp = project.in(file("dottyApp"))
.dependsOn(scala2Lib)
.enablePlugins(ScalaJSPlugin)
.settings(
scalaVersion := sys.props("plugin.scalaVersion"),

scalaJSUseMainModuleInitializer := true,
scalaJSLinkerConfig ~= (_.withCheckIR(true)),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
object Main {
def main(args: Array[String]): Unit = {
val a = new scala2Lib.A
assert(a.foo(1) == "1")
assert(a.foo("") == "1")
assert(a.foo(Array(1)) == "2")

val b = new scala2Lib.B
assert(b.foo(1) == "1")
assert(b.foo("") == "1")
assert(b.foo(Array(1)) == "2")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
addSbtPlugin("ch.epfl.lamp" % "sbt-dotty" % sys.props("plugin.version"))
addSbtPlugin("org.scala-js" % "sbt-scalajs" % sys.props("plugin.scalaJSVersion"))
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Keep synchronized with dottyApp/Api.scala
package scala2Lib

import scala.scalajs.js
import js.|

class A {
def foo(x: Int | String): String = "1"
def foo(x: Array[Int]): String = "2"
}

class B extends js.Object {
def foo(x: Int | String): String = "1"
def foo(x: Array[Int]): String = "2"
}
1 change: 1 addition & 0 deletions sbt-dotty/sbt-test/scala2-compat/erasure-scalajs/test
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
> dottyApp/run
9 changes: 2 additions & 7 deletions sbt-dotty/sbt-test/scala2-compat/erasure/build.sbt
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
lazy val scala2Lib = project.in(file("scala2Lib"))
.settings(
scalaVersion := "2.13.2"
scalaVersion := "2.13.5"
)

lazy val dottyApp = project.in(file("dottyApp"))
.dependsOn(scala2Lib)
.settings(
scalaVersion := sys.props("plugin.scalaVersion"),
// https://github.com/sbt/sbt/issues/5369
projectDependencies := {
projectDependencies.value.map(_.withDottyCompat(scalaVersion.value))
}
scalaVersion := sys.props("plugin.scalaVersion")
)

Loading