Skip to content

Commit 1e13a92

Browse files
committed
Fix #9011: Make single enum values inherit from Product
1 parent 157ad25 commit 1e13a92

File tree

4 files changed

+73
-3
lines changed

4 files changed

+73
-3
lines changed

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ object DesugarEnums {
124124

125125
/** A creation method for a value of enum type `E`, which is defined as follows:
126126
*
127-
* private def $new(_$ordinal: Int, $name: String) = new E {
127+
* private def $new(_$ordinal: Int, $name: String) = new E with EnumValue {
128128
* def $ordinal = $tag
129129
* override def toString = $name
130130
* $values.register(this)
@@ -135,7 +135,7 @@ object DesugarEnums {
135135
val toStringDef = toStringMeth(Ident(nme.nameDollar))
136136
val creator = New(Template(
137137
constr = emptyConstructor,
138-
parents = enumClassRef :: Nil,
138+
parents = enumClassRef :: scalaDot(str.EnumValue.toTypeName) :: Nil,
139139
derived = Nil,
140140
self = EmptyValDef,
141141
body = List(ordinalDef, toStringDef) ++ registerCall
@@ -286,7 +286,9 @@ object DesugarEnums {
286286
val (tag, scaffolding) = nextOrdinal(CaseKind.Object)
287287
val ordinalDef = ordinalMethLit(tag)
288288
val toStringDef = toStringMethLit(name.toString)
289-
val impl1 = cpy.Template(impl)(body = List(ordinalDef, toStringDef) ++ registerCall)
289+
val impl1 = cpy.Template(impl)(
290+
parents = impl.parents :+ scalaDot(str.EnumValue.toTypeName),
291+
body = List(ordinalDef, toStringDef) ++ registerCall)
290292
.withAttachment(ExtendsSingletonMirror, ())
291293
val vdef = ValDef(name, TypeTree(), New(impl1)).withMods(mods.withAddedFlags(EnumValue, span))
292294
flatTree(scaffolding ::: vdef :: Nil).withSpan(span)

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ object StdNames {
3333

3434
final val MODULE_INSTANCE_FIELD = "MODULE$"
3535

36+
final val EnumValue = "EnumValue"
3637
final val Function = "Function"
3738
final val ErasedFunction = "ErasedFunction"
3839
final val ContextFunction = "ContextFunction"

library/src/scala/EnumValue.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package scala
2+
3+
trait EnumValue extends Product:
4+
override def canEqual(that: Any) = true
5+
override def productArity: Int = 0
6+
override def productPrefix: String = toString
7+
override def productElement(n: Int): Any =
8+
throw IndexOutOfBoundsException(n.toString())
9+
override def productElementName(n: Int): String =
10+
throw IndexOutOfBoundsException(n.toString())

tests/run/i9011.scala

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
enum Opt[+T] derives Eq:
2+
case Sm(t: T)
3+
case Nn
4+
5+
import scala.deriving._
6+
import scala.compiletime.{erasedValue, summonInline}
7+
8+
trait Eq[T] {
9+
def eqv(x: T, y: T): Boolean
10+
}
11+
12+
object Eq {
13+
given Eq[Int] {
14+
def eqv(x: Int, y: Int) = x == y
15+
}
16+
17+
inline def summonAll[T <: Tuple]: List[Eq[_]] = inline erasedValue[T] match {
18+
case _: Unit => Nil
19+
case _: (t *: ts) => summonInline[Eq[t]] :: summonAll[ts]
20+
}
21+
22+
def check(elem: Eq[_])(x: Any, y: Any): Boolean =
23+
elem.asInstanceOf[Eq[Any]].eqv(x, y)
24+
25+
def iterator[T](p: T) = p.asInstanceOf[Product].productIterator
26+
27+
def eqSum[T](s: Mirror.SumOf[T], elems: List[Eq[_]]): Eq[T] =
28+
new Eq[T] {
29+
def eqv(x: T, y: T): Boolean = {
30+
val ordx = s.ordinal(x)
31+
(s.ordinal(y) == ordx) && check(elems(ordx))(x, y)
32+
}
33+
}
34+
35+
def eqProduct[T](p: Mirror.ProductOf[T], elems: List[Eq[_]]): Eq[T] =
36+
new Eq[T] {
37+
def eqv(x: T, y: T): Boolean =
38+
iterator(x).zip(iterator(y)).zip(elems.iterator).forall {
39+
case ((x, y), elem) => check(elem)(x, y)
40+
}
41+
}
42+
43+
inline given derived[T](using m: Mirror.Of[T]) as Eq[T] = {
44+
val elemInstances = summonAll[m.MirroredElemTypes]
45+
inline m match {
46+
case s: Mirror.SumOf[T] => eqSum(s, elemInstances)
47+
case p: Mirror.ProductOf[T] => eqProduct(p, elemInstances)
48+
}
49+
}
50+
}
51+
52+
object Test extends App {
53+
import Opt._
54+
val eqoi = summon[Eq[Opt[Int]]]
55+
assert(eqoi.eqv(Sm(23), Sm(23)))
56+
assert(eqoi.eqv(Nn, Nn))
57+
}

0 commit comments

Comments
 (0)