Skip to content

Commit 18bd314

Browse files
committed
Fix fragile transformation of fromProduct when using @unroll
UnrollDefinitions assumed that the body of `fromProduct` had a specific shape which is no longer the case with the dependent case class support introduced in the previous commit. This caused compiler crashes for tests/run/unroll-multiple.scala and tests/run/unroll-caseclass-integration This commit fixes this by directly generating the correct fromProduct in SyntheticMembers. This should also prevent crashes in situations where code is injected into existing trees like the code coverage support or external compiler plugins.
1 parent 4aa59eb commit 18bd314

File tree

4 files changed

+67
-74
lines changed

4 files changed

+67
-74
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,7 @@ class Definitions {
618618
@tu lazy val Int_== : Symbol = IntClass.requiredMethod(nme.EQ, List(IntType))
619619
@tu lazy val Int_>= : Symbol = IntClass.requiredMethod(nme.GE, List(IntType))
620620
@tu lazy val Int_<= : Symbol = IntClass.requiredMethod(nme.LE, List(IntType))
621+
@tu lazy val Int_> : Symbol = IntClass.requiredMethod(nme.GT, List(IntType))
621622
@tu lazy val LongType: TypeRef = valueTypeRef("scala.Long", java.lang.Long.TYPE, LongEnc, nme.specializedTypeNames.Long)
622623
def LongClass(using Context): ClassSymbol = LongType.symbol.asClass
623624
@tu lazy val Long_+ : Symbol = LongClass.requiredMethod(nme.PLUS, List(LongType))

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ object StdNames {
425425
val array_length : N = "array_length"
426426
val array_update : N = "array_update"
427427
val arraycopy: N = "arraycopy"
428+
val arity: N = "arity"
428429
val as: N = "as"
429430
val asTerm: N = "asTerm"
430431
val asModule: N = "asModule"

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -523,21 +523,47 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
523523
* ```
524524
* type MirroredMonoType = C[?]
525525
* ```
526+
*
527+
* However, if the last parameter is annotated `@unroll` then we generate:
528+
*
529+
* def fromProduct(x$0: Product): MirroredMonoType =
530+
* val arity = x$0.productArity
531+
* val a$1 = x$0.productElement(0).asInstanceOf[U]
532+
* val b$1 = x$0.productElement(1).asInstanceOf[a$1.Elem]
533+
* val c$1 = (
534+
* if arity > 2 then
535+
* x$0.productElement(2)
536+
* else
537+
* <default getter for the third parameter of C>
538+
* ).asInstanceOf[Seq[String]]
539+
* new C[U](a$1, b$1, c$1*)
526540
*/
527541
def fromProductBody(caseClass: Symbol, productParam: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree =
528542
val classRef = optInfo match
529543
case Some(info) => TypeRef(info.pre, caseClass)
530544
case _ => caseClass.typeRef
531-
val (newPrefix, constrMeth) =
545+
val (newPrefix, constrMeth, constrSyms) =
532546
val constr = TermRef(classRef, caseClass.primaryConstructor)
547+
val symss = caseClass.primaryConstructor.paramSymss
533548
(constr.info: @unchecked) match
534549
case tl: PolyType =>
535550
val tvars = constrained(tl)
536551
val targs = for tvar <- tvars yield
537552
tvar.instantiate(fromBelow = false)
538-
(AppliedType(classRef, targs), tl.instantiate(targs).asInstanceOf[MethodType])
553+
(AppliedType(classRef, targs), tl.instantiate(targs).asInstanceOf[MethodType], symss(1))
539554
case mt: MethodType =>
540-
(classRef, mt)
555+
(classRef, mt, symss.head)
556+
557+
// Index of the first parameter marked `@unroll` or -1
558+
val unrolledFrom =
559+
constrSyms.indexWhere(_.hasAnnotation(defn.UnrollAnnot))
560+
561+
// `val arity = x$0.productArity`
562+
val arityDef: Option[ValDef] =
563+
if unrolledFrom != -1 then
564+
Some(SyntheticValDef(nme.arity, productParam.select(defn.Product_productArity).withSpan(ctx.owner.span.focus)))
565+
else None
566+
val arityRefTree = arityDef.map(vd => ref(vd.symbol))
541567

542568
// Create symbols for the vals corresponding to each parameter
543569
// If there are dependent parameters, the infos won't be correct yet.
@@ -550,16 +576,29 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
550576
bindingSyms.foreach: bindingSym =>
551577
bindingSym.info = bindingSym.info.substParams(constrMeth, bindingRefs)
552578

579+
def defaultGetterAtIndex(idx: Int): Tree =
580+
val defaultGetterPrefix = caseClass.primaryConstructor.name.toTermName
581+
ref(caseClass.companionModule).select(NameKinds.DefaultGetterName(defaultGetterPrefix, idx))
582+
553583
val bindingDefs = bindingSyms.zipWithIndex.map: (bindingSym, idx) =>
554-
ValDef(bindingSym,
555-
productParam.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
556-
.ensureConforms(bindingSym.info))
584+
val selection = productParam.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
585+
val rhs = (
586+
if unrolledFrom != -1 && idx >= unrolledFrom then
587+
If(arityRefTree.get.select(defn.Int_>).appliedTo(Literal(Constant(idx))),
588+
thenp =
589+
selection,
590+
elsep =
591+
defaultGetterAtIndex(idx))
592+
else
593+
selection
594+
).ensureConforms(bindingSym.info)
595+
ValDef(bindingSym, rhs)
557596

558597
val newArgs = bindingRefs.lazyZip(constrMeth.paramInfos).map: (bindingRef, paramInfo) =>
559598
val refTree = ref(bindingRef)
560599
if paramInfo.isRepeatedParam then ctx.typer.seqToRepeated(refTree) else refTree
561600
Block(
562-
bindingDefs,
601+
arityDef.toList ::: bindingDefs,
563602
New(newPrefix, newArgs)
564603
)
565604
end fromProductBody

compiler/src/dotty/tools/dotc/transform/UnrollDefinitions.scala

Lines changed: 19 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -228,46 +228,9 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
228228
forwarderDef
229229
}
230230

231-
private def generateFromProduct(startParamIndices: List[Int], paramCount: Int, defdef: DefDef)(using Context) = {
232-
cpy.DefDef(defdef)(
233-
name = defdef.name,
234-
paramss = defdef.paramss,
235-
tpt = defdef.tpt,
236-
rhs = Match(
237-
ref(defdef.paramss.head.head.asInstanceOf[ValDef].symbol).select(termName("productArity")),
238-
startParamIndices.map { paramIndex =>
239-
val Apply(select, args) = defdef.rhs: @unchecked
240-
CaseDef(
241-
Literal(Constant(paramIndex)),
242-
EmptyTree,
243-
Apply(
244-
select,
245-
args.take(paramIndex) ++
246-
Range(paramIndex, paramCount).map(n =>
247-
ref(defdef.symbol.owner.companionModule)
248-
.select(DefaultGetterName(defdef.symbol.owner.primaryConstructor.name.toTermName, n))
249-
)
250-
)
251-
)
252-
} :+ CaseDef(
253-
Underscore(defn.IntType),
254-
EmptyTree,
255-
defdef.rhs
256-
)
257-
)
258-
).setDefTree
259-
}
260-
261-
private enum Gen:
262-
case Substitute(origin: Symbol, newDef: DefDef)
263-
case Forwarders(origin: Symbol, forwarders: List[DefDef])
231+
case class Forwarders(origin: Symbol, forwarders: List[DefDef])
264232

265-
def origin: Symbol
266-
def extras: List[DefDef] = this match
267-
case Substitute(_, d) => d :: Nil
268-
case Forwarders(_, ds) => ds
269-
270-
private def generateSyntheticDefs(tree: Tree, compute: ComputeIndices)(using Context): Option[Gen] = tree match {
233+
private def generateSyntheticDefs(tree: Tree, compute: ComputeIndices)(using Context): Option[Forwarders] = tree match {
271234
case defdef: DefDef if defdef.paramss.nonEmpty =>
272235
import dotty.tools.dotc.core.NameOps.isConstructorName
273236

@@ -277,38 +240,29 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
277240
val isCaseApply =
278241
defdef.name == nme.apply && defdef.symbol.owner.companionClass.is(CaseClass)
279242

280-
val isCaseFromProduct = defdef.name == nme.fromProduct && defdef.symbol.owner.companionClass.is(CaseClass)
281-
282243
val annotated =
283244
if (isCaseCopy) defdef.symbol.owner.primaryConstructor
284245
else if (isCaseApply) defdef.symbol.owner.companionClass.primaryConstructor
285-
else if (isCaseFromProduct) defdef.symbol.owner.companionClass.primaryConstructor
286246
else defdef.symbol
287247

288248
compute(annotated) match {
289249
case Nil => None
290250
case (paramClauseIndex, annotationIndices) :: Nil =>
291251
val paramCount = annotated.paramSymss(paramClauseIndex).size
292-
if isCaseFromProduct then
293-
Some(Gen.Substitute(
294-
origin = defdef.symbol,
295-
newDef = generateFromProduct(annotationIndices, paramCount, defdef)
296-
))
297-
else
298-
val generatedDefs =
299-
val indices = (annotationIndices :+ paramCount).sliding(2).toList.reverse
300-
indices.foldLeft(List.empty[DefDef]):
301-
case (defdefs, paramIndex :: nextParamIndex :: Nil) =>
302-
generateSingleForwarder(
303-
defdef,
304-
paramIndex,
305-
paramCount,
306-
nextParamIndex,
307-
paramClauseIndex,
308-
isCaseApply
309-
) :: defdefs
310-
case _ => unreachable("sliding with at least 2 elements")
311-
Some(Gen.Forwarders(origin = defdef.symbol, forwarders = generatedDefs))
252+
val generatedDefs =
253+
val indices = (annotationIndices :+ paramCount).sliding(2).toList.reverse
254+
indices.foldLeft(List.empty[DefDef]):
255+
case (defdefs, paramIndex :: nextParamIndex :: Nil) =>
256+
generateSingleForwarder(
257+
defdef,
258+
paramIndex,
259+
paramCount,
260+
nextParamIndex,
261+
paramClauseIndex,
262+
isCaseApply
263+
) :: defdefs
264+
case _ => unreachable("sliding with at least 2 elements")
265+
Some(Forwarders(origin = defdef.symbol, forwarders = generatedDefs))
312266

313267
case multiple =>
314268
report.error("Cannot have multiple parameter lists containing `@unroll` annotation", defdef.srcPos)
@@ -323,14 +277,12 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
323277
val generatedBody = tmpl.body.flatMap(generateSyntheticDefs(_, compute))
324278
val generatedConstr0 = generateSyntheticDefs(tmpl.constr, compute)
325279
val allGenerated = generatedBody ++ generatedConstr0
326-
val bodySubs = generatedBody.collect({ case s: Gen.Substitute => s.origin }).toSet
327-
val otherDecls = tmpl.body.filterNot(d => d.symbol.exists && bodySubs(d.symbol))
328280

329281
if allGenerated.nonEmpty then
330-
val byName = (tmpl.constr :: otherDecls).groupMap(_.symbol.name.toString)(_.symbol)
282+
val byName = (tmpl.constr :: tmpl.body).groupMap(_.symbol.name.toString)(_.symbol)
331283
for
332284
syntheticDefs <- allGenerated
333-
dcl <- syntheticDefs.extras
285+
dcl <- syntheticDefs.forwarders
334286
do
335287
val replaced = dcl.symbol
336288
byName.get(dcl.name.toString).foreach { syms =>
@@ -348,7 +300,7 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
348300
tmpl.parents,
349301
tmpl.derived,
350302
tmpl.self,
351-
otherDecls ++ allGenerated.flatMap(_.extras)
303+
tmpl.body ++ allGenerated.flatMap(_.forwarders)
352304
)
353305
}
354306

0 commit comments

Comments
 (0)