Skip to content

Commit 62c2a1e

Browse files
authored
Merge pull request #1958 from dotty-staging/add-enum
Add "enum" construct
2 parents 2556c83 + 30d8d87 commit 62c2a1e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+834
-184
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 89 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import reporting.diagnostic.messages._
1414

1515
object desugar {
1616
import untpd._
17+
import DesugarEnums._
1718

1819
/** Tags a .withFilter call generated by desugaring a for expression.
1920
* Such calls can alternatively be rewritten to use filter.
@@ -263,7 +264,9 @@ object desugar {
263264
val className = checkNotReservedName(cdef).asTypeName
264265
val impl @ Template(constr0, parents, self, _) = cdef.rhs
265266
val mods = cdef.mods
266-
val companionMods = mods.withFlags((mods.flags & AccessFlags).toCommonFlags)
267+
val companionMods = mods
268+
.withFlags((mods.flags & AccessFlags).toCommonFlags)
269+
.withMods(mods.mods.filter(!_.isInstanceOf[Mod.EnumCase]))
267270

268271
val (constr1, defaultGetters) = defDef(constr0, isPrimaryConstructor = true) match {
269272
case meth: DefDef => (meth, Nil)
@@ -288,17 +291,31 @@ object desugar {
288291
}
289292

290293
val isCaseClass = mods.is(Case) && !mods.is(Module)
294+
val isEnum = mods.hasMod[Mod.Enum]
295+
val isEnumCase = isLegalEnumCase(cdef)
291296
val isValueClass = parents.nonEmpty && isAnyVal(parents.head)
292297
// This is not watertight, but `extends AnyVal` will be replaced by `inline` later.
293298

294-
val constrTparams = constr1.tparams map toDefParam
299+
lazy val reconstitutedTypeParams = reconstitutedEnumTypeParams(cdef.pos.startPos)
300+
301+
val originalTparams =
302+
if (isEnumCase && parents.isEmpty) {
303+
if (constr1.tparams.nonEmpty) {
304+
if (reconstitutedTypeParams.nonEmpty)
305+
ctx.error(em"case with type parameters needs extends clause", constr1.tparams.head.pos)
306+
constr1.tparams
307+
}
308+
else reconstitutedTypeParams
309+
}
310+
else constr1.tparams
311+
val originalVparamss = constr1.vparamss
312+
val constrTparams = originalTparams.map(toDefParam)
295313
val constrVparamss =
296-
if (constr1.vparamss.isEmpty) { // ensure parameter list is non-empty
297-
if (isCaseClass)
298-
ctx.error(CaseClassMissingParamList(cdef), cdef.namePos)
314+
if (originalVparamss.isEmpty) { // ensure parameter list is non-empty
315+
if (isCaseClass) ctx.error(CaseClassMissingParamList(cdef), cdef.namePos)
299316
ListOfNil
300317
}
301-
else constr1.vparamss.nestedMap(toDefParam)
318+
else originalVparamss.nestedMap(toDefParam)
302319
val constr = cpy.DefDef(constr1)(tparams = constrTparams, vparamss = constrVparamss)
303320

304321
// Add constructor type parameters and evidence implicit parameters
@@ -312,21 +329,24 @@ object desugar {
312329
stat
313330
}
314331

315-
val derivedTparams = constrTparams map derivedTypeParam
332+
val derivedTparams =
333+
if (isEnumCase) constrTparams else constrTparams map derivedTypeParam
316334
val derivedVparamss = constrVparamss nestedMap derivedTermParam
317335
val arity = constrVparamss.head.length
318336

319-
var classTycon: Tree = EmptyTree
337+
val classTycon: Tree = new TypeRefTree // watching is set at end of method
320338

321-
// a reference to the class type, with all parameters given.
322-
val classTypeRef/*: Tree*/ = {
323-
// -language:keepUnions difference: classTypeRef needs type annotation, otherwise
324-
// infers Ident | AppliedTypeTree, which
325-
// renders the :\ in companions below untypable.
326-
classTycon = (new TypeRefTree) withPos cdef.pos.startPos // watching is set at end of method
327-
val tparams = impl.constr.tparams
328-
if (tparams.isEmpty) classTycon else AppliedTypeTree(classTycon, tparams map refOfDef)
329-
}
339+
def appliedRef(tycon: Tree) =
340+
(if (constrTparams.isEmpty) tycon
341+
else AppliedTypeTree(tycon, constrTparams map refOfDef))
342+
.withPos(cdef.pos.startPos)
343+
344+
// a reference to the class type bound by `cdef`, with type parameters coming from the constructor
345+
val classTypeRef = appliedRef(classTycon)
346+
// a reference to `enumClass`, with type parameters coming from the constructor
347+
lazy val enumClassTypeRef =
348+
if (reconstitutedTypeParams.isEmpty) enumClassRef
349+
else appliedRef(enumClassRef)
330350

331351
// new C[Ts](paramss)
332352
lazy val creatorExpr = New(classTypeRef, constrVparamss nestedMap refOfDef)
@@ -374,7 +394,9 @@ object desugar {
374394
DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, TypeTree(), creatorExpr)
375395
.withMods(synthetic) :: Nil
376396
}
377-
copyMeths ::: productElemMeths.toList
397+
398+
val enumTagMeths = if (isEnumCase) enumTagMeth(CaseKind.Class)._1 :: Nil else Nil
399+
copyMeths ::: enumTagMeths ::: productElemMeths.toList
378400
}
379401
else Nil
380402

@@ -387,8 +409,12 @@ object desugar {
387409

388410
// Case classes and case objects get a ProductN parent
389411
var parents1 = parents
412+
if (isEnumCase && parents.isEmpty)
413+
parents1 = enumClassTypeRef :: Nil
390414
if (mods.is(Case) && arity <= Definitions.MaxTupleArity)
391-
parents1 = parents1 :+ productConstr(arity)
415+
parents1 = parents1 :+ productConstr(arity) // TODO: This also adds Product0 to caes objects. Do we want that?
416+
if (isEnum)
417+
parents1 = parents1 :+ ref(defn.EnumType)
392418

393419
// The thicket which is the desugared version of the companion object
394420
// synthetic object C extends parentTpt { defs }
@@ -410,17 +436,26 @@ object desugar {
410436
// For all other classes, the parent is AnyRef.
411437
val companions =
412438
if (isCaseClass) {
439+
// The return type of the `apply` method
440+
val applyResultTpt =
441+
if (isEnumCase)
442+
if (parents.isEmpty) enumClassTypeRef
443+
else parents.reduceLeft(AndTypeTree)
444+
else TypeTree()
445+
413446
val parent =
414447
if (constrTparams.nonEmpty ||
415448
constrVparamss.length > 1 ||
416449
mods.is(Abstract) ||
417450
constr.mods.is(Private)) anyRef
451+
else
418452
// todo: also use anyRef if constructor has a dependent method type (or rule that out)!
419-
else (constrVparamss :\ classTypeRef) ((vparams, restpe) => Function(vparams map (_.tpt), restpe))
453+
(constrVparamss :\ (if (isEnumCase) applyResultTpt else classTypeRef)) (
454+
(vparams, restpe) => Function(vparams map (_.tpt), restpe))
420455
val applyMeths =
421456
if (mods is Abstract) Nil
422457
else
423-
DefDef(nme.apply, derivedTparams, derivedVparamss, TypeTree(), creatorExpr)
458+
DefDef(nme.apply, derivedTparams, derivedVparamss, applyResultTpt, creatorExpr)
424459
.withFlags(Synthetic | (constr1.mods.flags & DefaultParameterized)) :: Nil
425460
val unapplyMeth = {
426461
val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)
@@ -464,15 +499,15 @@ object desugar {
464499
else cpy.ValDef(self)(tpt = selfType).withMods(self.mods | SelfName)
465500
}
466501

467-
val cdef1 = {
468-
val originalTparams = constr1.tparams.toIterator
469-
val originalVparams = constr1.vparamss.toIterator.flatten
470-
val tparamAccessors = derivedTparams.map(_.withMods(originalTparams.next.mods))
502+
val cdef1 = addEnumFlags {
503+
val originalTparamsIt = originalTparams.toIterator
504+
val originalVparamsIt = originalVparamss.toIterator.flatten
505+
val tparamAccessors = derivedTparams.map(_.withMods(originalTparamsIt.next.mods))
471506
val caseAccessor = if (isCaseClass) CaseAccessor else EmptyFlags
472507
val vparamAccessors = derivedVparamss match {
473508
case first :: rest =>
474-
first.map(_.withMods(originalVparams.next.mods | caseAccessor)) ++
475-
rest.flatten.map(_.withMods(originalVparams.next.mods))
509+
first.map(_.withMods(originalVparamsIt.next.mods | caseAccessor)) ++
510+
rest.flatten.map(_.withMods(originalVparamsIt.next.mods))
476511
case _ =>
477512
Nil
478513
}
@@ -503,23 +538,26 @@ object desugar {
503538
*/
504539
def moduleDef(mdef: ModuleDef)(implicit ctx: Context): Tree = {
505540
val moduleName = checkNotReservedName(mdef).asTermName
506-
val tmpl = mdef.impl
541+
val impl = mdef.impl
507542
val mods = mdef.mods
543+
lazy val isEnumCase = isLegalEnumCase(mdef)
508544
if (mods is Package)
509-
PackageDef(Ident(moduleName), cpy.ModuleDef(mdef)(nme.PACKAGE, tmpl).withMods(mods &~ Package) :: Nil)
545+
PackageDef(Ident(moduleName), cpy.ModuleDef(mdef)(nme.PACKAGE, impl).withMods(mods &~ Package) :: Nil)
546+
else if (isEnumCase)
547+
expandEnumModule(moduleName, impl, mods, mdef.pos)
510548
else {
511549
val clsName = moduleName.moduleClassName
512550
val clsRef = Ident(clsName)
513551
val modul = ValDef(moduleName, clsRef, New(clsRef, Nil))
514552
.withMods(mods | ModuleCreationFlags | mods.flags & AccessFlags)
515553
.withPos(mdef.pos)
516-
val ValDef(selfName, selfTpt, _) = tmpl.self
517-
val selfMods = tmpl.self.mods
518-
if (!selfTpt.isEmpty) ctx.error(ObjectMayNotHaveSelfType(mdef), tmpl.self.pos)
519-
val clsSelf = ValDef(selfName, SingletonTypeTree(Ident(moduleName)), tmpl.self.rhs)
554+
val ValDef(selfName, selfTpt, _) = impl.self
555+
val selfMods = impl.self.mods
556+
if (!selfTpt.isEmpty) ctx.error(ObjectMayNotHaveSelfType(mdef), impl.self.pos)
557+
val clsSelf = ValDef(selfName, SingletonTypeTree(Ident(moduleName)), impl.self.rhs)
520558
.withMods(selfMods)
521-
.withPos(tmpl.self.pos orElse tmpl.pos.startPos)
522-
val clsTmpl = cpy.Template(tmpl)(self = clsSelf, body = tmpl.body)
559+
.withPos(impl.self.pos orElse impl.pos.startPos)
560+
val clsTmpl = cpy.Template(impl)(self = clsSelf, body = impl.body)
523561
val cls = TypeDef(clsName, clsTmpl)
524562
.withMods(mods.toTypeFlags & RetainedModuleClassFlags | ModuleClassCreationFlags)
525563
Thicket(modul, classDef(cls).withPos(mdef.pos))
@@ -542,11 +580,23 @@ object desugar {
542580
/** val p1, ..., pN: T = E
543581
* ==>
544582
* makePatDef[[val p1: T1 = E]]; ...; makePatDef[[val pN: TN = E]]
583+
*
584+
* case e1, ..., eN
585+
* ==>
586+
* expandSimpleEnumCase([case e1]); ...; expandSimpleEnumCase([case eN])
545587
*/
546-
def patDef(pdef: PatDef)(implicit ctx: Context): Tree = {
588+
def patDef(pdef: PatDef)(implicit ctx: Context): Tree = flatTree {
547589
val PatDef(mods, pats, tpt, rhs) = pdef
548-
val pats1 = if (tpt.isEmpty) pats else pats map (Typed(_, tpt))
549-
flatTree(pats1 map (makePatDef(pdef, mods, _, rhs)))
590+
if (mods.hasMod[Mod.EnumCase] && enumCaseIsLegal(pdef))
591+
pats map {
592+
case id: Ident =>
593+
expandSimpleEnumCase(id.name.asTermName, mods,
594+
Position(pdef.pos.start, id.pos.end, id.pos.start))
595+
}
596+
else {
597+
val pats1 = if (tpt.isEmpty) pats else pats map (Typed(_, tpt))
598+
pats1 map (makePatDef(pdef, mods, _, rhs))
599+
}
550600
}
551601

552602
/** If `pat` is a variable pattern,
@@ -923,7 +973,7 @@ object desugar {
923973
case (gen: GenFrom) :: (rest @ (GenFrom(_, _) :: _)) =>
924974
val cont = makeFor(mapName, flatMapName, rest, body)
925975
Apply(rhsSelect(gen, flatMapName), makeLambda(gen.pat, cont))
926-
case (enum @ GenFrom(pat, rhs)) :: (rest @ GenAlias(_, _) :: _) =>
976+
case (GenFrom(pat, rhs)) :: (rest @ GenAlias(_, _) :: _) =>
927977
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
928978
val pats = valeqs map { case GenAlias(pat, _) => pat }
929979
val rhss = valeqs map { case GenAlias(_, rhs) => rhs }
@@ -1024,7 +1074,6 @@ object desugar {
10241074
List(CaseDef(Ident(nme.DEFAULT_EXCEPTION_NAME), EmptyTree, Apply(handler, Ident(nme.DEFAULT_EXCEPTION_NAME)))),
10251075
finalizer)
10261076
}
1027-
10281077
}
10291078
}.withPos(tree.pos)
10301079

0 commit comments

Comments
 (0)