Skip to content

Commit 813e059

Browse files
committed
Simplify MainAnnotation interface
Remove the `Command` class and place the `argGetter`, `varargsGetter` and `run` methods directly in the `MainAnnotation` interface. Now `command` pre-processes the arguments which clearly states which strings will be used for each argument. This simplifies the implementation of the `MainAnnotation` methods.
1 parent e852aa7 commit 813e059

15 files changed

+511
-462
lines changed

compiler/src/dotty/tools/dotc/ast/MainProxies.scala

Lines changed: 64 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -143,29 +143,31 @@ object MainProxies {
143143
* */
144144
* @myMain(80) def f(
145145
* @myMain.Alias("myX") x: S,
146+
* y: S,
146147
* ys: T*
147148
* ) = ...
148149
*
149150
* would be translated to something like
150151
*
151152
* final class f {
152153
* static def main(args: Array[String]): Unit = {
153-
* val cmd = new myMain(80).command(
154-
* info = new CommandInfo(
155-
* name = "f",
156-
* documentation = "Lorem ipsum dolor sit amet consectetur adipiscing elit.",
157-
* parameters = Seq(
158-
* new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
159-
* new scala.annotation.MainAnnotation.ParameterInfo("ys", "T", false, false, "all my params y", Seq())
160-
* )
154+
* val annotation = new myMain(80)
155+
* val info = new Info(
156+
* name = "f",
157+
* documentation = "Lorem ipsum dolor sit amet consectetur adipiscing elit.",
158+
* parameters = Seq(
159+
* new scala.annotation.MainAnnotation.Parameter("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX"))),
160+
* new scala.annotation.MainAnnotation.Parameter("y", "S", true, false, "", Seq()),
161+
* new scala.annotation.MainAnnotation.Parameter("ys", "T", false, true, "all my params y", Seq())
161162
* )
162-
* args = args
163-
* )
164-
*
165-
* val args0: () => S = cmd.argGetter[S](0, None)
166-
* val args1: () => Seq[T] = cmd.varargGetter[T]
167-
*
168-
* cmd.run(() => f(args0(), args1()*))
163+
* ),
164+
* val command = annotation.command(info, args)
165+
* if command.isDefined then
166+
* val cmd = command.get
167+
* val args0: () => S = annotation.argGetter[S](info.parameters(0), cmd(0), None)
168+
* val args1: () => S = annotation.argGetter[S](info.parameters(1), mainArgs(1), Some(() => sum$default$1()))
169+
* val args2: () => Seq[T] = annotation.varargGetter[T](info.parameters(2), cmd.drop(2))
170+
* annotation.run(() => f(args0(), args1(), args2()*))
169171
* }
170172
* }
171173
*/
@@ -229,7 +231,7 @@ object MainProxies {
229231
*
230232
* A ParamInfo has the following shape
231233
* ```
232-
* new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
234+
* new scala.annotation.MainAnnotation.Parameter("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
233235
* ```
234236
*/
235237
def parameterInfos(mt: MethodType): List[Tree] =
@@ -252,33 +254,34 @@ object MainProxies {
252254
val constructorArgs = List(param, paramTypeStr, hasDefault, isRepeated, paramDoc)
253255
.map(value => Literal(Constant(value)))
254256

255-
New(TypeTree(defn.MainAnnotationParameterInfo.typeRef), List(constructorArgs :+ paramAnnots))
257+
New(TypeTree(defn.MainAnnotationParameter.typeRef), List(constructorArgs :+ paramAnnots))
256258

257259
end parameterInfos
258260

259261
/**
260262
* Creates a list of references and definitions of arguments.
261263
* The goal is to create the
262-
* `val args0: () => S = cmd.argGetter[S](0, None)`
264+
* `val args0: () => S = annotation.argGetter[S](0, cmd(0), None)`
263265
* part of the code.
264266
*/
265267
def argValDefs(mt: MethodType): List[ValDef] =
266268
for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
267-
val argName = nme.args ++ idx.toString
268-
val isRepeated = formal.isRepeatedParam
269-
val formalType = if isRepeated then formal.argTypes.head else formal
270-
val getterName = if isRepeated then nme.varargGetter else nme.argGetter
271-
val defaultValueGetterOpt = defaultValueSymbols.get(idx) match
272-
case None => ref(defn.NoneModule.termRef)
273-
case Some(dvSym) =>
274-
val value = unitToValue(ref(dvSym.termRef))
275-
Apply(ref(defn.SomeClass.companionModule.termRef), value)
276-
val argGetter0 = TypeApply(Select(Ident(nme.cmd), getterName), TypeTree(formalType) :: Nil)
277-
val argGetter =
278-
if isRepeated then argGetter0
279-
else Apply(argGetter0, List(Literal(Constant(idx)), defaultValueGetterOpt))
280-
281-
ValDef(argName, TypeTree(), argGetter)
269+
val argName = nme.args ++ idx.toString
270+
val isRepeated = formal.isRepeatedParam
271+
val formalType = if isRepeated then formal.argTypes.head else formal
272+
val getterName = if isRepeated then nme.varargGetter else nme.argGetter
273+
val defaultValueGetterOpt = defaultValueSymbols.get(idx) match
274+
case None => ref(defn.NoneModule.termRef)
275+
case Some(dvSym) =>
276+
val value = unitToValue(ref(dvSym.termRef))
277+
Apply(ref(defn.SomeClass.companionModule.termRef), value)
278+
val argGetter0 = TypeApply(Select(Ident(nme.annotation), getterName), TypeTree(formalType) :: Nil)
279+
val index = Literal(Constant(idx))
280+
val paramInfo = Apply(Select(Ident(nme.info), nme.parameters), index)
281+
val argGetter =
282+
if isRepeated then Apply(argGetter0, List(paramInfo, Apply(Select(Ident(nme.cmd), nme.drop), List(index))))
283+
else Apply(argGetter0, List(paramInfo, Apply(Ident(nme.cmd), List(index)), defaultValueGetterOpt))
284+
ValDef(argName, TypeTree(), argGetter)
282285
end argValDefs
283286

284287

@@ -318,18 +321,39 @@ object MainProxies {
318321
val nameTree = Literal(Constant(mainFun.showName))
319322
val docTree = Literal(Constant(documentation.mainDoc))
320323
val paramInfos = Apply(ref(defn.SeqModule.termRef), parameterInfos)
321-
New(TypeTree(defn.MainAnnotationCommandInfo.typeRef), List(List(nameTree, docTree, paramInfos)))
324+
New(TypeTree(defn.MainAnnotationInfo.typeRef), List(List(nameTree, docTree, paramInfos)))
322325

323-
val cmd = ValDef(
324-
nme.cmd,
326+
val annotVal = ValDef(
327+
nme.annotation,
328+
TypeTree(),
329+
instantiateAnnotation(mainAnnot)
330+
)
331+
val infoVal = ValDef(
332+
nme.info,
333+
TypeTree(),
334+
cmdInfo
335+
)
336+
val command = ValDef(
337+
nme.command,
325338
TypeTree(),
326339
Apply(
327-
Select(instantiateAnnotation(mainAnnot), nme.command),
328-
List(cmdInfo, Ident(nme.args))
340+
Select(Ident(nme.annotation), nme.command),
341+
List(Ident(nme.info), Ident(nme.args))
329342
)
330343
)
331-
val run = Apply(Select(Ident(nme.cmd), nme.run), mainCall)
332-
val body = Block(cmdInfo :: cmd :: args, run)
344+
val argsVal = ValDef(
345+
nme.cmd,
346+
TypeTree(),
347+
Select(Ident(nme.command), nme.get)
348+
)
349+
val run = Apply(Select(Ident(nme.annotation), nme.run), mainCall)
350+
val body0 = If(
351+
Select(Ident(nme.command), nme.isDefined),
352+
Block(argsVal :: args, run),
353+
EmptyTree
354+
)
355+
val body = Block(List(annotVal, infoVal, command), body0) // TODO add `if (cmd.nonEmpty)`
356+
333357
val mainArg = ValDef(nme.args, TypeTree(defn.ArrayType.appliedTo(defn.StringType)), EmptyTree)
334358
.withFlags(Param)
335359
/** Replace typed `Ident`s that have been typed with a TypeSplice with the reference to the symbol.

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -856,8 +856,8 @@ class Definitions {
856856
@tu lazy val XMLTopScopeModule: Symbol = requiredModule("scala.xml.TopScope")
857857

858858
@tu lazy val MainAnnotationClass: ClassSymbol = requiredClass("scala.annotation.MainAnnotation")
859-
@tu lazy val MainAnnotationCommandInfo: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.CommandInfo")
860-
@tu lazy val MainAnnotationParameterInfo: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.ParameterInfo")
859+
@tu lazy val MainAnnotationInfo: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Info")
860+
@tu lazy val MainAnnotationParameter: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Parameter")
861861
@tu lazy val MainAnnotationParameterAnnotation: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.ParameterAnnotation")
862862
@tu lazy val MainAnnotationCommand: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Command")
863863

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ object StdNames {
543543
val ordinalDollar: N = "$ordinal"
544544
val ordinalDollar_ : N = "_$ordinal"
545545
val origin: N = "origin"
546+
val parameters: N = "parameters"
546547
val parts: N = "parts"
547548
val postfixOps: N = "postfixOps"
548549
val prefix : N = "prefix"

docs/_docs/reference/experimental/main-annotation.md

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,35 @@ When a users annotates a method with an annotation that extends `MainAnnotation`
1313
* @param first Fist number to sum
1414
* @param rest The rest of the numbers to sum
1515
*/
16-
@myMain def sum(first: Int, rest: Int*): Int = first + rest.sum
16+
@myMain def sum(first: Int, second: Int = 0, rest: Int*): Int = first + second + rest.sum
1717
```
1818

1919
```scala
2020
object foo {
2121
def main(args: Array[String]): Unit = {
22-
23-
val cmd = new myMain().command(
24-
info = new CommandInfo(
25-
name = "sum",
26-
documentation = "Sum all the numbers",
27-
parameters = Seq(
28-
new ParameterInfo("first", "scala.Int", hasDefault=false, isVarargs=false, "Fist number to sum", Seq()),
29-
new ParameterInfo("rest", "scala.Int" , hasDefault=false, isVarargs=true, "The rest of the numbers to sum", Seq())
30-
)
31-
),
32-
args = args
22+
val mainAnnot = new myMain()
23+
val info = new Info(
24+
name = "foo.main",
25+
documentation = "Sum all the numbers",
26+
parameters = Seq(
27+
new Parameter("first", "scala.Int", hasDefault=false, isVarargs=false, "Fist number to sum", Seq()),
28+
new Parameter("second", "scala.Int", hasDefault=true, isVarargs=false, "", Seq()),
29+
new Parameter("rest", "scala.Int" , hasDefault=false, isVarargs=true, "The rest of the numbers to sum", Seq())
30+
)
3331
)
34-
val args0 = cmd.argGetter[Int](0, None) // using a parser of Int
35-
val args1 = cmd.varargGetter[Int] // using a parser of Int
36-
cmd.run(() => sum(args0(), args1()*))
32+
val mainArgsOpt = mainAnnot.command(info, args)
33+
if mainArgsOpt.isDefined then
34+
val mainArgs = mainArgsOpt.get
35+
val args0 = mainAnnot.argGetter[Int](info.parameters(0), mainArgs(0), None) // using a parser of Int
36+
val args1 = mainAnnot.argGetter[Int](info.parameters(1), mainArgs(1), Some(() => sum$default$1())) // using a parser of Int
37+
val args2 = mainAnnot.varargGetter[Int](info.parameters(2), mainArgs.drop(2)) // using a parser of Int
38+
mainAnnot.run(() => sum(args0(), args1(), args2()*))
3739
}
3840
}
3941
```
4042

41-
The implementation of the `main` method first instantiates the annotation and then creates a `Command`.
42-
When creating the `Command`, the arguments can be checked and preprocessed.
43+
The implementation of the `main` method first instantiates the annotation and then call `command`.
44+
When calling the `command`, the arguments can be checked and preprocessed.
4345
Then it defines a series of argument getters calling `argGetter` for each parameter and `varargGetter` for the last one if it is a varargs. `argGetter` gets an optional lambda that computes the default argument.
4446
Finally, the `run` method is called to run the application. It receives a by-name argument that contains the call the annotated method with the instantiations arguments (using the lambdas from `argGetter`/`varargGetter`).
4547

@@ -50,42 +52,46 @@ Example of implementation of `myMain` that takes all arguments positionally. It
5052
// Parser used to parse command line arguments
5153
import scala.util.CommandLineParser.FromString[T]
5254

53-
// Result type of the annotated method is Int
54-
class myMain extends MainAnnotation:
55-
import MainAnnotation.{ ParameterInfo, Command }
55+
// Result type of the annotated method is Int and arguments are parsed using FromString
56+
@experimental class myMain extends MainAnnotation[FromString, Int]:
57+
import MainAnnotation.{ Info, Parameter }
5658

57-
/** A new command with arguments from `args` */
58-
def command(info: CommandInfo, args: Array[String]): Command[FromString, Int] =
59+
def command(info: Info, args: Seq[String]): Option[Seq[String]] =
5960
if args.contains("--help") then
6061
println(info.documentation)
61-
// TODO: Print documentation of the parameters
62-
System.exit(0)
63-
assert(info.parameters.forall(!_.hasDefault), "Default arguments are not supported")
64-
val (plainArgs, varargs) =
65-
if info.parameters.last.isVarargs then
66-
val numPlainArgs = info.parameters.length - 1
67-
assert(numPlainArgs <= args.length, "Not enough arguments")
68-
(args.take(numPlainArgs), args.drop(numPlainArgs))
62+
None // do not parse or run the program
63+
else if info.parameters.exists(_.hasDefault) then
64+
println("Default arguments are not supported")
65+
None
66+
else if info.hasVarargs then
67+
val numPlainArgs = info.parameters.length - 1
68+
if numPlainArgs <= args.length then
69+
println("Not enough arguments")
70+
None
71+
else
72+
Some(args)
73+
else
74+
if info.parameters.length <= args.length then
75+
println("Not enough arguments")
76+
None
77+
else if info.parameters.length >= args.length then
78+
println("Too many arguments")
79+
None
6980
else
70-
assert(info.parameters.length <= args.length, "Not enough arguments")
71-
assert(info.parameters.length >= args.length, "Too many arguments")
72-
(args, Array.empty[String])
73-
new MyCommand(plainArgs, varargs)
81+
Some(args)
7482

75-
@experimental
76-
class MyCommand(plainArgs: Seq[String], varargs: Seq[String]) extends Command[FromString, Int]:
83+
def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using parser: FromString[T]): () => T =
84+
() => parser.fromString(arg)
7785

78-
def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using parser: FromString[T]): () => T =
79-
() => parser.fromString(plainArgs(idx))
86+
def varargGetter[T](param: Parameter, args: Seq[String])(using parser: FromString[T]): () => Seq[T] =
87+
() => args.map(arg => parser.fromString(arg))
8088

81-
def varargGetter[T](using parser: FromString[T]): () => Seq[T] =
82-
() => varargs.map(arg => parser.fromString(arg))
89+
def run(program: () => Int): Unit =
90+
println("executing program")
8391

84-
def run(program: () => Int): Unit =
85-
println("executing program")
92+
try {
8693
val result = program()
8794
println("result: " + result)
8895
println("executed program")
89-
end MyCommand
9096
end myMain
9197
```

0 commit comments

Comments
 (0)