Skip to content

Commit f6529c4

Browse files
committed
Store capability class information in a hash map during cc
1 parent 83a409d commit f6529c4

File tree

5 files changed

+58
-35
lines changed

5 files changed

+58
-35
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -203,21 +203,6 @@ extension (tp: Type)
203203
case _ =>
204204
false
205205

206-
def isCapabilityClassRef(using Context) = tp.dealiasKeepAnnots match
207-
case _: TypeRef | _: AppliedType => tp.typeSymbol.hasAnnotation(defn.CapabilityAnnot)
208-
case _ => false
209-
210-
/** Check if the class has universal capability, which means:
211-
* 1. the class has a capability annotation,
212-
* 2. the class is an impure function type,
213-
* 3. or one of its base classes has universal capability.
214-
*/
215-
def hasUniversalCapability(using Context): Boolean = tp match
216-
case CapturingType(parent, ref) =>
217-
ref.isUniversal || parent.hasUniversalCapability
218-
case tp =>
219-
tp.isCapabilityClassRef || tp.parents.exists(_.hasUniversalCapability)
220-
221206
/** Drop @retains annotations everywhere */
222207
def dropAllRetains(using Context): Type = // TODO we should drop retains from inferred types before unpickling
223208
val tm = new TypeMap:

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ class CheckCaptures extends Recheck, SymTransformer:
528528
*/
529529
def addParamArgRefinements(core: Type, initCs: CaptureSet): (Type, CaptureSet) =
530530
var refined: Type = core
531-
var allCaptures: CaptureSet = if core.hasUniversalCapability
531+
var allCaptures: CaptureSet = if setup.isCapabilityClassRef(core)
532532
then CaptureSet.universal else initCs
533533
for (getterName, argType) <- mt.paramNames.lazyZip(argTypes) do
534534
val getter = cls.info.member(getterName).suchThat(_.is(ParamAccessor)).symbol

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ trait SetupAPI:
2323
def setupUnit(tree: Tree, recheckDef: DefRecheck)(using Context): Unit
2424
def isPreCC(sym: Symbol)(using Context): Boolean
2525
def postCheck()(using Context): Unit
26+
def isCapabilityClassRef(tp: Type)(using Context): Boolean
2627

2728
object Setup:
2829

@@ -67,6 +68,31 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
6768
&& !sym.owner.is(CaptureChecked)
6869
&& !defn.isFunctionSymbol(sym.owner)
6970

71+
private val capabilityClassMap = new util.HashMap[Symbol, Boolean]
72+
73+
/** Check if the class is capability, which means:
74+
* 1. the class has a capability annotation,
75+
* 2. or at least one of its parent type has universal capability.
76+
*/
77+
def isCapabilityClassRef(tp: Type)(using Context): Boolean = tp.dealiasKeepAnnots match
78+
case _: TypeRef | _: AppliedType =>
79+
val sym = tp.classSymbol
80+
def checkSym: Boolean =
81+
sym.hasAnnotation(defn.CapabilityAnnot)
82+
|| sym.info.parents.exists(hasUniversalCapability)
83+
sym.isClass && capabilityClassMap.getOrElseUpdate(sym, checkSym)
84+
case _ => false
85+
86+
private def hasUniversalCapability(tp: Type)(using Context): Boolean = tp.dealiasKeepAnnots match
87+
case CapturingType(parent, refs) =>
88+
refs.isUniversal || hasUniversalCapability(parent)
89+
case AnnotatedType(parent, ann) =>
90+
if ann.symbol.isRetains then
91+
try ann.tree.toCaptureSet.isUniversal || hasUniversalCapability(parent)
92+
catch case ex: IllegalCaptureRef => false
93+
else hasUniversalCapability(parent)
94+
case tp => isCapabilityClassRef(tp)
95+
7096
private def fluidify(using Context) = new TypeMap with IdempotentCaptRefMap:
7197
def apply(t: Type): Type = t match
7298
case t: MethodType =>
@@ -292,7 +318,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
292318
this(t.underlying)
293319
case t =>
294320
// Map references to capability classes C to C^
295-
if t.hasUniversalCapability
321+
if isCapabilityClassRef(t)
296322
then CapturingType(t, defn.expandedUniversalSet, boxed = false)
297323
else recur(t)
298324
end expandAliases
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
class F1 extends (Int => Unit) {
2+
def apply(x: Int): Unit = ()
3+
}
4+
5+
class F2 extends (Int -> Unit) {
6+
def apply(x: Int): Unit = ()
7+
}
8+
9+
def test =
10+
val x1 = new (Int => Unit) {
11+
def apply(x: Int): Unit = ()
12+
}
13+
14+
val x2: Int -> Unit = new (Int => Unit) { // error
15+
def apply(x: Int): Unit = ()
16+
}
17+
18+
val x3: Int -> Unit = new (Int -> Unit) {
19+
def apply(x: Int): Unit = ()
20+
}
21+
22+
val y1: Int => Unit = new F1
23+
val y2: Int -> Unit = new F1 // error
24+
val y3: Int => Unit = new F2
25+
val y4: Int -> Unit = new F2
26+
27+
val z1 = () => ()
28+
val z2: () -> Unit = () => ()
29+
val z3: () -> Unit = z1
30+
val z4: () => Unit = () => ()

tests/neg-custom-args/captures/extending-impure-function.scala.scala

Lines changed: 0 additions & 18 deletions
This file was deleted.

0 commit comments

Comments
 (0)