Skip to content

Commit ecfeaa7

Browse files
committed
Improve variance computation and printing
1 parent 0dd1aba commit ecfeaa7

File tree

7 files changed

+31
-26
lines changed

7 files changed

+31
-26
lines changed

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@ object DesugarEnums {
4949
val tparams = enumClass.typeParams
5050
def isGround(tp: Type) = tp.subst(tparams, tparams.map(_ => NoType)) eq tp
5151
val targs = tparams map { tparam =>
52-
if (tparam.variance > 0 && isGround(tparam.info.bounds.lo))
52+
if (tparam.is(Covariant) && isGround(tparam.info.bounds.lo))
5353
tparam.info.bounds.lo
54-
else if (tparam.variance < 0 && isGround(tparam.info.bounds.hi))
54+
else if (tparam.is(Contravariant) && isGround(tparam.info.bounds.hi))
5555
tparam.info.bounds.hi
5656
else {
5757
def problem =
58-
if (tparam.variance == 0) "is non variant"
58+
if (!tparam.isOneOf(VarianceFlags)) "is non variant"
5959
else "has bounds that depend on a type parameter in the same parameter list"
6060
errorType(i"""cannot determine type argument for enum parent $enumClass,
6161
|type parameter $tparam $problem""", ctx.source.atSpan(span))

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import dotty.tools.io.AbstractFile
1212
import Decorators.SymbolIteratorDecorator
1313
import ast._
1414
import Trees.Literal
15+
import Variances.Variance
1516
import annotation.tailrec
1617
import util.SimpleIdentityMap
1718
import util.Stats
@@ -1374,13 +1375,13 @@ object SymDenotations {
13741375
def namedType(implicit ctx: Context): NamedType =
13751376
if (isType) typeRef else termRef
13761377

1377-
/** The variance of this type parameter or type member as an Int, with
1378-
* +1 = Covariant, -1 = Contravariant, 0 = Nonvariant, or not a type parameter
1378+
/** The variance of this type parameter or type member as a subset of
1379+
* {Covariant, Contravariant}
13791380
*/
1380-
final def variance(implicit ctx: Context): Int =
1381-
if (this.is(Covariant)) 1
1382-
else if (this.is(Contravariant)) -1
1383-
else 0
1381+
final def variance(implicit ctx: Context): Variance =
1382+
if is(Covariant) then Covariant
1383+
else if is(Contravariant) then Contravariant
1384+
else EmptyFlags
13841385

13851386
/** The flags to be used for a type parameter owned by this symbol.
13861387
* Overridden by ClassDenotation.

compiler/src/dotty/tools/dotc/core/Symbols.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ object Symbols {
700700
def paramInfo(implicit ctx: Context): Type = denot.info
701701
def paramInfoAsSeenFrom(pre: Type)(implicit ctx: Context): Type = pre.memberInfo(this)
702702
def paramInfoOrCompleter(implicit ctx: Context): Type = denot.infoOrCompleter
703-
def paramVariance(implicit ctx: Context): Variance = varianceFromInt(denot.variance)
703+
def paramVariance(implicit ctx: Context): Variance = denot.variance
704704
def paramRef(implicit ctx: Context): TypeRef = denot.typeRef
705705

706706
// -------- Printing --------------------------------------------------------

compiler/src/dotty/tools/dotc/core/Variances.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ object Variances {
119119
case _ =>
120120
foldOver(x, t)
121121
}
122+
// Note: Normally, we'd need to repeat `traverse` until a fixpoint is reached.
123+
// But since recursive lambdas can only appear in bounds, and bound never have
124+
// structural variances, a single traversal is enough.
122125
traverse((), lam.resType)
123126

124127
/** Does variance `v1` conform to variance `v2`?
@@ -144,15 +147,12 @@ object Variances {
144147
if needsDetailedCheck then tparams1.corresponds(tparams2)(varianceConforms)
145148
else tparams1.hasSameLengthAs(tparams2)
146149

147-
def varianceString(sym: Symbol)(implicit ctx: Context): String =
148-
varianceString(sym.variance)
150+
def varianceSign(sym: Symbol)(implicit ctx: Context): String =
151+
varianceSign(sym.variance)
149152

150-
def varianceString(v: Variance): String =
151-
if (v is Covariant) "covariant"
152-
else if (v is Contravariant) "contravariant"
153-
else "invariant"
153+
def varianceSign(v: Variance): String = varianceSign(varianceToInt(v))
154154

155-
def varianceString(v: Int): String =
155+
def varianceSign(v: Int): String =
156156
if (v > 0) "+"
157157
else if (v < 0) "-"
158158
else ""

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import StdNames.nme
88
import ast.Trees._
99
import typer.Implicits._
1010
import typer.ImportInfo
11-
import Variances.{varianceString, varianceToInt}
11+
import Variances.varianceSign
1212
import util.SourcePosition
1313
import java.lang.Integer.toOctalString
1414
import config.Config.summarizeDepth
@@ -333,7 +333,7 @@ class PlainPrinter(_ctx: Context) extends Printer {
333333
val names =
334334
if lam.isVariantLambda then
335335
lam.paramNames.lazyZip(lam.givenVariances).map((name, v) =>
336-
varianceString(varianceToInt(v)) + name)
336+
varianceSign(v) + name)
337337
else lam.paramNames
338338
(names.mkString("[", ", ", "]"), lam.resType)
339339
case _ =>
@@ -449,7 +449,7 @@ class PlainPrinter(_ctx: Context) extends Printer {
449449

450450
private def dclTextWithInfo(sym: Symbol, info: Option[Type]): Text =
451451
(toTextFlags(sym) ~~ keyString(sym) ~~
452-
(varianceString(sym) ~ nameString(sym)) ~ toTextRHS(info)).close
452+
(varianceSign(sym.variance) ~ nameString(sym)) ~ toTextRHS(info)).close
453453

454454
def toText(sym: Symbol): Text =
455455
(kindString(sym) ~~ {

compiler/src/dotty/tools/dotc/transform/FullParameterization.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ trait FullParameterization {
9898
}
9999
val ctparams = if (abstractOverClass) clazz.typeParams else Nil
100100
val ctnames = ctparams.map(_.name)
101-
val ctvariances = ctparams.map(_.variance)
102101

103102
/** The method result type */
104103
def resultType(mapClassParams: Type => Type) = {

compiler/src/dotty/tools/dotc/typer/VarianceChecker.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ object VarianceChecker {
6565
checkType(bounds.lo)
6666
checkType(bounds.hi)
6767
end checkLambda
68+
69+
private def varianceLabel(v: Variance): String =
70+
if (v is Covariant) "covariant"
71+
else if (v is Contravariant) "contravariant"
72+
else "invariant"
6873
}
6974

7075
class VarianceChecker()(implicit ctx: Context) {
@@ -108,10 +113,10 @@ class VarianceChecker()(implicit ctx: Context) {
108113
if (relative == Bivariant) None
109114
else {
110115
val required = compose(relative, this.variance)
111-
def tvar_s = s"$tvar (${varianceString(tvar.flags)} ${tvar.showLocated})"
116+
def tvar_s = s"$tvar (${varianceLabel(tvar.flags)} ${tvar.showLocated})"
112117
def base_s = s"$base in ${base.owner}" + (if (base.owner.isClass) "" else " in " + base.owner.enclosingClass)
113-
ctx.log(s"verifying $tvar_s is ${varianceString(required)} at $base_s")
114-
ctx.log(s"relative variance: ${varianceString(relative)}")
118+
ctx.log(s"verifying $tvar_s is ${varianceLabel(required)} at $base_s")
119+
ctx.log(s"relative variance: ${varianceLabel(relative)}")
115120
ctx.log(s"current variance: ${this.variance}")
116121
ctx.log(s"owner chain: ${base.ownersIterator.toList}")
117122
if (tvar.isOneOf(required)) None
@@ -129,7 +134,7 @@ class VarianceChecker()(implicit ctx: Context) {
129134
else tp.normalized match {
130135
case tp: TypeRef =>
131136
val sym = tp.symbol
132-
if (sym.variance != 0 && base.isContainedIn(sym.owner)) checkVarianceOfSymbol(sym)
137+
if (sym.isOneOf(VarianceFlags) && base.isContainedIn(sym.owner)) checkVarianceOfSymbol(sym)
133138
else sym.info match {
134139
case MatchAlias(_) => foldOver(status, tp)
135140
case TypeAlias(alias) => this(status, alias)
@@ -160,7 +165,7 @@ class VarianceChecker()(implicit ctx: Context) {
160165
private object Traverser extends TreeTraverser {
161166
def checkVariance(sym: Symbol, pos: SourcePosition) = Validator.validateDefinition(sym) match {
162167
case Some(VarianceError(tvar, required)) =>
163-
def msg = i"${varianceString(tvar.flags)} $tvar occurs in ${varianceString(required)} position in type ${sym.info} of $sym"
168+
def msg = i"${varianceLabel(tvar.flags)} $tvar occurs in ${varianceLabel(required)} position in type ${sym.info} of $sym"
164169
if (ctx.scala2CompatMode &&
165170
(sym.owner.isConstructor || sym.ownersIterator.exists(_.isAllOf(ProtectedLocal))))
166171
ctx.migrationWarning(

0 commit comments

Comments
 (0)