Skip to content

Commit 434f066

Browse files
committed
Support goto-def on exported methods too
1 parent 4e532e9 commit 434f066

File tree

4 files changed

+66
-81
lines changed

4 files changed

+66
-81
lines changed

compiler/src/dotty/tools/dotc/ast/tpd.scala

+5
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,11 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
11421142
buf.toList
11431143
}
11441144

1145+
def collectSubTrees[A](f: PartialFunction[Tree, A])(using Context): List[A] =
1146+
val buf = mutable.ListBuffer[A]()
1147+
foreachSubTree(f.runWith(buf += _)(_))
1148+
buf.toList
1149+
11451150
/** Set this tree as the `defTree` of its symbol and return this tree */
11461151
def setDefTree(using Context): ThisTree = {
11471152
val sym = tree.symbol

compiler/src/dotty/tools/dotc/core/Symbols.scala

+10-3
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ import DenotTransformers.*
1919
import StdNames.*
2020
import NameOps.*
2121
import NameKinds.LazyImplicitName
22-
import ast.tpd
23-
import tpd.{Tree, TreeProvider, TreeOps}
24-
import ast.TreeTypeMap
22+
import ast.*, tpd.*
2523
import Constants.Constant
2624
import Variances.Variance
2725
import reporting.Message
@@ -336,6 +334,15 @@ object Symbols extends SymUtils {
336334
denot.info.dropAlias.finalResultType.typeConstructor match
337335
case tp: NamedType => tp.symbol.sourceSymbol
338336
case _ => this
337+
else if denot.is(ExportedTerm) then
338+
val root = denot.maybeOwner match
339+
case cls: ClassSymbol => cls.rootTreeContaining(name.toString)
340+
case _ => EmptyTree
341+
val targets = root.collectSubTrees:
342+
case tree: DefDef if tree.name == name => methPart(tree.rhs).tpe
343+
targets.match
344+
case (tp: NamedType) :: _ => tp.symbol
345+
case _ => this
339346
else if (denot.is(Synthetic)) {
340347
val linked = denot.linkedClass
341348
if (linked.exists && !linked.is(Synthetic))

presentation-compiler/src/main/dotty/tools/pc/MetalsInteractive.scala

+9-14
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
1-
package dotty.tools.pc
1+
package dotty.tools
2+
package pc
23

34
import scala.annotation.tailrec
45

5-
import dotty.tools.dotc.ast.tpd
6-
import dotty.tools.dotc.ast.tpd.*
7-
import dotty.tools.dotc.ast.untpd
8-
import dotty.tools.dotc.core.Contexts.*
9-
import dotty.tools.dotc.core.Flags.*
10-
import dotty.tools.dotc.core.Names.Name
11-
import dotty.tools.dotc.core.StdNames
12-
import dotty.tools.dotc.core.Symbols.*
13-
import dotty.tools.dotc.core.Types.Type
14-
import dotty.tools.dotc.interactive.SourceTree
15-
import dotty.tools.dotc.util.SourceFile
16-
import dotty.tools.dotc.util.SourcePosition
6+
import dotc.*
7+
import ast.*, tpd.*
8+
import core.*, Contexts.*, Decorators.*, Flags.*, Names.*, Symbols.*, Types.*
9+
import interactive.*
10+
import util.*
11+
import util.SourcePosition
1712

1813
object MetalsInteractive:
1914

@@ -205,7 +200,7 @@ object MetalsInteractive:
205200
Nil
206201

207202
case path @ head :: tail =>
208-
if head.symbol.is(ExportedType) then
203+
if head.symbol.is(Exported) then
209204
val sym = head.symbol.sourceSymbol
210205
List((sym, sym.info))
211206
else if head.symbol.is(Synthetic) then

presentation-compiler/test/dotty/tools/pc/tests/definition/PcDefinitionSuite.scala

+42-64
Original file line numberDiff line numberDiff line change
@@ -199,89 +199,67 @@ class PcDefinitionSuite extends BasePcDefinitionSuite:
199199
|""".stripMargin
200200
)
201201

202-
@Test def `exportType1` =
203-
check(
204-
"""object enumerations:
205-
| trait <<SymbolKind>>
206-
| trait CymbalKind
207-
|
208-
|object all:
209-
| export enumerations.*
210-
|
211-
|@main def hello =
212-
| import all.SymbolKind
213-
| import enumerations.CymbalKind
214-
|
215-
| val x = new Symbo@@lKind {}
216-
| val y = new CymbalKind {}
202+
@Test def exportType0 =
203+
check(
204+
"""object Foo:
205+
| trait <<Cat>>
206+
|object Bar:
207+
| export Foo.*
208+
|class Test:
209+
| import Bar.*
210+
| def test = new Ca@@t {}
217211
|""".stripMargin
218212
)
219213

220-
@Test def `exportType1Wild` =
221-
check(
222-
"""object enumerations:
223-
| trait <<SymbolKind>>
224-
| trait CymbalKind
225-
|
226-
|object all:
227-
| export enumerations.SymbolKind
228-
|
229-
|@main def hello =
230-
| import all.SymbolKind
231-
| import enumerations.CymbalKind
232-
|
233-
| val x = new Symbo@@lKind {}
234-
| val y = new CymbalKind {}
214+
@Test def exportType1 =
215+
check(
216+
"""object Foo:
217+
| trait <<Cat>>[A]
218+
|object Bar:
219+
| export Foo.*
220+
|class Test:
221+
| import Bar.*
222+
| def test = new Ca@@t[Int] {}
235223
|""".stripMargin
236224
)
237225

238-
@Test def `exportTerm1` =
226+
@Test def exportTerm0Nullary =
239227
check(
240-
"""class BitMap
241-
|class Scanner:
242-
| def scan(): BitMap = ???
243-
|class Copier:
244-
| private val scanUnit = new Scanner
245-
| export scanUnit.<<scan>>
246-
| def t1 = sc@@an()
228+
"""trait Foo:
229+
| def <<meth>>: Int
230+
|class Bar(val foo: Foo):
231+
| export foo.*
232+
| def test(bar: Bar) = bar.me@@th
247233
|""".stripMargin
248234
)
249235

250-
@Test def `exportTerm2` =
236+
@Test def exportTerm0 =
251237
check(
252-
"""class BitMap
253-
|class Scanner:
254-
| def scan(): BitMap = ???
255-
|class Copier:
256-
| private val scanUnit = new Scanner
257-
| export scanUnit.<<scan>>
258-
|class Test:
259-
| def t2(cpy: Copier) = cpy.sc@@an()
238+
"""trait Foo:
239+
| def <<meth>>(): Int
240+
|class Bar(val foo: Foo):
241+
| export foo.*
242+
| def test(bar: Bar) = bar.me@@th()
260243
|""".stripMargin
261244
)
262245

263-
@Test def `exportTerm1Wild` =
246+
@Test def exportTerm1 =
264247
check(
265-
"""class BitMap
266-
|class Scanner:
267-
| def scan(): BitMap = ???
268-
|class Copier:
269-
| private val scanUnit = new Scanner
270-
| export scanUnit.<<*>>
271-
| def t1 = sc@@an()
248+
"""trait Foo:
249+
| def <<meth>>(x: Int): Int
250+
|class Bar(val foo: Foo):
251+
| export foo.*
252+
| def test(bar: Bar) = bar.me@@th(0)
272253
|""".stripMargin
273254
)
274255

275-
@Test def `exportTerm2Wild` =
256+
@Test def exportTerm1Poly =
276257
check(
277-
"""class BitMap
278-
|class Scanner:
279-
| def scan(): BitMap = ???
280-
|class Copier:
281-
| private val scanUnit = new Scanner
282-
| export scanUnit.<<*>>
283-
|class Test:
284-
| def t2(cpy: Copier) = cpy.sc@@an()
258+
"""trait Foo:
259+
| def <<meth>>[A](x: A): A
260+
|class Bar(val foo: Foo):
261+
| export foo.*
262+
| def test(bar: Bar) = bar.me@@th(0)
285263
|""".stripMargin
286264
)
287265

0 commit comments

Comments
 (0)