Skip to content

Commit 837ed3a

Browse files
authored
Carry and check universal capability from parents correctly (#20004)
Fix #18857 This PR checks universal capability from parent classes properly.
2 parents 8825b07 + f6529c4 commit 837ed3a

File tree

6 files changed

+77
-14
lines changed

6 files changed

+77
-14
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,6 @@ extension (tp: Type)
203203
case _ =>
204204
false
205205

206-
def isCapabilityClassRef(using Context) = tp.dealiasKeepAnnots match
207-
case _: TypeRef | _: AppliedType => tp.typeSymbol.hasAnnotation(defn.CapabilityAnnot)
208-
case _ => false
209-
210206
/** Drop @retains annotations everywhere */
211207
def dropAllRetains(using Context): Type = // TODO we should drop retains from inferred types before unpickling
212208
val tm = new TypeMap:

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,8 @@ class CheckCaptures extends Recheck, SymTransformer:
537537
*/
538538
def addParamArgRefinements(core: Type, initCs: CaptureSet): (Type, CaptureSet) =
539539
var refined: Type = core
540-
var allCaptures: CaptureSet = initCs
540+
var allCaptures: CaptureSet = if setup.isCapabilityClassRef(core)
541+
then CaptureSet.universal else initCs
541542
for (getterName, argType) <- mt.paramNames.lazyZip(argTypes) do
542543
val getter = cls.info.member(getterName).suchThat(_.is(ParamAccessor)).symbol
543544
if getter.termRef.isTracked && !getter.is(Private) then

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ trait SetupAPI:
2323
def setupUnit(tree: Tree, recheckDef: DefRecheck)(using Context): Unit
2424
def isPreCC(sym: Symbol)(using Context): Boolean
2525
def postCheck()(using Context): Unit
26+
def isCapabilityClassRef(tp: Type)(using Context): Boolean
2627

2728
object Setup:
2829

@@ -67,6 +68,31 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
6768
&& !sym.owner.is(CaptureChecked)
6869
&& !defn.isFunctionSymbol(sym.owner)
6970

71+
private val capabilityClassMap = new util.HashMap[Symbol, Boolean]
72+
73+
/** Check if the class is capability, which means:
74+
* 1. the class has a capability annotation,
75+
* 2. or at least one of its parent type has universal capability.
76+
*/
77+
def isCapabilityClassRef(tp: Type)(using Context): Boolean = tp.dealiasKeepAnnots match
78+
case _: TypeRef | _: AppliedType =>
79+
val sym = tp.classSymbol
80+
def checkSym: Boolean =
81+
sym.hasAnnotation(defn.CapabilityAnnot)
82+
|| sym.info.parents.exists(hasUniversalCapability)
83+
sym.isClass && capabilityClassMap.getOrElseUpdate(sym, checkSym)
84+
case _ => false
85+
86+
private def hasUniversalCapability(tp: Type)(using Context): Boolean = tp.dealiasKeepAnnots match
87+
case CapturingType(parent, refs) =>
88+
refs.isUniversal || hasUniversalCapability(parent)
89+
case AnnotatedType(parent, ann) =>
90+
if ann.symbol.isRetains then
91+
try ann.tree.toCaptureSet.isUniversal || hasUniversalCapability(parent)
92+
catch case ex: IllegalCaptureRef => false
93+
else hasUniversalCapability(parent)
94+
case tp => isCapabilityClassRef(tp)
95+
7096
private def fluidify(using Context) = new TypeMap with IdempotentCaptRefMap:
7197
def apply(t: Type): Type = t match
7298
case t: MethodType =>
@@ -269,12 +295,6 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
269295
CapturingType(fntpe, cs, boxed = false)
270296
else fntpe
271297

272-
/** Map references to capability classes C to C^ */
273-
private def expandCapabilityClass(tp: Type): Type =
274-
if tp.isCapabilityClassRef
275-
then CapturingType(tp, defn.expandedUniversalSet, boxed = false)
276-
else tp
277-
278298
private def recur(t: Type): Type = normalizeCaptures(mapOver(t))
279299

280300
def apply(t: Type) =
@@ -297,7 +317,8 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
297317
case t: TypeVar =>
298318
this(t.underlying)
299319
case t =>
300-
if t.isCapabilityClassRef
320+
// Map references to capability classes C to C^
321+
if isCapabilityClassRef(t)
301322
then CapturingType(t, defn.expandedUniversalSet, boxed = false)
302323
else recur(t)
303324
end expandAliases

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import scala.util.control.NonFatal
2323
import config.Config
2424
import reporting.*
2525
import collection.mutable
26-
import cc.{CapturingType, derivedCapturingType}
26+
import cc.{CapturingType, derivedCapturingType, stripCapturing}
2727

2828
import scala.annotation.internal.sharable
2929
import scala.compiletime.uninitialized
@@ -2225,7 +2225,7 @@ object SymDenotations {
22252225
tp match {
22262226
case tp @ TypeRef(prefix, _) =>
22272227
def foldGlb(bt: Type, ps: List[Type]): Type = ps match {
2228-
case p :: ps1 => foldGlb(bt & recur(p), ps1)
2228+
case p :: ps1 => foldGlb(bt & recur(p.stripCapturing), ps1)
22292229
case _ => bt
22302230
}
22312231

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import annotation.capability
2+
3+
class C1
4+
@capability class C2 extends C1
5+
class C3 extends C2
6+
7+
def test =
8+
val x1: C1 = new C1
9+
val x2: C1 = new C2 // error
10+
val x3: C1 = new C3 // error
11+
12+
val y1: C2 = new C2
13+
val y2: C2 = new C3
14+
15+
val z1: C3 = new C3
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
class F1 extends (Int => Unit) {
2+
def apply(x: Int): Unit = ()
3+
}
4+
5+
class F2 extends (Int -> Unit) {
6+
def apply(x: Int): Unit = ()
7+
}
8+
9+
def test =
10+
val x1 = new (Int => Unit) {
11+
def apply(x: Int): Unit = ()
12+
}
13+
14+
val x2: Int -> Unit = new (Int => Unit) { // error
15+
def apply(x: Int): Unit = ()
16+
}
17+
18+
val x3: Int -> Unit = new (Int -> Unit) {
19+
def apply(x: Int): Unit = ()
20+
}
21+
22+
val y1: Int => Unit = new F1
23+
val y2: Int -> Unit = new F1 // error
24+
val y3: Int => Unit = new F2
25+
val y4: Int -> Unit = new F2
26+
27+
val z1 = () => ()
28+
val z2: () -> Unit = () => ()
29+
val z3: () -> Unit = z1
30+
val z4: () => Unit = () => ()

0 commit comments

Comments
 (0)