Skip to content

Commit 2a385ed

Browse files
committed
Simplify MainAnnotation interface
Remove the `Command` class and place the `argGetter`, `varargsGetter` and `run` methods directly in the `MainAnnotation` interface. Now `command` pre-processes the arguments which clearly states which strings will be used for each argument. This simplifies the implementation of the `MainAnnotation` methods.
1 parent e892e06 commit 2a385ed

17 files changed

+519
-464
lines changed

compiler/src/dotty/tools/dotc/ast/MainProxies.scala

Lines changed: 64 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -141,29 +141,31 @@ object MainProxies {
141141
* */
142142
* @myMain(80) def f(
143143
* @myMain.Alias("myX") x: S,
144+
* y: S,
144145
* ys: T*
145146
* ) = ...
146147
*
147148
* would be translated to something like
148149
*
149150
* final class f {
150151
* static def main(args: Array[String]): Unit = {
151-
* val cmd = new myMain(80).command(
152-
* info = new CommandInfo(
153-
* name = "f",
154-
* documentation = "Lorem ipsum dolor sit amet consectetur adipiscing elit.",
155-
* parameters = Seq(
156-
* new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
157-
* new scala.annotation.MainAnnotation.ParameterInfo("ys", "T", false, false, "all my params y", Seq())
158-
* )
152+
* val annotation = new myMain(80)
153+
* val info = new Info(
154+
* name = "f",
155+
* documentation = "Lorem ipsum dolor sit amet consectetur adipiscing elit.",
156+
* parameters = Seq(
157+
* new scala.annotation.MainAnnotation.Parameter("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX"))),
158+
* new scala.annotation.MainAnnotation.Parameter("y", "S", true, false, "", Seq()),
159+
* new scala.annotation.MainAnnotation.Parameter("ys", "T", false, true, "all my params y", Seq())
159160
* )
160-
* args = args
161-
* )
162-
*
163-
* val args0: () => S = cmd.argGetter[S](0, None)
164-
* val args1: () => Seq[T] = cmd.varargGetter[T]
165-
*
166-
* cmd.run(() => f(args0(), args1()*))
161+
* ),
162+
* val command = annotation.command(info, args)
163+
* if command.isDefined then
164+
* val cmd = command.get
165+
* val args0: () => S = annotation.argGetter[S](info.parameters(0), cmd(0), None)
166+
* val args1: () => S = annotation.argGetter[S](info.parameters(1), mainArgs(1), Some(() => sum$default$1()))
167+
* val args2: () => Seq[T] = annotation.varargGetter[T](info.parameters(2), cmd.drop(2))
168+
* annotation.run(() => f(args0(), args1(), args2()*))
167169
* }
168170
* }
169171
*/
@@ -227,7 +229,7 @@ object MainProxies {
227229
*
228230
* A ParamInfo has the following shape
229231
* ```
230-
* new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
232+
* new scala.annotation.MainAnnotation.Parameter("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
231233
* ```
232234
*/
233235
def parameterInfos(mt: MethodType): List[Tree] =
@@ -250,33 +252,34 @@ object MainProxies {
250252
val constructorArgs = List(param, paramTypeStr, hasDefault, isRepeated, paramDoc)
251253
.map(value => Literal(Constant(value)))
252254

253-
New(TypeTree(defn.MainAnnotationParameterInfo.typeRef), List(constructorArgs :+ paramAnnots))
255+
New(TypeTree(defn.MainAnnotationParameter.typeRef), List(constructorArgs :+ paramAnnots))
254256

255257
end parameterInfos
256258

257259
/**
258260
* Creates a list of references and definitions of arguments.
259261
* The goal is to create the
260-
* `val args0: () => S = cmd.argGetter[S](0, None)`
262+
* `val args0: () => S = annotation.argGetter[S](0, cmd(0), None)`
261263
* part of the code.
262264
*/
263265
def argValDefs(mt: MethodType): List[ValDef] =
264266
for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
265-
val argName = nme.args ++ idx.toString
266-
val isRepeated = formal.isRepeatedParam
267-
val formalType = if isRepeated then formal.argTypes.head else formal
268-
val getterName = if isRepeated then nme.varargGetter else nme.argGetter
269-
val defaultValueGetterOpt = defaultValueSymbols.get(idx) match
270-
case None => ref(defn.NoneModule.termRef)
271-
case Some(dvSym) =>
272-
val value = unitToValue(ref(dvSym.termRef))
273-
Apply(ref(defn.SomeClass.companionModule.termRef), value)
274-
val argGetter0 = TypeApply(Select(Ident(nme.cmd), getterName), TypeTree(formalType) :: Nil)
275-
val argGetter =
276-
if isRepeated then argGetter0
277-
else Apply(argGetter0, List(Literal(Constant(idx)), defaultValueGetterOpt))
278-
279-
ValDef(argName, TypeTree(), argGetter)
267+
val argName = nme.args ++ idx.toString
268+
val isRepeated = formal.isRepeatedParam
269+
val formalType = if isRepeated then formal.argTypes.head else formal
270+
val getterName = if isRepeated then nme.varargGetter else nme.argGetter
271+
val defaultValueGetterOpt = defaultValueSymbols.get(idx) match
272+
case None => ref(defn.NoneModule.termRef)
273+
case Some(dvSym) =>
274+
val value = unitToValue(ref(dvSym.termRef))
275+
Apply(ref(defn.SomeClass.companionModule.termRef), value)
276+
val argGetter0 = TypeApply(Select(Ident(nme.annotation), getterName), TypeTree(formalType) :: Nil)
277+
val index = Literal(Constant(idx))
278+
val paramInfo = Apply(Select(Ident(nme.info), nme.parameters), index)
279+
val argGetter =
280+
if isRepeated then Apply(argGetter0, List(paramInfo, Apply(Select(Ident(nme.cmd), nme.drop), List(index))))
281+
else Apply(argGetter0, List(paramInfo, Apply(Ident(nme.cmd), List(index)), defaultValueGetterOpt))
282+
ValDef(argName, TypeTree(), argGetter)
280283
end argValDefs
281284

282285

@@ -316,18 +319,39 @@ object MainProxies {
316319
val nameTree = Literal(Constant(mainFun.showName))
317320
val docTree = Literal(Constant(documentation.mainDoc))
318321
val paramInfos = Apply(ref(defn.SeqModule.termRef), parameterInfos)
319-
New(TypeTree(defn.MainAnnotationCommandInfo.typeRef), List(List(nameTree, docTree, paramInfos)))
322+
New(TypeTree(defn.MainAnnotationInfo.typeRef), List(List(nameTree, docTree, paramInfos)))
320323

321-
val cmd = ValDef(
322-
nme.cmd,
324+
val annotVal = ValDef(
325+
nme.annotation,
326+
TypeTree(),
327+
instantiateAnnotation(mainAnnot)
328+
)
329+
val infoVal = ValDef(
330+
nme.info,
331+
TypeTree(),
332+
cmdInfo
333+
)
334+
val command = ValDef(
335+
nme.command,
323336
TypeTree(),
324337
Apply(
325-
Select(instantiateAnnotation(mainAnnot), nme.command),
326-
List(cmdInfo, Ident(nme.args))
338+
Select(Ident(nme.annotation), nme.command),
339+
List(Ident(nme.info), Ident(nme.args))
327340
)
328341
)
329-
val run = Apply(Select(Ident(nme.cmd), nme.run), mainCall)
330-
val body = Block(cmdInfo :: cmd :: args, run)
342+
val argsVal = ValDef(
343+
nme.cmd,
344+
TypeTree(),
345+
Select(Ident(nme.command), nme.get)
346+
)
347+
val run = Apply(Select(Ident(nme.annotation), nme.run), mainCall)
348+
val body0 = If(
349+
Select(Ident(nme.command), nme.isDefined),
350+
Block(argsVal :: args, run),
351+
EmptyTree
352+
)
353+
val body = Block(List(annotVal, infoVal, command), body0) // TODO add `if (cmd.nonEmpty)`
354+
331355
val mainArg = ValDef(nme.args, TypeTree(defn.ArrayType.appliedTo(defn.StringType)), EmptyTree)
332356
.withFlags(Param)
333357
/** Replace typed `Ident`s that have been typed with a TypeSplice with the reference to the symbol.

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -852,8 +852,8 @@ class Definitions {
852852
@tu lazy val XMLTopScopeModule: Symbol = requiredModule("scala.xml.TopScope")
853853

854854
@tu lazy val MainAnnotationClass: ClassSymbol = requiredClass("scala.annotation.MainAnnotation")
855-
@tu lazy val MainAnnotationCommandInfo: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.CommandInfo")
856-
@tu lazy val MainAnnotationParameterInfo: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.ParameterInfo")
855+
@tu lazy val MainAnnotationInfo: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Info")
856+
@tu lazy val MainAnnotationParameter: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Parameter")
857857
@tu lazy val MainAnnotationParameterAnnotation: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.ParameterAnnotation")
858858
@tu lazy val MainAnnotationCommand: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Command")
859859

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ object StdNames {
543543
val ordinalDollar: N = "$ordinal"
544544
val ordinalDollar_ : N = "_$ordinal"
545545
val origin: N = "origin"
546+
val parameters: N = "parameters"
546547
val parts: N = "parts"
547548
val postfixOps: N = "postfixOps"
548549
val prefix : N = "prefix"

docs/_docs/reference/experimental/main-annotation.md

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,35 @@ When a users annotates a method with an annotation that extends `MainAnnotation`
1313
* @param first Fist number to sum
1414
* @param rest The rest of the numbers to sum
1515
*/
16-
@myMain def sum(first: Int, rest: Int*): Int = first + rest.sum
16+
@myMain def sum(first: Int, second: Int = 0, rest: Int*): Int = first + second + rest.sum
1717
```
1818

1919
```scala
2020
object foo {
2121
def main(args: Array[String]): Unit = {
22-
23-
val cmd = new myMain().command(
24-
info = new CommandInfo(
25-
name = "sum",
26-
documentation = "Sum all the numbers",
27-
parameters = Seq(
28-
new ParameterInfo("first", "scala.Int", hasDefault=false, isVarargs=false, "Fist number to sum", Seq()),
29-
new ParameterInfo("rest", "scala.Int" , hasDefault=false, isVarargs=true, "The rest of the numbers to sum", Seq())
30-
)
31-
),
32-
args = args
22+
val mainAnnot = new myMain()
23+
val info = new Info(
24+
name = "foo.main",
25+
documentation = "Sum all the numbers",
26+
parameters = Seq(
27+
new Parameter("first", "scala.Int", hasDefault=false, isVarargs=false, "Fist number to sum", Seq()),
28+
new Parameter("second", "scala.Int", hasDefault=true, isVarargs=false, "", Seq()),
29+
new Parameter("rest", "scala.Int" , hasDefault=false, isVarargs=true, "The rest of the numbers to sum", Seq())
30+
)
3331
)
34-
val args0 = cmd.argGetter[Int](0, None) // using a parser of Int
35-
val args1 = cmd.varargGetter[Int] // using a parser of Int
36-
cmd.run(() => sum(args0(), args1()*))
32+
val mainArgsOpt = mainAnnot.command(info, args)
33+
if mainArgsOpt.isDefined then
34+
val mainArgs = mainArgsOpt.get
35+
val args0 = mainAnnot.argGetter[Int](info.parameters(0), mainArgs(0), None) // using a parser of Int
36+
val args1 = mainAnnot.argGetter[Int](info.parameters(1), mainArgs(1), Some(() => sum$default$1())) // using a parser of Int
37+
val args2 = mainAnnot.varargGetter[Int](info.parameters(2), mainArgs.drop(2)) // using a parser of Int
38+
mainAnnot.run(() => sum(args0(), args1(), args2()*))
3739
}
3840
}
3941
```
4042

41-
The implementation of the `main` method first instantiates the annotation and then creates a `Command`.
42-
When creating the `Command`, the arguments can be checked and preprocessed.
43+
The implementation of the `main` method first instantiates the annotation and then call `command`.
44+
When calling the `command`, the arguments can be checked and preprocessed.
4345
Then it defines a series of argument getters calling `argGetter` for each parameter and `varargGetter` for the last one if it is a varargs. `argGetter` gets an optional lambda that computes the default argument.
4446
Finally, the `run` method is called to run the application. It receives a by-name argument that contains the call the annotated method with the instantiations arguments (using the lambdas from `argGetter`/`varargGetter`).
4547

@@ -50,42 +52,46 @@ Example of implementation of `myMain` that takes all arguments positionally. It
5052
// Parser used to parse command line arguments
5153
import scala.util.CommandLineParser.FromString[T]
5254

53-
// Result type of the annotated method is Int
54-
class myMain extends MainAnnotation:
55-
import MainAnnotation.{ ParameterInfo, Command }
55+
// Result type of the annotated method is Int and arguments are parsed using FromString
56+
@experimental class myMain extends MainAnnotation[FromString, Int]:
57+
import MainAnnotation.{ Info, Parameter }
5658

57-
/** A new command with arguments from `args` */
58-
def command(info: CommandInfo, args: Array[String]): Command[FromString, Int] =
59+
def command(info: Info, args: Seq[String]): Option[Seq[String]] =
5960
if args.contains("--help") then
6061
println(info.documentation)
61-
// TODO: Print documentation of the parameters
62-
System.exit(0)
63-
assert(info.parameters.forall(!_.hasDefault), "Default arguments are not supported")
64-
val (plainArgs, varargs) =
65-
if info.parameters.last.isVarargs then
66-
val numPlainArgs = info.parameters.length - 1
67-
assert(numPlainArgs <= args.length, "Not enough arguments")
68-
(args.take(numPlainArgs), args.drop(numPlainArgs))
62+
None // do not parse or run the program
63+
else if info.parameters.exists(_.hasDefault) then
64+
println("Default arguments are not supported")
65+
None
66+
else if info.hasVarargs then
67+
val numPlainArgs = info.parameters.length - 1
68+
if numPlainArgs <= args.length then
69+
println("Not enough arguments")
70+
None
71+
else
72+
Some(args)
73+
else
74+
if info.parameters.length <= args.length then
75+
println("Not enough arguments")
76+
None
77+
else if info.parameters.length >= args.length then
78+
println("Too many arguments")
79+
None
6980
else
70-
assert(info.parameters.length <= args.length, "Not enough arguments")
71-
assert(info.parameters.length >= args.length, "Too many arguments")
72-
(args, Array.empty[String])
73-
new MyCommand(plainArgs, varargs)
81+
Some(args)
7482

75-
@experimental
76-
class MyCommand(plainArgs: Seq[String], varargs: Seq[String]) extends Command[FromString, Int]:
83+
def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using parser: FromString[T]): () => T =
84+
() => parser.fromString(arg)
7785

78-
def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using parser: FromString[T]): () => T =
79-
() => parser.fromString(plainArgs(idx))
86+
def varargGetter[T](param: Parameter, args: Seq[String])(using parser: FromString[T]): () => Seq[T] =
87+
() => args.map(arg => parser.fromString(arg))
8088

81-
def varargGetter[T](using parser: FromString[T]): () => Seq[T] =
82-
() => varargs.map(arg => parser.fromString(arg))
89+
def run(program: () => Int): Unit =
90+
println("executing program")
8391

84-
def run(program: () => Int): Unit =
85-
println("executing program")
92+
try {
8693
val result = program()
8794
println("result: " + result)
8895
println("executed program")
89-
end MyCommand
9096
end myMain
9197
```

0 commit comments

Comments
 (0)