Skip to content

fix #12634 - port sbt/zinc#979 - add sealedDescendants to zinc sealed children #12636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,66 @@ object SymDenotations {

annotations.collect { case Annotation.Child(child) => child }.reverse
end children

/** Recursively assemble all children of this symbol, Preserves order of insertion.
*/
final def sealedStrictDescendants(using Context): List[Symbol] =

@tailrec
def findLvlN(
explore: mutable.ArrayDeque[Symbol],
seen: util.HashSet[Symbol],
acc: mutable.ListBuffer[Symbol]
): List[Symbol] =
if explore.isEmpty then
acc.toList
else
val sym = explore.head
val explore1 = explore.dropInPlace(1)
val lvlN = sym.children
val notSeen = lvlN.filterConserve(!seen.contains(_))
if notSeen.isEmpty then
findLvlN(explore1, seen, acc)
else
findLvlN(explore1 ++= notSeen, {seen ++= notSeen; seen}, acc ++= notSeen)
end findLvlN

/** Scans through `explore` to see if there are recursive children.
* If a symbol in `explore` has children that are not contained in
* `lvl1`, fallback to `findLvlN`, or else return `lvl1`.
*/
@tailrec
def findLvl2(
lvl1: List[Symbol], explore: List[Symbol], seenOrNull: util.HashSet[Symbol] | Null
): List[Symbol] = explore match
case sym :: explore1 =>
val lvl2 = sym.children
if lvl2.isEmpty then // no children, scan rest of explore1
findLvl2(lvl1, explore1, seenOrNull)
else // check if we have seen the children before
val seen = // initialise the seen set if not already
if seenOrNull != null then seenOrNull
else util.HashSet.from(lvl1)
val notSeen = lvl2.filterConserve(!seen.contains(_))
if notSeen.isEmpty then // we found children, but we had already seen them, scan the rest of explore1
findLvl2(lvl1, explore1, seen)
else // found unseen recursive children, we should fallback to the loop
findLvlN(
explore = mutable.ArrayDeque.from(explore1).appendAll(notSeen),
seen = {seen ++= notSeen; seen},
acc = mutable.ListBuffer.from(lvl1).appendAll(notSeen)
)
case nil =>
lvl1
end findLvl2

val lvl1 = children
findLvl2(lvl1, lvl1, seenOrNull = null)
end sealedStrictDescendants

/** Same as `sealedStrictDescendants` but prepends this symbol as well.
*/
final def sealedDescendants(using Context): List[Symbol] = this.symbol :: sealedStrictDescendants
}

/** The contents of a class definition during a period
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ private class ExtractAPICollector(using Context) extends ThunkHolder {
val modifiers = apiModifiers(sym)
val anns = apiAnnotations(sym).toArray
val topLevel = sym.isTopLevelClass
val childrenOfSealedClass = sym.children.sorted(classFirstSort).map(c =>
val childrenOfSealedClass = sym.sealedDescendants.sorted(classFirstSort).map(c =>
if (c.isClass)
apiType(c.typeRef)
else
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/util/HashSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ object HashSet:
*/
inline val DenseLimit = 8

def from[T](xs: IterableOnce[T]): HashSet[T] =
val set = new HashSet[T]()
set ++= xs
set

/** A hash set that allows some privileged protected access to its internals
* @param initialCapacity Indicates the initial number of slots in the hash table.
* The actual number of slots is always a power of 2, so the
Expand Down
108 changes: 108 additions & 0 deletions compiler/test/dotty/tools/dotc/core/SealedDescendantsTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package dotty.tools.dotc.core

import dotty.tools.dotc.core.Contexts.{Context, ctx}
import dotty.tools.dotc.core.Symbols.*

import org.junit.Assert._
import org.junit.Test

import dotty.tools.DottyTest

class SealedDescendantsTest extends DottyTest {

@Test
def zincIssue979: Unit =
val source =
"""
sealed trait Z
sealed trait A extends Z
class B extends A
class C extends A
class D extends A
"""

expectedDescendents(source, "Z",
"Z" ::
"A" ::
"B" ::
"C" ::
"D" :: Nil
)
end zincIssue979

@Test
def enumOpt: Unit =
val source =
"""
enum Opt[+T] {
case Some(t: T)
case None
}
"""

expectedDescendents(source, "Opt",
"Opt" ::
"Some" ::
"None.type" :: Nil
)
end enumOpt

@Test
def hierarchicalSharedChildren: Unit =
// Q is a child of both Z and A and should appear once
// X is a child of both A and Q and should appear once
val source =
"""
sealed trait Z
sealed trait A extends Z
sealed trait Q extends A with Z
trait X extends A with Q
case object Y extends Q
"""

expectedDescendents(source, "Z",
"Z" ::
"A" ::
"Q" ::
"X" ::
"Y.type" :: Nil
)
end hierarchicalSharedChildren

@Test
def hierarchicalSharedChildrenB: Unit =
val source =
"""
sealed trait Z
case object A extends Z with D with E
sealed trait B extends Z
trait C extends B
sealed trait D extends B
sealed trait E extends D
"""

expectedDescendents(source, "Z",
"Z" ::
"A.type" ::
"B" ::
"C" ::
"D" ::
"E" :: Nil
)
end hierarchicalSharedChildrenB

def expectedDescendents(source: String, root: String, expected: List[String]) =
exploreRoot(source, root) { rootCls =>
val descendents = rootCls.sealedDescendants.map(sym => s"${sym.name}${if (sym.isTerm) ".type" else ""}")
assertEquals(expected.toString, descendents.toString)
}

def exploreRoot(source: String, root: String)(op: Context ?=> ClassSymbol => Unit) =
val source0 = source.linesIterator.map(_.trim).mkString("\n|")
val source1 = s"""package testsealeddescendants
|$source0""".stripMargin
checkCompile("typer", source1) { (_, context) =>
given Context = context
op(requiredClass(s"testsealeddescendants.$root"))
}
}
4 changes: 4 additions & 0 deletions sbt-test/source-dependencies/sealed-extends-sealed/A.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
sealed trait Z
sealed trait A extends Z
class B extends A
class C extends A
6 changes: 6 additions & 0 deletions sbt-test/source-dependencies/sealed-extends-sealed/App.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
object App {
def foo(z: Z) = z match {
case _: B =>
case _: C =>
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
sealed trait Z
sealed trait A extends Z
class B extends A
class C extends A
class D extends A
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import sbt._
import Keys._

object DottyInjectedPlugin extends AutoPlugin {
override def requires = plugins.JvmPlugin
override def trigger = allRequirements

override val projectSettings = Seq(
scalaVersion := sys.props("plugin.scalaVersion"),
scalacOptions ++= Seq("-source:3.0-migration", "-Xfatal-warnings")
)
}
8 changes: 8 additions & 0 deletions sbt-test/source-dependencies/sealed-extends-sealed/test
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
> compile

# Introduce a new class C that also extends A
$ copy-file changes/A.scala A.scala

# App.scala needs recompiling because the pattern match in it
# is no longer exhaustive, which emits a warning
-> compile