@@ -14,6 +14,7 @@ import reporting.diagnostic.messages._
14
14
15
15
object desugar {
16
16
import untpd ._
17
+ import DesugarEnums ._
17
18
18
19
/** Tags a .withFilter call generated by desugaring a for expression.
19
20
* Such calls can alternatively be rewritten to use filter.
@@ -263,7 +264,9 @@ object desugar {
263
264
val className = checkNotReservedName(cdef).asTypeName
264
265
val impl @ Template (constr0, parents, self, _) = cdef.rhs
265
266
val mods = cdef.mods
266
- val companionMods = mods.withFlags((mods.flags & AccessFlags ).toCommonFlags)
267
+ val companionMods = mods
268
+ .withFlags((mods.flags & AccessFlags ).toCommonFlags)
269
+ .withMods(mods.mods.filter(! _.isInstanceOf [Mod .EnumCase ]))
267
270
268
271
val (constr1, defaultGetters) = defDef(constr0, isPrimaryConstructor = true ) match {
269
272
case meth : DefDef => (meth, Nil )
@@ -288,17 +291,31 @@ object desugar {
288
291
}
289
292
290
293
val isCaseClass = mods.is(Case ) && ! mods.is(Module )
294
+ val isEnum = mods.hasMod[Mod .Enum ]
295
+ val isEnumCase = isLegalEnumCase(cdef)
291
296
val isValueClass = parents.nonEmpty && isAnyVal(parents.head)
292
297
// This is not watertight, but `extends AnyVal` will be replaced by `inline` later.
293
298
294
- val constrTparams = constr1.tparams map toDefParam
299
+ lazy val reconstitutedTypeParams = reconstitutedEnumTypeParams(cdef.pos.startPos)
300
+
301
+ val originalTparams =
302
+ if (isEnumCase && parents.isEmpty) {
303
+ if (constr1.tparams.nonEmpty) {
304
+ if (reconstitutedTypeParams.nonEmpty)
305
+ ctx.error(em " case with type parameters needs extends clause " , constr1.tparams.head.pos)
306
+ constr1.tparams
307
+ }
308
+ else reconstitutedTypeParams
309
+ }
310
+ else constr1.tparams
311
+ val originalVparamss = constr1.vparamss
312
+ val constrTparams = originalTparams.map(toDefParam)
295
313
val constrVparamss =
296
- if (constr1.vparamss.isEmpty) { // ensure parameter list is non-empty
297
- if (isCaseClass)
298
- ctx.error(CaseClassMissingParamList (cdef), cdef.namePos)
314
+ if (originalVparamss.isEmpty) { // ensure parameter list is non-empty
315
+ if (isCaseClass) ctx.error(CaseClassMissingParamList (cdef), cdef.namePos)
299
316
ListOfNil
300
317
}
301
- else constr1.vparamss .nestedMap(toDefParam)
318
+ else originalVparamss .nestedMap(toDefParam)
302
319
val constr = cpy.DefDef (constr1)(tparams = constrTparams, vparamss = constrVparamss)
303
320
304
321
// Add constructor type parameters and evidence implicit parameters
@@ -312,21 +329,24 @@ object desugar {
312
329
stat
313
330
}
314
331
315
- val derivedTparams = constrTparams map derivedTypeParam
332
+ val derivedTparams =
333
+ if (isEnumCase) constrTparams else constrTparams map derivedTypeParam
316
334
val derivedVparamss = constrVparamss nestedMap derivedTermParam
317
335
val arity = constrVparamss.head.length
318
336
319
- var classTycon : Tree = EmptyTree
337
+ val classTycon : Tree = new TypeRefTree // watching is set at end of method
320
338
321
- // a reference to the class type, with all parameters given.
322
- val classTypeRef /* : Tree*/ = {
323
- // -language:keepUnions difference: classTypeRef needs type annotation, otherwise
324
- // infers Ident | AppliedTypeTree, which
325
- // renders the :\ in companions below untypable.
326
- classTycon = (new TypeRefTree ) withPos cdef.pos.startPos // watching is set at end of method
327
- val tparams = impl.constr.tparams
328
- if (tparams.isEmpty) classTycon else AppliedTypeTree (classTycon, tparams map refOfDef)
329
- }
339
+ def appliedRef (tycon : Tree ) =
340
+ (if (constrTparams.isEmpty) tycon
341
+ else AppliedTypeTree (tycon, constrTparams map refOfDef))
342
+ .withPos(cdef.pos.startPos)
343
+
344
+ // a reference to the class type bound by `cdef`, with type parameters coming from the constructor
345
+ val classTypeRef = appliedRef(classTycon)
346
+ // a reference to `enumClass`, with type parameters coming from the constructor
347
+ lazy val enumClassTypeRef =
348
+ if (reconstitutedTypeParams.isEmpty) enumClassRef
349
+ else appliedRef(enumClassRef)
330
350
331
351
// new C[Ts](paramss)
332
352
lazy val creatorExpr = New (classTypeRef, constrVparamss nestedMap refOfDef)
@@ -374,7 +394,9 @@ object desugar {
374
394
DefDef (nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, TypeTree (), creatorExpr)
375
395
.withMods(synthetic) :: Nil
376
396
}
377
- copyMeths ::: productElemMeths.toList
397
+
398
+ val enumTagMeths = if (isEnumCase) enumTagMeth(CaseKind .Class )._1 :: Nil else Nil
399
+ copyMeths ::: enumTagMeths ::: productElemMeths.toList
378
400
}
379
401
else Nil
380
402
@@ -387,8 +409,12 @@ object desugar {
387
409
388
410
// Case classes and case objects get a ProductN parent
389
411
var parents1 = parents
412
+ if (isEnumCase && parents.isEmpty)
413
+ parents1 = enumClassTypeRef :: Nil
390
414
if (mods.is(Case ) && arity <= Definitions .MaxTupleArity )
391
- parents1 = parents1 :+ productConstr(arity)
415
+ parents1 = parents1 :+ productConstr(arity) // TODO: This also adds Product0 to caes objects. Do we want that?
416
+ if (isEnum)
417
+ parents1 = parents1 :+ ref(defn.EnumType )
392
418
393
419
// The thicket which is the desugared version of the companion object
394
420
// synthetic object C extends parentTpt { defs }
@@ -410,17 +436,26 @@ object desugar {
410
436
// For all other classes, the parent is AnyRef.
411
437
val companions =
412
438
if (isCaseClass) {
439
+ // The return type of the `apply` method
440
+ val applyResultTpt =
441
+ if (isEnumCase)
442
+ if (parents.isEmpty) enumClassTypeRef
443
+ else parents.reduceLeft(AndTypeTree )
444
+ else TypeTree ()
445
+
413
446
val parent =
414
447
if (constrTparams.nonEmpty ||
415
448
constrVparamss.length > 1 ||
416
449
mods.is(Abstract ) ||
417
450
constr.mods.is(Private )) anyRef
451
+ else
418
452
// todo: also use anyRef if constructor has a dependent method type (or rule that out)!
419
- else (constrVparamss :\ classTypeRef) ((vparams, restpe) => Function (vparams map (_.tpt), restpe))
453
+ (constrVparamss :\ (if (isEnumCase) applyResultTpt else classTypeRef)) (
454
+ (vparams, restpe) => Function (vparams map (_.tpt), restpe))
420
455
val applyMeths =
421
456
if (mods is Abstract ) Nil
422
457
else
423
- DefDef (nme.apply, derivedTparams, derivedVparamss, TypeTree () , creatorExpr)
458
+ DefDef (nme.apply, derivedTparams, derivedVparamss, applyResultTpt , creatorExpr)
424
459
.withFlags(Synthetic | (constr1.mods.flags & DefaultParameterized )) :: Nil
425
460
val unapplyMeth = {
426
461
val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)
@@ -464,15 +499,15 @@ object desugar {
464
499
else cpy.ValDef (self)(tpt = selfType).withMods(self.mods | SelfName )
465
500
}
466
501
467
- val cdef1 = {
468
- val originalTparams = constr1.tparams .toIterator
469
- val originalVparams = constr1.vparamss .toIterator.flatten
470
- val tparamAccessors = derivedTparams.map(_.withMods(originalTparams .next.mods))
502
+ val cdef1 = addEnumFlags {
503
+ val originalTparamsIt = originalTparams .toIterator
504
+ val originalVparamsIt = originalVparamss .toIterator.flatten
505
+ val tparamAccessors = derivedTparams.map(_.withMods(originalTparamsIt .next.mods))
471
506
val caseAccessor = if (isCaseClass) CaseAccessor else EmptyFlags
472
507
val vparamAccessors = derivedVparamss match {
473
508
case first :: rest =>
474
- first.map(_.withMods(originalVparams .next.mods | caseAccessor)) ++
475
- rest.flatten.map(_.withMods(originalVparams .next.mods))
509
+ first.map(_.withMods(originalVparamsIt .next.mods | caseAccessor)) ++
510
+ rest.flatten.map(_.withMods(originalVparamsIt .next.mods))
476
511
case _ =>
477
512
Nil
478
513
}
@@ -503,23 +538,26 @@ object desugar {
503
538
*/
504
539
def moduleDef (mdef : ModuleDef )(implicit ctx : Context ): Tree = {
505
540
val moduleName = checkNotReservedName(mdef).asTermName
506
- val tmpl = mdef.impl
541
+ val impl = mdef.impl
507
542
val mods = mdef.mods
543
+ lazy val isEnumCase = isLegalEnumCase(mdef)
508
544
if (mods is Package )
509
- PackageDef (Ident (moduleName), cpy.ModuleDef (mdef)(nme.PACKAGE , tmpl).withMods(mods &~ Package ) :: Nil )
545
+ PackageDef (Ident (moduleName), cpy.ModuleDef (mdef)(nme.PACKAGE , impl).withMods(mods &~ Package ) :: Nil )
546
+ else if (isEnumCase)
547
+ expandEnumModule(moduleName, impl, mods, mdef.pos)
510
548
else {
511
549
val clsName = moduleName.moduleClassName
512
550
val clsRef = Ident (clsName)
513
551
val modul = ValDef (moduleName, clsRef, New (clsRef, Nil ))
514
552
.withMods(mods | ModuleCreationFlags | mods.flags & AccessFlags )
515
553
.withPos(mdef.pos)
516
- val ValDef (selfName, selfTpt, _) = tmpl .self
517
- val selfMods = tmpl .self.mods
518
- if (! selfTpt.isEmpty) ctx.error(ObjectMayNotHaveSelfType (mdef), tmpl .self.pos)
519
- val clsSelf = ValDef (selfName, SingletonTypeTree (Ident (moduleName)), tmpl .self.rhs)
554
+ val ValDef (selfName, selfTpt, _) = impl .self
555
+ val selfMods = impl .self.mods
556
+ if (! selfTpt.isEmpty) ctx.error(ObjectMayNotHaveSelfType (mdef), impl .self.pos)
557
+ val clsSelf = ValDef (selfName, SingletonTypeTree (Ident (moduleName)), impl .self.rhs)
520
558
.withMods(selfMods)
521
- .withPos(tmpl .self.pos orElse tmpl .pos.startPos)
522
- val clsTmpl = cpy.Template (tmpl )(self = clsSelf, body = tmpl .body)
559
+ .withPos(impl .self.pos orElse impl .pos.startPos)
560
+ val clsTmpl = cpy.Template (impl )(self = clsSelf, body = impl .body)
523
561
val cls = TypeDef (clsName, clsTmpl)
524
562
.withMods(mods.toTypeFlags & RetainedModuleClassFlags | ModuleClassCreationFlags )
525
563
Thicket (modul, classDef(cls).withPos(mdef.pos))
@@ -542,11 +580,23 @@ object desugar {
542
580
/** val p1, ..., pN: T = E
543
581
* ==>
544
582
* makePatDef[[val p1: T1 = E ]]; ...; makePatDef[[val pN: TN = E ]]
583
+ *
584
+ * case e1, ..., eN
585
+ * ==>
586
+ * expandSimpleEnumCase([case e1]); ...; expandSimpleEnumCase([case eN])
545
587
*/
546
- def patDef (pdef : PatDef )(implicit ctx : Context ): Tree = {
588
+ def patDef (pdef : PatDef )(implicit ctx : Context ): Tree = flatTree {
547
589
val PatDef (mods, pats, tpt, rhs) = pdef
548
- val pats1 = if (tpt.isEmpty) pats else pats map (Typed (_, tpt))
549
- flatTree(pats1 map (makePatDef(pdef, mods, _, rhs)))
590
+ if (mods.hasMod[Mod .EnumCase ] && enumCaseIsLegal(pdef))
591
+ pats map {
592
+ case id : Ident =>
593
+ expandSimpleEnumCase(id.name.asTermName, mods,
594
+ Position (pdef.pos.start, id.pos.end, id.pos.start))
595
+ }
596
+ else {
597
+ val pats1 = if (tpt.isEmpty) pats else pats map (Typed (_, tpt))
598
+ pats1 map (makePatDef(pdef, mods, _, rhs))
599
+ }
550
600
}
551
601
552
602
/** If `pat` is a variable pattern,
@@ -923,7 +973,7 @@ object desugar {
923
973
case (gen : GenFrom ) :: (rest @ (GenFrom (_, _) :: _)) =>
924
974
val cont = makeFor(mapName, flatMapName, rest, body)
925
975
Apply (rhsSelect(gen, flatMapName), makeLambda(gen.pat, cont))
926
- case (enum @ GenFrom (pat, rhs)) :: (rest @ GenAlias (_, _) :: _) =>
976
+ case (GenFrom (pat, rhs)) :: (rest @ GenAlias (_, _) :: _) =>
927
977
val (valeqs, rest1) = rest.span(_.isInstanceOf [GenAlias ])
928
978
val pats = valeqs map { case GenAlias (pat, _) => pat }
929
979
val rhss = valeqs map { case GenAlias (_, rhs) => rhs }
@@ -1024,7 +1074,6 @@ object desugar {
1024
1074
List (CaseDef (Ident (nme.DEFAULT_EXCEPTION_NAME ), EmptyTree , Apply (handler, Ident (nme.DEFAULT_EXCEPTION_NAME )))),
1025
1075
finalizer)
1026
1076
}
1027
-
1028
1077
}
1029
1078
}.withPos(tree.pos)
1030
1079
0 commit comments