Skip to content

Commit d54b1b0

Browse files
committed
Abstract over paramVariance instead of paramVarianceSign
1 parent 29082e1 commit d54b1b0

File tree

3 files changed

+21
-8
lines changed

3 files changed

+21
-8
lines changed

compiler/src/dotty/tools/dotc/core/ParamInfo.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package dotty.tools.dotc.core
33
import Names.Name
44
import Contexts.Context
55
import Types.Type
6+
import Variances.{Variance, varianceToInt}
67

78
/** A common super trait of Symbol and LambdaParam.
89
* Used to capture the attributes of type parameters which can be implemented as either.
@@ -35,7 +36,13 @@ trait ParamInfo {
3536
def paramInfoOrCompleter(implicit ctx: Context): Type
3637

3738
/** The variance of the type parameter */
38-
def paramVarianceSign(implicit ctx: Context): Int
39+
def paramVariance(implicit ctx: Context): Variance
40+
41+
/** The variance of the type parameter, as a number -1, 0, +1.
42+
* Bivariant is mapped to 1, i.e. it is treated like Covariant.
43+
*/
44+
final def paramVarianceSign(implicit ctx: Context): Int =
45+
varianceToInt(paramVariance)
3946

4047
/** A type that refers to the parameter */
4148
def paramRef(implicit ctx: Context): Type

compiler/src/dotty/tools/dotc/core/Symbols.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import ast.tpd
2222
import tpd.{Tree, TreeProvider, TreeOps}
2323
import ast.TreeTypeMap
2424
import Constants.Constant
25+
import Variances.{Variance, varianceFromInt}
2526
import reporting.diagnostic.Message
2627
import collection.mutable
2728
import io.AbstractFile
@@ -699,7 +700,7 @@ object Symbols {
699700
def paramInfo(implicit ctx: Context): Type = denot.info
700701
def paramInfoAsSeenFrom(pre: Type)(implicit ctx: Context): Type = pre.memberInfo(this)
701702
def paramInfoOrCompleter(implicit ctx: Context): Type = denot.infoOrCompleter
702-
def paramVarianceSign(implicit ctx: Context): Int = denot.variance
703+
def paramVariance(implicit ctx: Context): Variance = varianceFromInt(denot.variance)
703704
def paramRef(implicit ctx: Context): TypeRef = denot.typeRef
704705

705706
// -------- Printing --------------------------------------------------------

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3458,14 +3458,14 @@ object Types {
34583458

34593459
private def setVariances(tparams: List[LambdaParam], vs: List[Variance]): Unit =
34603460
if tparams.nonEmpty then
3461-
tparams.head.setVariance(vs.head)
3461+
tparams.head.givenVariance = vs.head
34623462
setVariances(tparams.tail, vs.tail)
34633463

34643464
val isVariant = variances.nonEmpty
34653465
if isVariant then setVariances(typeParams, variances)
34663466

34673467
def givenVariances =
3468-
if isVariant then typeParams.map(_.paramVariance)
3468+
if isVariant then typeParams.map(_.givenVariance)
34693469
else Nil
34703470

34713471
override def computeHash(bs: Binders): Int =
@@ -3479,7 +3479,7 @@ object Types {
34793479
&& isVariant == that.isVariant
34803480
&& (!isVariant
34813481
|| typeParams.corresponds(that.typeParams)((x, y) =>
3482-
x.paramVariance == y.paramVariance))
3482+
x.givenVariance == y.givenVariance))
34833483
&& {
34843484
val bs1 = new BinderPairs(this, that, bs)
34853485
paramInfos.equalElements(that.paramInfos, bs1) &&
@@ -3620,12 +3620,17 @@ object Types {
36203620
def paramInfo(implicit ctx: Context): tl.PInfo = tl.paramInfos(n)
36213621
def paramInfoAsSeenFrom(pre: Type)(implicit ctx: Context): tl.PInfo = paramInfo
36223622
def paramInfoOrCompleter(implicit ctx: Context): Type = paramInfo
3623-
def paramVarianceSign(implicit ctx: Context): Int = tl.paramNames(n).variance
36243623
def paramRef(implicit ctx: Context): Type = tl.paramRefs(n)
36253624

36263625
private var myVariance: FlagSet = UndefinedFlags
3627-
def setVariance(v: Variance): Unit = myVariance = v
3628-
def paramVariance: Variance =
3626+
def givenVariance_=(v: Variance): Unit =
3627+
assert(myVariance == UndefinedFlags)
3628+
myVariance = v
3629+
def givenVariance: Variance =
3630+
assert(myVariance != UndefinedFlags)
3631+
myVariance
3632+
3633+
def paramVariance(implicit ctx: Context): Variance =
36293634
if myVariance == UndefinedFlags then
36303635
myVariance = varianceFromInt(tl.paramNames(n).variance)
36313636
myVariance

0 commit comments

Comments
 (0)